# 1 - Test the environment

In [None]:
import warnings

In [None]:
warnings.simplefilter('ignore', category=DeprecationWarning)
warnings.simplefilter('ignore', category=UserWarning)

In [None]:
import gym

from matplotlib import pyplot as plt
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

## 1.1 - Play games with random actions

In [None]:
env = gym.make('ALE/SpaceInvaders-v5')

In [None]:
env.unwrapped.get_action_meanings()

In [None]:
for episode in range(5):
    obs = env.reset()
    done = False
    score = 0

    while not done:
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        score += reward

    print(f'Episode : {episode + 1} --> Score : {score}')

env.close()

## 1.2 - Stack frames to represent movements

In [None]:
env = make_atari_env('ALE/SpaceInvaders-v5')
env = VecFrameStack(env, n_stack=4)

In [None]:
obs = env.reset()

for i in range(10):
    plt.figure(figsize=(20, 16))

    for idx in range(obs.shape[3]):
        plt.subplot(1, 4, idx + 1)
        plt.imshow(obs[0][:, :, idx])

    plt.show()

    action = env.action_space.sample()
    obs, reward, done, info = env.step([action])

env.close()

# 2 - Train and evaluate models

In [None]:
import numpy as np

from sb3_contrib import MaskablePPO, QRDQN
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.type_aliases import Schedule

In [None]:
def mask_fn(_env: gym.Env) -> np.ndarray:
    return [True, True, True, True, False, False]

In [None]:
class ActionMaskerAtariWrapper(gym.Wrapper):

    def __init__(self, _env: gym.Env):
        _env = ActionMasker(_env, mask_fn)
        _env = AtariWrapper(_env)

        super().__init__(_env)

In [None]:
def linear_schedule(initial_value: float) -> Schedule:
    def func(progress_remaining: float) -> float:
        return progress_remaining * initial_value

    return func

## 2.1 - Setup train and evaluation environments

In [None]:
env = make_vec_env('ALE/SpaceInvaders-v5',
                   n_envs=8,
                   wrapper_class=ActionMaskerAtariWrapper)
env = VecFrameStack(env, n_stack=4)

In [None]:
eval_env = make_vec_env('ALE/SpaceInvaders-v5',
                        wrapper_class=ActionMaskerAtariWrapper)
eval_env = VecFrameStack(eval_env, n_stack=4)

## 2.2 - Setup checkpoint and evaluation callbacks

In [None]:
checkpoint_callback = CheckpointCallback(save_freq=125_000,
                                         save_path='logs/',
                                         name_prefix='space_invaders')

In [None]:
eval_callback = EvalCallback(eval_env,
                             best_model_save_path='logs/',
                             eval_freq=12_500)

In [None]:
callback_list = CallbackList([checkpoint_callback, eval_callback])

## 2.3 - Train a model with [QR-DQN algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html])

In [None]:
checkpoint_callback.name_prefix = 'qrdqn/space_invaders'
eval_callback.best_model_save_path = 'logs/qrdqn/'

In [None]:
model = QRDQN('CnnPolicy',
              env,
              optimize_memory_usage=True,
              exploration_fraction=0.025,
              verbose=1,
              tensorboard_log='logs/tensorboard/')

In [None]:
model.learn(total_timesteps=int(1e7),
            callback=callback_list,
            log_interval=2000,
            tb_log_name='qrdqn_space_invaders')
model.save('qrdqn_space_invaders')

## 2.4 - Train a model with [Maskable PPO algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html)

In [None]:
checkpoint_callback.name_prefix = 'maskable_ppo/space_invaders'
eval_callback.best_model_save_path = 'logs/maskable_ppo/'

In [None]:
model = MaskablePPO('CnnPolicy',
                    env,
                    learning_rate=linear_schedule(2.5e-4),
                    n_steps=128,
                    batch_size=256,
                    n_epochs=4,
                    clip_range=linear_schedule(0.1),
                    ent_coef=0.01,
                    vf_coef=0.5,
                    verbose=1,
                    tensorboard_log='logs/tensorboard/')

In [None]:
model.learn(total_timesteps=int(1e7),
            callback=callback_list,
            log_interval=100,
            tb_log_name='maskable_ppo_space_invaders')
model.save('maskable_ppo_space_invaders')

## 2.5 - Error : Python process interruption

The python kernel crashed a little bit after 6M steps, so the model is reset with saved parameters of the zip file of 6M steps to train 4M more steps to reach 10M steps in total.

In [None]:
checkpoint_callback.name_prefix = 'maskable_ppo/follow-up/space_invaders'

In [None]:
model = MaskablePPO('CnnPolicy',
                    env,
                    learning_rate=linear_schedule(1.0001e-4),
                    n_steps=128,
                    batch_size=256,
                    n_epochs=4,
                    clip_range=linear_schedule(0.04),
                    ent_coef=0.01,
                    vf_coef=0.5,
                    verbose=1,
                    tensorboard_log='logs/tensorboard/')

In [None]:
model.set_parameters('logs/maskable_ppo/space_invaders_6000000_steps.zip')

In [None]:
model.learn(total_timesteps=int(4e6),
            callback=callback_list,
            log_interval=100,
            tb_log_name='maskable_ppo_space_invaders')
model.save('maskable_ppo_space_invaders')

# 3 - See the results of trained models

In [None]:
def demo(_model, _env):
    mean_reward, std_reward = evaluate_policy(_model, _env, render=True)
    print(f'mean_reward = {mean_reward:.2f} +/- {std_reward:.2f}')

## 3.1 - Setup the demonstration environment

In [None]:
demo_env = make_vec_env('ALE/SpaceInvaders-v5',
                        wrapper_class=ActionMaskerAtariWrapper)
demo_env = VecFrameStack(demo_env, n_stack=4)

## 3.2 - Demo of models with [QR-DQN algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html]) 

In [None]:
demo_model = QRDQN.load('logs/qrdqn/best_model')
demo(demo_model, demo_env)

In [None]:
demo_model = QRDQN.load('qrdqn_space_invaders')
demo(demo_model, demo_env)

## 3.3 - Demo of models with [Maskable PPO algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html)

In [None]:
demo_model = MaskablePPO.load('logs/maskable_ppo/best_model')
demo(demo_model, demo_env)

In [None]:
demo_model = MaskablePPO.load('maskable_ppo_space_invaders')
demo(demo_model, demo_env)

# Bonus - Make a gif of the best model

In [None]:
import imageio

In [None]:
gif_env = make_atari_env('ALE/SpaceInvaders-v5')
gif_env = VecFrameStack(gif_env, n_stack=4)

In [None]:
gif_model = MaskablePPO.load('maskable_ppo_space_invaders')
gif_model.set_env(gif_env)

In [None]:
images = []
obs = gif_env.reset()
img = gif_env.render(mode='rgb_array')

for i in range(350):
    images.append(img)
    action, _ = gif_model.predict(obs)
    obs, _, _, _ = gif_env.step(action)
    img = gif_env.render(mode='rgb_array')

imageio.mimsave('test_space_invaders.gif',
                [np.array(img) for i, img in enumerate(images) if i % 2 == 0],
                fps=29)
env.close()