<a href="https://colab.research.google.com/github/timmyt110/Deep-Q-Learning-on-Atari-Final-Project/blob/main/Mario_DDQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')




Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os

# NEW BASE PATH for COLAB (change folder name if you want)
base = "/content/drive/MyDrive/CSCI166_RL_New3"

paths = {
    "models":  f"{base}/Models",
    "results": f"{base}/Results",
    "videos":  f"{base}/Videos",
    "runs":    f"{base}/Runs"
}

# Create directories
for p in paths.values():
    os.makedirs(p, exist_ok=True)

paths


{'models': '/content/drive/MyDrive/CSCI166_RL_New3/Models',
 'results': '/content/drive/MyDrive/CSCI166_RL_New3/Results',
 'videos': '/content/drive/MyDrive/CSCI166_RL_New3/Videos',
 'runs': '/content/drive/MyDrive/CSCI166_RL_New3/Runs'}

In [None]:
!pip install -q "gymnasium[atari,accept-rom-license]"
!pip install -q autorom
!AutoROM --accept-license


[0mAutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.12/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.


In [None]:
# === Step 6: DQN model ===
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super().__init__()

        # input_shape is (C, H, W)
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # figure out the size of the conv output
        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            conv_out_size = self.conv(dummy).shape[-1]

        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def forward(self, x):
        # x is uint8 images in [0, 255]
        x = x.float() / 255.0
        return self.fc(self.conv(x))


In [None]:
# === Step 7: Replay buffer ===
import collections
import random
import numpy as np

Experience = collections.namedtuple(
    "Experience",
    ["state", "action", "reward", "done", "next_state"],
)

class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buf = collections.deque(maxlen=capacity)

    def __len__(self):
        return len(self.buf)

    def push(self, e: Experience):
        # e should be an Experience(...)
        self.buf.append(e)

    def sample(self, batch_size: int):
        batch = random.sample(self.buf, batch_size)
        s, a, r, d, sp = zip(*batch)

        return (
            np.stack(s, axis=0),            # (B, C, H, W) uint8
            np.array(a, dtype=np.int64),    # (B,)
            np.array(r, dtype=np.float32),  # (B,)
            np.array(d, dtype=np.uint8),    # (B,)
            np.stack(sp, axis=0),           # (B, C, H, W) uint8
        )


In [None]:
# === Step 8: Helpers ===
import numpy as np
import random
import torch

def epsilon_by_frame(step: int, eps_start: float, eps_end: float, eps_decay_steps: int) -> float:
    frac = max(0.0, 1.0 - step / float(eps_decay_steps))
    return eps_end + (eps_start - eps_end) * frac

def obs_to_chw(obs) -> np.ndarray:
    arr = np.array(obs)            # (H, W, C)
    return np.transpose(arr, (2, 0, 1))  # (C, H, W)

def select_action(online_net, state_chw: np.ndarray, epsilon: float, action_space, device: str):
    # epsilon-greedy exploration
    if random.random() < epsilon:
        return action_space.sample()

    with torch.no_grad():
        s = torch.from_numpy(state_chw[None, ...]).to(device)  # (1, C, H, W)
        q = online_net(s)
        return int(q.argmax(dim=1).item())


In [None]:
# === Step 9: DDQN update ===
# === Step 9: Environment Setup (AtariPreprocessing + FrameStack) ===

import gymnasium as gym
import numpy as np
from collections import deque

# Use this import path for AtariPreprocessing (works across versions)
try:
    from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
except ImportError:
    from gymnasium.wrappers.atari import AtariPreprocessing

from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo

class SimpleFrameStack(gym.Wrapper):
    def __init__(self, env, k=4):
        super().__init__(env)
        self.k = k
        self.frames = deque(maxlen=k)
        h, w = 84, 84
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(h, w, k), dtype=np.uint8
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.frames.clear()
        for _ in range(self.k):
            self.frames.append(obs)
        return np.stack(self.frames, axis=-1), info

    def step(self, action):
        obs, reward, term, trunc, info = self.env.step(action)
        self.frames.append(obs)
        return np.stack(self.frames, axis=-1), reward, term, trunc, info


def make_env(seed: int = 0):
    env = gym.make(
        "ALE/MarioBros-v5",
        frameskip=1,
        repeat_action_probability=0.25,
        full_action_space=True,
    )

    env = AtariPreprocessing(
        env,
        frame_skip=4,
        screen_size=84,
        grayscale_obs=True,
        scale_obs=False,
        terminal_on_life_loss=False,
    )

    env = SimpleFrameStack(env, k=4)
    env = RecordEpisodeStatistics(env)

    # This is to record training videos every 50 episodes
    env = RecordVideo(
        env,
        video_folder=paths["videos"],
        episode_trigger=lambda ep: ep % 50 == 0,
        name_prefix="train_mario",
    )

    env.action_space.seed(seed)
    return env



In [None]:
import gymnasium as gym
import numpy as np
from collections import deque

# Use this import path for AtariPreprocessing (works across versions)
try:
    from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
except ImportError:
    # older layout fallback
    from gymnasium.wrappers.atari import AtariPreprocessing

from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo

# ---- Minimal, version-proof FrameStack replacement ----
class SimpleFrameStack(gym.Wrapper):
    """
    Stacks the last `k` grayscale frames along the last axis (H, W, k).
    Returns uint8 [0,255] like the real FrameStack.
    """
    def __init__(self, env, k=4):
        super().__init__(env)
        self.k = k
        self.frames = deque(maxlen=k)
        # After AtariPreprocessing(grayscale_obs=True, screen_size=84) -> (84, 84)
        h, w = 84, 84
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(h, w, k), dtype=np.uint8
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.frames.clear()
        for _ in range(self.k):
            self.frames.append(obs)  # seed stack with first frame
        return self._get_obs(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, terminated, truncated, info

    def _get_obs(self):
        # Stack along the last dimension -> (H, W, k)
        return np.stack(list(self.frames), axis=-1)

def make_env(env_name="ALE/MarioBros-v5", seed=None):
    # 1. Create the base environment
    #    frameskip=1 + repeat_action_prob is a standard 'v5' / sticky-action setting
    env = gym.make(
        env_name,
        render_mode="rgb_array",
        frameskip=1,
        repeat_action_probability=0.25,
        full_action_space=True
    )

    # 2. Apply standard Atari preprocessing
    env = AtariPreprocessing(
        env,
        frame_skip=4,
        screen_size=84,
        grayscale_obs=True,
        scale_obs=False,            # returns uint8 [0,255]
        terminal_on_life_loss=False
    )

    # 3. Stack frames
    env = SimpleFrameStack(env, k=4)

    # 4. Record basic stats (episode return/length) in info
    env = RecordEpisodeStatistics(env)

    if seed is not None:
        env.reset(seed=seed)
        env.action_space.seed(seed)

    return env

In [None]:
# === Step 10: Training setup & loop (Mario DDQN) ===
import os, time
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F # <-- Added for mse_loss
from dataclasses import dataclass

# Safe TensorBoard import (falls back to dummy if missing)
try:
    from torch.utils.tensorboard import SummaryWriter
except Exception as e:
    print(" TensorBoard not available, using dummy SummaryWriter:", e)

    class SummaryWriter:
        def __init__(self, *args, **kwargs):
            pass
        def add_scalar(self, *args, **kwargs):
            pass
        def close(self):
            pass

# ---- Config ----
@dataclass
class TrainConfig:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    lr: float = 1e-4
    gamma: float = 0.99
    batch_size: int = 32
    grad_clip: float = 10.0
    replay_size: int = 1_000_000
    warmup_steps: int = 50_000
    update_every: int = 4
    target_sync: int = 10_000
    eps_start: float = 1.0
    eps_end: float = 0.05
    eps_decay_steps: int = 1_000_000
    max_frames: int = 1_000_000

CFG = TrainConfig()

# Use the folder structure we already created in Step 1
MODELS_DIR  = paths["models"]
RESULTS_DIR = paths["results"]
RUNS_DIR    = paths["runs"]
VIDEOS_DIR  = paths["videos"]  # mostly used by RecordVideo in make_env

for d in [MODELS_DIR, RESULTS_DIR, RUNS_DIR, VIDEOS_DIR]:
    os.makedirs(d, exist_ok=True)

# NOTE: we are reusing the make_env() and SimpleFrameStack defined earlier,
# which already wrap Mario with AtariPreprocessing + FrameStack + RecordVideo.

# ---- DDQN Learn Step Function (Missing from notebook) ----
def ddqn_learn_step(online_net, target_net, optimizer, buffer, batch_size, gamma, grad_clip, device):
    s, a, r, d, sp = buffer.sample(batch_size)

    # Convert to tensors
    s_t = torch.from_numpy(s).to(device).float() / 255.0  # Normalized
    a_t = torch.from_numpy(a).to(device)
    r_t = torch.from_numpy(r).to(device)
    d_t = torch.from_numpy(d).to(device).float()
    sp_t = torch.from_numpy(sp).to(device).float() / 255.0 # Normalized

    # Compute Q(s,a) for current states
    q_values = online_net(s_t)
    # Gather the Q-values for the actions taken
    q_s_a = q_values.gather(1, a_t.unsqueeze(-1)).squeeze(-1)

    # Compute Q'(s',a') for next states using online network (for action selection)
    with torch.no_grad():
        next_q_values_online = online_net(sp_t)
        # Select best action from online network
        next_actions = next_q_values_online.argmax(dim=1)
        # Compute Q'(s',a') using target network (for value estimation)
        next_q_values_target = target_net(sp_t)
        q_s_prime_a_prime = next_q_values_target.gather(1, next_actions.unsqueeze(-1)).squeeze(-1)

        # Compute target Q-values
        # (1 - d_t) ensures that if done=True, the next state value is 0
        target_q_s_a = r_t + gamma * q_s_prime_a_prime * (1 - d_t)

    # Compute loss (MSE between predicted Q and target Q)
    loss = F.mse_loss(q_s_a, target_q_s_a)

    # Optimize the online network
    optimizer.zero_grad()
    loss.backward()
    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(online_net.parameters(), grad_clip)
    optimizer.step()

    return loss.item()

# ---- Trainer (expects DQN, ReplayBuffer, Experience, helpers from Steps 6â€“9) ----
def train_ddqn(run_label: str,
               env_name: str = "ALE/MarioBros-v5",
               cfg: TrainConfig = CFG):
    # Reuse your existing make_env() (already wired to MarioBros)
    env = make_env(seed=0)
    obs, _ = env.reset()
    action_space = env.action_space

    # Env returns (H, W, 4); our network expects (C, H, W)
    C, H, W = 4, 84, 84
    n_actions = action_space.n

    online = DQN((C, H, W), n_actions).to(cfg.device)
    target = DQN((C, H, W), n_actions).to(cfg.device)
    target.load_state_dict(online.state_dict())

    optimizer = optim.Adam(online.parameters(), lr=cfg.lr)
    buffer = ReplayBuffer(cfg.replay_size)

    tb = SummaryWriter(log_dir=os.path.join(RUNS_DIR, run_label))
    csv_path = os.path.join(RESULTS_DIR, f"{run_label}.csv")
    if os.path.exists(csv_path):
        os.remove(csv_path)

    state_chw = obs_to_chw(obs)
    ep_return, ep_idx, step = 0.0, 0, 0

    while step < cfg.max_frames:
        step += 1
        eps = epsilon_by_frame(step, cfg.eps_start, cfg.eps_end, cfg.eps_decay_steps)

        action = select_action(online, state_chw, eps, action_space, cfg.device)
        next_obs, reward, term, trunc, info = env.step(action)
        done = term or trunc
        next_chw = obs_to_chw(next_obs)

        buffer.push(Experience(state_chw, action, reward, done, next_chw))
        ep_return += reward
        state_chw = next_chw

        if done:
            ep_idx += 1
            tb.add_scalar("charts/episode_return", ep_return, step)
            tb.add_scalar("charts/epsilon", eps, step)
            if "episode" in info:   # fixed key name
                tb.add_scalar("charts/episode_length", info["episode"]["l"], step)

            with open(csv_path, "a") as f:
                f.write(f"{step},{ep_idx},{ep_return}\n")

            obs, _ = env.reset()
            state_chw = obs_to_chw(obs)
            ep_return = 0.0

        # Fill replay buffer before training
        if len(buffer) < cfg.warmup_steps:
            continue

        # Gradient step
        if step % cfg.update_every == 0:
            loss = ddqn_learn_step(
                online, target, optimizer, buffer,
                cfg.batch_size, cfg.gamma, cfg.grad_clip, cfg.device
            )
            tb.add_scalar("loss/td", loss, step)

        # Sync target network
        if step % cfg.target_sync == 0:
            target.load_state_dict(online.state_dict())

        # Periodic checkpoint
        if step % 200_000 == 0:
            ckpt = os.path.join(MODELS_DIR, f"{run_label}_{step}.pth")
            torch.save(online.state_dict(), ckpt)
            print(f" saved {ckpt}")

    env.close()
    tb.close()
    print(f" Finished {run_label}")
    return online

In [None]:
# Step 11: Plot learning curve
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_returns(run_label: str, window: int = 50):
    # use the results directory we created in Step 1
    RESULTS_DIR = paths["results"]

    csv_path = os.path.join(RESULTS_DIR, f"{run_label}.csv")
    out_png  = os.path.join(RESULTS_DIR, f"{run_label}_curve.png")

    if not os.path.exists(csv_path):
        print(" CSV not found:", csv_path)
        return

    df = pd.read_csv(csv_path, header=None, names=["steps","episode","return"])
    df["smooth"] = df["return"].rolling(window, min_periods=1).mean()

    plt.figure(figsize=(10,5))
    plt.plot(df["steps"], df["smooth"], label="Smoothed Return")
    plt.xlabel("Env Steps")
    plt.ylabel("Episodic Return")
    plt.title(f"Learning Curve: {run_label}")
    plt.grid(True)
    plt.legend()
    plt.savefig(out_png, bbox_inches="tight")
    plt.close()

    print(" Saved learning-curve plot:", out_png)


In [None]:
# === Step 12: Record short videos (10â€“30s) ===
import os
import time
import torch

def record_clip(env_name: str,
                run_label: str,
                net,
                ckpt_path: str = None,
                eps: float = 0.05,
                seconds: int = 20,
                device: str = CFG.device):

    # Use the videos directory from Step 1
    VIDEOS_DIR = paths["videos"]
    vdir = os.path.join(VIDEOS_DIR, run_label)
    os.makedirs(vdir, exist_ok=True)

    # Build env with rgb rendering enabled
    base = gym.make(
        env_name,
        frameskip=1,
        repeat_action_probability=0.25,
        full_action_space=True,
        render_mode="rgb_array",  # required for video
    )

    env = AtariPreprocessing(
        base,
        frame_skip=4,
        screen_size=84,
        grayscale_obs=True,
        scale_obs=False,
        terminal_on_life_loss=False,
    )
    env = SimpleFrameStack(env, k=4)
    env = RecordEpisodeStatistics(env)

    # Wrap RecordVideo LAST so it records actual frames
    from gymnasium.wrappers import RecordVideo
    env = RecordVideo(
        env,
        video_folder=vdir,
        episode_trigger=lambda e: True,   # record first episode
        name_prefix=run_label,
    )

    # Load checkpoint if provided
    if ckpt_path is not None:
        net.load_state_dict(torch.load(ckpt_path, map_location=device))

    obs, _ = env.reset(seed=123)
    done, t0 = False, time.time()

    while not done and (time.time() - t0) < seconds:
        chw = obs_to_chw(obs)
        action = select_action(net, chw, eps, env.action_space, device)
        obs, r, term, trunc, _ = env.step(action)
        done = term or trunc

    env.close()
    print(f" Saved video(s) to: {vdir}   (look for .mp4 files)")


In [None]:
# === Step 13: Train DDQN on Mario Bros + Plot + Record Videos (Colab) ===

import os
import torch
import gymnasium as gym
import ale_py

# Register Atari environments (required for Gymnasium v1.0+)
gym.register_envs(ale_py)

# -----------------------------
#  Paths (from Step 1: `paths`)
# -----------------------------
MODELS_DIR  = paths["models"]
RESULTS_DIR = paths["results"]
VIDEOS_DIR  = paths["videos"]

for d in [MODELS_DIR, RESULTS_DIR, VIDEOS_DIR]:
    os.makedirs(d, exist_ok=True)

print("Models dir :", MODELS_DIR)
print("Results dir:", RESULTS_DIR)
print("Videos dir :", VIDEOS_DIR)

# -----------------------------
#  Config overrides
# -----------------------------
CFG.device       = "cuda" if torch.cuda.is_available() else "cpu"
CFG.max_frames   = 205_000          # just over 200k so we hit the auto-save
CFG.warmup_steps = 10_000
CFG.target_sync  = 5_000
CFG.update_every = 4

print("Using device:", CFG.device)

mario_label = "DDQN_MarioBros"

# -----------------------------
#  Training
# -----------------------------
print(" Training on:", mario_label)

# NOTE: our Step 10 defined:
#   def train_ddqn(run_label: str, env_name: str = "ALE/MarioBros-v5", cfg: TrainConfig = CFG)
# So we call it like this:
mario_net = train_ddqn(run_label=mario_label, cfg=CFG)

# -----------------------------
#  Plot learning curve
# -----------------------------
print(" Plotting curveâ€¦")
plot_returns(mario_label)

# -----------------------------
#  Checkpoints
# -----------------------------
# Saved automatically in train_ddqn at 200k steps:
early_ckpt = os.path.join(MODELS_DIR, f"{mario_label}_200000.pth")

# Save final model at the end of training
final_ckpt = os.path.join(MODELS_DIR, f"{mario_label}_205000.pth")
torch.save(mario_net.state_dict(), final_ckpt)
print(" Saved final checkpoint:", final_ckpt)

# -----------------------------
#  Record videos
# -----------------------------
# Early (more random / less trained)
record_clip(
    "ALE/MarioBros-v5",
    "Mario_early",
    mario_net,
    ckpt_path=early_ckpt,    # 200k model
    eps=0.20,
    seconds=20
)

# Learned (later policy at 205k)
record_clip(
    "ALE/MarioBros-v5",
    "Mario_learned",
    mario_net,
    ckpt_path=final_ckpt,    # 205k model
    eps=0.05,
    seconds=20
)

print(" Step 13 completed successfully!")
print(" Curves in  :", RESULTS_DIR)
print(" Models in  :", MODELS_DIR)
print(" Videos in  :", VIDEOS_DIR)

Models dir : /content/drive/MyDrive/CSCI166_RL_New3/Models
Results dir: /content/drive/MyDrive/CSCI166_RL_New3/Results
Videos dir : /content/drive/MyDrive/CSCI166_RL_New3/Videos
Using device: cpu
 Training on: DDQN_MarioBros
 saved /content/drive/MyDrive/CSCI166_RL_New3/Models/DDQN_MarioBros_200000.pth
 Finished DDQN_MarioBros
 Plotting curveâ€¦
 Saved learning-curve plot: /content/drive/MyDrive/CSCI166_RL_New3/Results/DDQN_MarioBros_curve.png
 Saved final checkpoint: /content/drive/MyDrive/CSCI166_RL_New3/Models/DDQN_MarioBros_205000.pth


  logger.warn(


 Saved video(s) to: /content/drive/MyDrive/CSCI166_RL_New3/Videos/Mario_early   (look for .mp4 files)


  logger.warn(


 Saved video(s) to: /content/drive/MyDrive/CSCI166_RL_New3/Videos/Mario_learned   (look for .mp4 files)
 Step 13 completed successfully!
 Curves in  : /content/drive/MyDrive/CSCI166_RL_New3/Results
 Models in  : /content/drive/MyDrive/CSCI166_RL_New3/Models
 Videos in  : /content/drive/MyDrive/CSCI166_RL_New3/Videos


In [None]:
# === Step 13b: Second training run (305k frames) ===

import os
import torch

mario_label = "DDQN_MarioBros_305k"

# New config for longer run
CFG.device       = "cuda" if torch.cuda.is_available() else "cpu"
CFG.max_frames   = 305_000
CFG.warmup_steps = 10_000
CFG.target_sync  = 5_000
CFG.update_every = 4

print("Using device:", CFG.device)
print(" Training on:", mario_label)

mario_net_305k = train_ddqn(run_label=mario_label, cfg=CFG)

# Plot curve
plot_returns(mario_label)

# Checkpoints
early_ckpt = os.path.join(paths["models"], f"{mario_label}_200000.pth")
final_ckpt = os.path.join(paths["models"], f"{mario_label}_305000.pth")
torch.save(mario_net_305k.state_dict(), final_ckpt)
print("Saved final checkpoint:", final_ckpt)

# Videos
record_clip("ALE/MarioBros-v5", "305k_early", mario_net_305k, ckpt_path=early_ckpt, eps=0.20, seconds=20)
record_clip("ALE/MarioBros-v5", "305k_late", mario_net_305k, ckpt_path=final_ckpt, eps=0.05, seconds=20)

print(" 305k run completed!")


Using device: cpu
 Training on: DDQN_MarioBros_305k
 saved /content/drive/MyDrive/CSCI166_RL_New3/Models/DDQN_MarioBros_305k_200000.pth
 Finished DDQN_MarioBros_305k
 Saved learning-curve plot: /content/drive/MyDrive/CSCI166_RL_New3/Results/DDQN_MarioBros_305k_curve.png
Saved final checkpoint: /content/drive/MyDrive/CSCI166_RL_New3/Models/DDQN_MarioBros_305k_305000.pth


  logger.warn(


 Saved video(s) to: /content/drive/MyDrive/CSCI166_RL_New3/Videos/305k_early   (look for .mp4 files)


  logger.warn(


 Saved video(s) to: /content/drive/MyDrive/CSCI166_RL_New3/Videos/305k_late   (look for .mp4 files)
 305k run completed!


In [None]:
def dqn_learn_step(online_net, target_net, optimizer, buffer, batch_size, gamma, grad_clip, device):
    """
    Vanilla DQN TD update:
      - uses TARGET net to compute max_a' Q(s', a')
      - no Double DQN selection step
    """
    s, a, r, d, sp = buffer.sample(batch_size)

    s_t  = torch.from_numpy(s).to(device).float() / 255.0
    a_t  = torch.from_numpy(a).to(device)
    r_t  = torch.from_numpy(r).to(device)
    d_t  = torch.from_numpy(d).to(device).float()
    sp_t = torch.from_numpy(sp).to(device).float() / 255.0

    # Q(s,a) from online net
    q_values = online_net(s_t)
    q_s_a = q_values.gather(1, a_t.unsqueeze(-1)).squeeze(-1)

    with torch.no_grad():
        # Vanilla DQN: max_a' Q_target(s', a')
        next_q_values = target_net(sp_t)
        max_next_q, _ = next_q_values.max(dim=1)
        target_q = r_t + gamma * max_next_q * (1 - d_t)

    loss = F.mse_loss(q_s_a, target_q)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(online_net.parameters(), grad_clip)
    optimizer.step()

    return float(loss.item())


def train_dqn(run_label: str,
              env_name: str = "ALE/MarioBros-v5",
              cfg: TrainConfig = CFG):

    env = make_env(seed=0)
    obs, _ = env.reset()
    action_space = env.action_space

    C, H, W = 4, 84, 84
    n_actions = action_space.n

    online = DQN((C, H, W), n_actions).to(cfg.device)
    target = DQN((C, H, W), n_actions).to(cfg.device)
    target.load_state_dict(online.state_dict())

    optimizer = optim.Adam(online.parameters(), lr=cfg.lr)
    buffer = ReplayBuffer(cfg.replay_size)

    tb = SummaryWriter(log_dir=os.path.join(RUNS_DIR, run_label))

    csv_path = os.path.join(RESULTS_DIR, f"{run_label}.csv")
    if os.path.exists(csv_path):
        os.remove(csv_path)

    state_chw = obs_to_chw(obs)
    ep_return = 0.0
    ep_idx = 0
    step = 0

    while step < cfg.max_frames:
        step += 1
        eps = epsilon_by_frame(step, cfg.eps_start, cfg.eps_end, cfg.eps_decay_steps)

        # epsilon-greedy action from ONLINE net
        action = select_action(online, state_chw, eps, action_space, cfg.device)
        next_obs, reward, term, trunc, info = env.step(action)
        done = term or trunc
        next_chw = obs_to_chw(next_obs)

        buffer.push(Experience(state_chw, action, reward, done, next_chw))
        ep_return += reward
        state_chw = next_chw

        if done:
            ep_idx += 1
            tb.add_scalar("charts/episode_return", ep_return, step)
            tb.add_scalar("charts/epsilon", eps, step)

            # ðŸ”¹ log to CSV so plot_returns() works
            with open(csv_path, "a") as f:
                f.write(f"{step},{ep_idx},{ep_return}\n")

            obs, _ = env.reset()
            state_chw = obs_to_chw(obs)
            ep_return = 0.0

        # warmup period
        if len(buffer) < cfg.warmup_steps:
            continue

        # Gradient update
        if step % cfg.update_every == 0:
            loss = dqn_learn_step(
                online, target, optimizer, buffer,
                cfg.batch_size, cfg.gamma, cfg.grad_clip, cfg.device
            )
            tb.add_scalar("losses/td_loss", loss, step)

        # Sync target network
        if step % cfg.target_sync == 0:
            target.load_state_dict(online.state_dict())

        # Optional checkpoint
        if step % 200_000 == 0:
            ckpt = os.path.join(MODELS_DIR, f"{run_label}_{step}.pth")
            torch.save(online.state_dict(), ckpt)
            print(f" Saved checkpoint: {ckpt}")

    env.close()
    tb.close()
    print(f" Finished training {run_label}")
    return online



In [None]:
from dataclasses import dataclass

@dataclass
class TrainConfig:
    lr: float = 1e-4
    gamma: float = 0.99
    eps_start: float = 1.0
    eps_end: float = 0.1
    eps_decay_steps: int = 300_000
    batch_size: int = 32
    replay_size: int = 100_000
    warmup_steps: int = 10_000
    update_every: int = 4
    target_sync: int = 10_000
    max_frames: int = 300_000
    grad_clip: float = 10.0
    device: str = "cpu"

CFG = TrainConfig()


In [None]:
# --------- Run 1: Baseline DQN (project requirement) ----------
dqn_label = "DQN_MarioBros"
CFG.max_frames   = 200_000   # e.g., 200k frames
CFG.warmup_steps = 50_000
CFG.target_sync  = 10_000
CFG.update_every = 4

print("Using device:", CFG.device)
print(" Training baseline:", dqn_label)
mario_dqn = train_dqn(run_label=dqn_label, cfg=CFG)
plot_returns(dqn_label)

# --------- Run 2: Double DQN (your current run) ----------
mario_label = "DDQN_MarioBros"
CFG.max_frames   = 305_000
CFG.warmup_steps = 10_000
CFG.target_sync  = 5_000
CFG.update_every = 4

print("Using device:", CFG.device)
print(" Training DDQN:", mario_label)
mario_net_305k = train_ddqn(run_label=mario_label, cfg=CFG)
plot_returns(mario_label)



Using device: cpu
 Training baseline: DQN_MarioBros
 Saved checkpoint: /content/drive/MyDrive/CSCI166_RL_New3/Models/DQN_MarioBros_200000.pth
 Finished training DQN_MarioBros
 Saved learning-curve plot: /content/drive/MyDrive/CSCI166_RL_New3/Results/DQN_MarioBros_curve.png
Using device: cpu
 Training DDQN: DDQN_MarioBros
 saved /content/drive/MyDrive/CSCI166_RL_New3/Models/DDQN_MarioBros_200000.pth
 Finished DDQN_MarioBros
 Saved learning-curve plot: /content/drive/MyDrive/CSCI166_RL_New3/Results/DDQN_MarioBros_curve.png


In [None]:
# ----- Record videos for Baseline DQN -----

# CHOOSE a checkpoint (train_dqn saves one at 200k automatically)
dqn_ckpt_200k = os.path.join(MODELS_DIR, "DQN_MarioBros_200000.pth")

# If no checkpoint was saved, save manually:
dqn_final_ckpt = os.path.join(MODELS_DIR, "DQN_MarioBros_final.pth")
torch.save(mario_dqn.state_dict(), dqn_final_ckpt)

# ---------- Early-ish DQN behavior ----------
record_clip(
    "ALE/MarioBros-v5",
    "DQN_early",
    mario_dqn,
    ckpt_path=dqn_ckpt_200k,
    eps=0.25,
    seconds=20
)


# ---------- Learned DQN behavior ----------
record_clip(
    "ALE/MarioBros-v5",
    "DQN_learned",
    mario_dqn,
    ckpt_path=dqn_final_ckpt,
    eps=0.05,
    seconds=20
)


  logger.warn(


 Saved video(s) to: /content/drive/MyDrive/CSCI166_RL_New3/Videos/DQN_early   (look for .mp4 files)


  logger.warn(


 Saved video(s) to: /content/drive/MyDrive/CSCI166_RL_New3/Videos/DQN_learned   (look for .mp4 files)
