In [86]:
FNAME = "atari_empty_16x16_plain_dqn_3"
import numpy as np

import torch
import torch.nn as nn

from stable_baselines3 import DQN
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack


import gymnasium as gym
import ale_py
gym.register_envs(ale_py)

from gymnasium.wrappers import FrameStackObservation, ClipReward

from IPython import display

In [87]:
# Function to reset weights
def reset_weights(layer):
    if isinstance(layer, (nn.Conv2d, nn.Linear)):
        layer.reset_parameters()

# Custom callback to reset weights during training
class ResetWeightsCallback(BaseCallback):
    def __init__(self, reset_interval, verbose=0):
        super().__init__(verbose)
        self.reset_interval = reset_interval  # Number of steps between resets

    def _on_step(self) -> bool:
        # Reset weights every reset_interval steps
        if self.n_calls % self.reset_interval == 0: # n_calls inherited from BaseCallback
            # if self.verbose > 0:
            #     print(f"Resetting weights at step {self.n_calls}...")
            print(f"Policy weight reset at: {self.n_calls}")
            # Reset q_net and q_net_target
            self.model.policy.q_net.apply(reset_weights)
            self.model.policy.q_net_target.apply(reset_weights)
        return True

In [88]:
n_stack = 4 # run updates once every 4 frames (stack 4 frames for the model)
eval_freq = 5000 # once every 5000 timesteps, evaluate the model

In [89]:
env = make_atari_env("AlienNoFrameskip-v4", n_envs=n_stack) #seed can be used here
env = VecFrameStack(env, n_stack= n_stack)
eval_env = make_atari_env("AlienNoFrameskip-v4", n_envs= n_stack) #seed can be used here, different than env's seed
eval_env = VecFrameStack(eval_env, n_stack= n_stack)

In [90]:
log_path = f"./logs/sb3_atari_dqn_1"
policy_kwargs = dict()
# policy_kwargs.update(num_agent=1)
# policy_kwargs.update(action_select_coef=50)

In [92]:
timesteps = 100000
replay_ratio = 1
reset_interval = 80000

In [93]:
eval_callback = EvalCallback(env, best_model_save_path=log_path, log_path=log_path,
                             eval_freq=max(eval_freq // n_stack, 1), deterministic=True,
                             render=True)
# Create and attach the callback
reset_callback = ResetWeightsCallback(reset_interval=reset_interval // n_stack, verbose=1)

callback_list = CallbackList([eval_callback, reset_callback])

In [94]:
model = DQN(
    policy= "CnnPolicy", 
    env= env, 
    verbose= 1, 
    buffer_size= timesteps,
    learning_starts= 2000,
    tau= 0.005,
    train_freq= (1, "step"),
    gradient_steps= replay_ratio,
    target_update_interval= 1,
    policy_kwargs= policy_kwargs,
    tensorboard_log="./dqn_atari_logs",
    )
# need reset, reset_frequency and all_reset
model.learn(
    total_timesteps=timesteps,
    callback=eval_callback
    )

Using cuda device
Wrapping the env in a VecTransposeImage.




Logging to ./dqn_atari_logs/DQN_3




----------------------------------
| rollout/            |          |
|    exploration_rate | 0.913    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 782      |
|    time_elapsed     | 1        |
|    total_timesteps  | 916      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.804    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 930      |
|    time_elapsed     | 2        |
|    total_timesteps  | 2064     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0153   |
|    n_updates        | 15       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.54e+03 |
|    ep_rew_mean      | 200      |
|    exploration_rate | 0.729    |
| time/               |          |
|    episodes       

: 

In [83]:
env.close()
eval_env.close()

In [84]:
model.save(f"./models/{FNAME}")

In [85]:
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")

print(f"mean_reward: {mean_reward}, std_reward:{std_reward}")

mean_reward: 522.0, std_reward:247.37825288412077
