In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Adjust paths to your saved pickles
reward_files = sorted(Path("checkpoint_data").glob("total_rewards_list_*.pkl"))
loss_files = sorted(Path("checkpoint_data").glob("total_loss_list_*.pkl"))

# load latest or specific
def load_latest(files):
    if not files:
        raise RuntimeError("No files found")
    with open(files[-1], "rb") as f:
        return pickle.load(f)

rewards = load_latest(reward_files)  # list of lists (episodes grouped)
losses = load_latest(loss_files)     # list of lists

# flatten into per-episode average if needed
flat_rewards = [np.mean(batch) if isinstance(batch, (list,tuple)) and batch else 0 for batch in rewards]
smoothed = np.convolve(flat_rewards, np.ones(50)/50, mode='valid')

plt.figure(figsize=(10,5))
plt.plot(range(len(smoothed)), smoothed, label="Smoothed Reward (win=50)")
plt.title("Training Reward Curve (Smoothed)")
plt.xlabel("Checkpoints")
plt.ylabel("Average Reward")
plt.legend()
plt.grid(True)
plt.savefig("Imgs/reward_curve_smoothed.png", dpi=300)
plt.close()

# Loss plot: convert each stored losses list to mean per checkpoint
mean_losses = [np.mean(batch) if batch else 0 for batch in losses]
plt.figure(figsize=(10,5))
plt.plot(mean_losses, label="Mean Huber Loss")
plt.title("TD-Loss (Huber) vs Iterations")
plt.xlabel("Checkpoint Index")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.savefig("Imgs/t d_loss_iterations.png", dpi=300)
plt.close()


In [2]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, nb_actions):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4,32,8,stride=4), nn.ReLU(),
            nn.Conv2d(32,64,4,stride=2), nn.ReLU(),
            nn.Conv2d(64,64,3,stride=1), nn.ReLU(),
            nn.Flatten(), nn.Linear(3136,512), nn.ReLU(),
            nn.Linear(512, nb_actions),
        )
    def forward(self,x): return self.network(x)

def param_table(nb_actions=4):
    m = DQN(nb_actions)
    total = sum(p.numel() for p in m.parameters())
    print("Total params:", total)
    for name, p in m.named_parameters():
        print(name, p.shape, p.numel())

if __name__ == "__main__":
    param_table(nb_actions=4)


Total params: 1686180
network.0.weight torch.Size([32, 4, 8, 8]) 8192
network.0.bias torch.Size([32]) 32
network.2.weight torch.Size([64, 32, 4, 4]) 32768
network.2.bias torch.Size([64]) 64
network.4.weight torch.Size([64, 64, 3, 3]) 36864
network.4.bias torch.Size([64]) 64
network.7.weight torch.Size([512, 3136]) 1605632
network.7.bias torch.Size([512]) 512
network.9.weight torch.Size([4, 512]) 2048
network.9.bias torch.Size([4]) 4


In [3]:
import gymnasium as gym
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from gymnasium.wrappers import ResizeObservation, GrayscaleObservation, FrameStackObservation
import matplotlib.pyplot as plt
import ale_py

gym.register_envs(ale_py)
env = gym.make("ALE/Breakout-v5")
env = ResizeObservation(env, (84,84))
env = GrayscaleObservation(env)
env = FrameStackObservation(env, 4)
env = MaxAndSkipEnv(env, skip=4)

obs, _ = env.reset()
plt.figure(figsize=(10,3))
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(obs[i], cmap="gray")
    plt.axis("off")
plt.suptitle("Preprocessed Frame Stack (84×84×4)")
plt.savefig("Fig4_1_preprocessing_pipeline.png", dpi=300)
plt.close()
