In [None]:
!pip install gymnasium[atari] stable-baselines3[extra] torch tensorboard ale-py
!pip install jupyter matplotlib


In [None]:
import gymnasium as gym
from stable_baselines3 import DQN
import ale_py
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback

class TrackProgressCallback(BaseCallback):
    def __init__(self, save_freq=100_000, verbose=1):
        super().__init__(verbose)
        self.episode_rewards = []
        self.avg_rewards = []
        self.save_freq = save_freq

    def _on_step(self):
        # Track episode reward
        if len(self.locals["infos"]) > 0 and "episode" in self.locals["infos"][0]:
            r = self.locals["infos"][0]["episode"]["r"]
            self.episode_rewards.append(r)
            # Running mean for last 100 episodes
            avg = np.mean(self.episode_rewards[-100:]) if len(self.episode_rewards) > 0 else r
            self.avg_rewards.append(avg)
        # Save weights at intervals
        if self.num_timesteps % self.save_freq == 0:
            self.model.save(f"dqn_breakout_step_{self.num_timesteps}")
        return True

    def _on_training_end(self):
        np.save("episode_rewards.npy", np.array(self.episode_rewards))
        np.save("average_rewards.npy", np.array(self.avg_rewards))




In [None]:
from gymnasium.wrappers import FrameStackObservation,AtariPreprocessing
from typing import Callable

def linear_schedule(initial_value: float) -> Callable[[float], float]:
    """
    Linear learning rate schedule.

    :param initial_value: Initial learning rate.
    :return: schedule that computes
      current learning rate depending on remaining progress
    """
    def func(progress_remaining: float) -> float:
        """
        Progress will decrease from 1 (beginning) to 0.

        :param progress_remaining:
        :return: current learning rate
        """
        return progress_remaining * initial_value

    return func

env = gym.make("ALE/Breakout-v5", render_mode="rgb_array", frameskip=1)
env = AtariPreprocessing(env, terminal_on_life_loss=True)
env = FrameStackObservation(env, stack_size=4)

model = DQN(
    "CnnPolicy",
    env,
    verbose=1,
    buffer_size=100_000,
    learning_starts=10_000,
    batch_size=32,
    gamma=0.99,
    train_freq=4,
    learning_rate=linear_schedule(0.0001),
    target_update_interval=10_000,
    exploration_fraction=0.1,
    exploration_final_eps=0.01,
    tensorboard_log="./tensorboard_log/",
    device="cuda"  # or "cpu" if no GPU
)


progress = TrackProgressCallback(save_freq=100_000)
model.learn(total_timesteps=3_000_000,callback=progress)


model.save("dqn_breakout_test")
model.save_replay_buffer("dqn_replay_buffer")
print("Test training completed.")



In [None]:
import matplotlib.pyplot as plt
ep_rewards = np.load("episode_rewards.npy")
avg_rewards = np.load("average_rewards.npy")

plt.plot(avg_rewards)
plt.xlabel("Episode")
plt.ylabel("100-episode Average Reward")
plt.title("Breakout DQN Learning Curve")
plt.show()

env = gym.make("ALE/Breakout-v5", render_mode="rgb_array",frameskip=1)
env = AtariPreprocessing(env, terminal_on_life_loss=True)
env = FrameStackObservation(env, stack_size=4)
obs, _ = env.reset()
done, truncated = False, False
total_reward = 0

while not (done or truncated):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)
    total_reward += reward


print("Total reward:", total_reward)


In [None]:
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv

def make_env():
    env = gym.make("ALE/Breakout-v5", render_mode="rgb_array",frameskip=1)
    env = AtariPreprocessing(env, terminal_on_life_loss=True)
    env = FrameStackObservation(env, stack_size=4)
    return env

env = DummyVecEnv([make_env])
env = VecVideoRecorder(
    env,
    "./videos/",
    record_video_trigger=lambda step: True,
    video_length=1000,
    name_prefix="dqn-breakout"
)

obs = env.reset()
for _ in range(1000):
    action, _ = model.predict(obs, deterministic=True)
    obs, rewards, dones, infos = env.step(action)
env.close()


In [None]:
from IPython.display import Video
Video('/content/videos/dqn-breakout-step-0-to-step-1000.mp4', embed=True)
