# Imports

In [9]:
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStack
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv, EpisodicLifeEnv, FireResetEnv
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt

# Define Wrappers

In [17]:
def makeEnv(envId: str, renderMode: str):
    env = gym.make(envId, render_mode=renderMode)
    env = AtariPreprocessing(env, frame_skip=1)
    env = FrameStack(env, 4)
    
    return env

In [18]:
envId = 'ALE/Zaxxon-v5'
renderMode = 'human'
env = DummyVecEnv([lambda: makeEnv(envId, renderMode)])
env = VecFrameStack(env, n_stack=4)

# DQN Model Training

In [14]:
model = DQN('CnnPolicy', env, verbose=1, buffer_size=10000, learning_starts=1000,
            batch_size=32, gamma=0.99, target_update_interval=1000, train_freq=4,
            gradient_steps=1, exploration_fraction=0.1, exploration_final_eps=0.01)
model.learn(total_timesteps=10000)
model.save('dqn_zaxxon')

Using cpu device




----------------------------------
| rollout/            |          |
|    exploration_rate | 0.01     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 5        |
|    time_elapsed     | 656      |
|    total_timesteps  | 3448     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 4.02e-07 |
|    n_updates        | 611      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.01     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 7        |
|    time_elapsed     | 877      |
|    total_timesteps  | 7011     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.77e-06 |
|    n_updates        | 1502     |
----------------------------------


# Load the model

In [15]:
model = DQN.load('dqn_zaxxon.zip')
obs = env.reset()
for _ in range(10000):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

