In [1]:
!pip install gymnasium
!pip install ale-py
!pip install torch
!pip install stable-baselines3



In [2]:
!pip install wandb



In [1]:
import os
import random
import time
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import ale_py

In [2]:
gym.register_envs(ale_py)

In [3]:
from gymnasium.wrappers import GrayscaleObservation

In [4]:
env_id = "ALE/MsPacman-v5"      # The environment to train on
seed = 1                               # Random seed for reproducibility
total_timesteps = 10000000             # Total timesteps to run the training
learning_rate = 1e-4                   # Learning rate for the optimizer
num_envs = 1                           # Number of parallel environments
buffer_size = 1000000                  # Size of the replay buffer
gamma = 0.99                           # Discount factor
tau = 1.0                              # Target network update rate
target_network_frequency = 1000        # Timesteps to update target network
batch_size = 32                        # Batch size for training
start_e = 1                            # Starting epsilon for exploration
end_e = 0.01                           # Ending epsilon for exploration
exploration_fraction = 0.10            # Fraction of timesteps to decay epsilon
learning_starts = 80000                # Timesteps before starting training
train_frequency = 4                    # Frequency of training

In [5]:
def make_env(env_id, seed, idx, capture_video=False, run_name=""):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)

        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)

        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)

        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayscaleObservation(env)
        env = gym.wrappers.FrameStackObservation(env, 4)
        env.action_space.seed(seed)

        return env

    return thunk

In [6]:
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),
        )

    def forward(self, x):
        return self.network(x / 255.0)

In [7]:
  # ---- Exploration Schedule ---- #
  def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
      slope = (end_e - start_e) / duration
      return max(slope * t + start_e, end_e)

In [None]:
# ---- WandB Setup ---- #
import wandb

run_name="CSIR-GPU_run"

wandb.init(
    project="ATARI-RL",
    name=run_name,
    config={
        "env_id": env_id,
        "seed": seed,
        "learning_rate": learning_rate,
        "buffer_size": buffer_size,
        "batch_size": batch_size,
        "gamma": gamma,
        "tau": tau,
        "train_frequency": train_frequency,
        "exploration_fraction": exploration_fraction,
        "start_e": start_e,
        "end_e": end_e,
        "total_timesteps": total_timesteps,
    }
)

# ---- Training and Evaluation ---- #
# Setting up TensorBoard for logging
from tqdm import tqdm

writer = SummaryWriter(f"runs/{run_name}")
writer.add_text("hyperparameters", f"|param|value|\n|-|-|\n" + "\n".join([f"|{key}|{value}|" for key, value in locals().items()]))

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Setup device
device = torch.device("cuda")

# Create Environment
envs = gym.vector.SyncVectorEnv(
    [make_env(env_id, seed + i, i, capture_video=False, run_name=run_name) for i in range(num_envs)]
)

q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
target_network = QNetwork(envs).to(device)
target_network.load_state_dict(q_network.state_dict())

# Setup Replay Buffer
rb = ReplayBuffer(
    buffer_size,
    envs.single_observation_space,
    envs.single_action_space,
    device,
    optimize_memory_usage=True,
    handle_timeout_termination=False
)

# ---- Training Loop ---- #
start_time = time.time()
obs, _ = envs.reset(seed=seed)

# Initialize tqdm progress bar
with tqdm(total=total_timesteps, desc="Training Progress", unit="steps") as pbar:
    for global_step in range(total_timesteps):
        epsilon = linear_schedule(start_e, end_e, exploration_fraction * total_timesteps, global_step)

        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            q_values = q_network(torch.Tensor(obs).to(device))
            actions = torch.argmax(q_values, dim=1).cpu().numpy()

        next_obs, rewards, terminated, truncated, infos = envs.step(actions)

        if "final_info" in infos:
            for info in infos["final_info"]:
                if "episode" not in info:
                    continue
                episodic_return = info['episode']['r']
                episode_length = info['episode']['l']

                writer.add_scalar("charts/episodic_return", episodic_return, global_step)
                writer.add_scalar("charts/episode_length", episode_length, global_step)
                writer.add_scalar("charts/epsilon", epsilon, global_step)

                # Log to WandB
                wandb.log({
                    "episodic_return": episodic_return,
                    "episode_length": episode_length,
                    "epsilon": epsilon,
                    "global_step": global_step,
                })

        real_next_obs = next_obs.copy()
        for idx, d in enumerate(truncated):
            if d:
                real_next_obs[idx] = infos["final_observation"][idx]
        rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

        obs = next_obs

        if global_step > learning_starts:
            if global_step % train_frequency == 0:
                data = rb.sample(batch_size)
                with torch.no_grad():
                    target_max, _ = target_network(data.next_observations).max(dim=1)
                    td_target = data.rewards.flatten() + gamma * target_max * (1 - data.dones.flatten())
                old_val = q_network(data.observations).gather(1, data.actions).squeeze()
                loss = F.mse_loss(td_target, old_val)

                if global_step % 100 == 0:
                    writer.add_scalar("losses/td_loss", loss, global_step)
                    writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

                    # Log to WandB
                    wandb.log({
                        "td_loss": loss.item(),
                        "SPS": int(global_step / (time.time() - start_time)),
                        "global_step": global_step,
                    })

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if global_step % target_network_frequency == 0:
                for target_param, param in zip(target_network.parameters(), q_network.parameters()):
                    target_param.data.copy_(
                        tau * param.data + (1.0 - tau) * target_param.data
                    )

        # Update tqdm progress bar with additional metrics
        pbar.set_postfix({
            "episodic_return": episodic_return if "episodic_return" in locals() else 0,
            "epsilon": epsilon,
            "SPS": int(global_step / (time.time() - start_time))
        })
        pbar.update(1)

# ---- Save Model ---- #
model_path = f"runs/{run_name}/dqn_model.pth"
torch.save(q_network.state_dict(), model_path)
wandb.save(model_path)
print(f"Model saved to {model_path} and uploaded to WandB")

envs.close()
writer.close()
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcasarulez[0m ([33mHarish-Personal[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]
Training Progress:   0%| | 30885/10000000 [01:06<5:59:27, 462.22steps/s, episodi