In [None]:
import random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import ale_py   
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStack
from torch.distributions import Categorical
from collections import deque

In [None]:
# ──────────────────────────────────────────────────────────────────────────
# 1. Hyper-parameters & configuration
# ──────────────────────────────────────────────────────────────────────────
class Config:
    # Environment
    env_id: str = "ALE/MsPacman-v5"
    seed: int = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # A2C
    epsilon: float = 0.05
    gamma: float = 0.99
    rollout_steps: int = 5
    total_episodes: int = 10_000

    # Optimisation
    lr: float = 1e-4
    entropy_coef: float = 0.05
    value_loss_coef: float = 0.5        
    grad_clip_norm: float = 0.5         

    # Logging
    log_interval: int = 1             
    moving_avg_episodes: int = 100

    # DEBUG 
    debug_interval: int = 10
    

# Reproducibility
random.seed(Config.seed)
np.random.seed(Config.seed)
torch.manual_seed(Config.seed)
if Config.device.type == "cuda":
    torch.cuda.manual_seed_all(Config.seed)

In [None]:
# ──────────────────────────────────────────────────────────────────────────
# 2. Helpers – environment factory
# ──────────────────────────────────────────────────────────────────────────
def make_env(env_id: str, seed: int | None = None) -> gym.Env:
    env = gym.make(
        env_id,
        frameskip=1,
        full_action_space=False
    )
    env = AtariPreprocessing(
        env,
        frame_skip=4,
        screen_size=84,          # smaller than 128 → faster
        terminal_on_life_loss=False,
        grayscale_obs=True,
        scale_obs=True
    )
    env = FrameStack(env, num_stack=4)
    if seed is not None:
        env.reset(seed=seed)
        env.action_space.seed(seed)
    return env

In [None]:
# ──────────────────────────────────────────────────────────────────────────
# 3. Neural network – shared conv base + policy/value heads
# ──────────────────────────────────────────────────────────────────────────
class ActorCriticNet(nn.Module):
    def __init__(self, input_shape: tuple[int, int, int], num_actions: int):
        super().__init__()

        # In FrameStack the shape is (C,H,W) where C ≤ 4
        if len(input_shape) != 3:
            raise ValueError("Expected 3-D input shape") 
        channels, height, width = input_shape if input_shape[0] <= 4 else (input_shape[2], *input_shape[:2])

        self.conv1 = nn.Conv2d(channels, 32, 8, 4)
        self.conv2 = nn.Conv2d(32, 64, 4, 2)
        self.conv3 = nn.Conv2d(64, 64, 3, 1)

        # Compute conv output size
        with torch.no_grad():
            dummy = torch.zeros(1, channels, height, width)
            out = self._forward_conv(dummy)
            flat = out.view(1, -1).size(1)

        self.fc = nn.Linear(flat, 1024)
        self.policy_head = nn.Linear(1024, num_actions)
        self.value_head = nn.Linear(1024, 1)

    def _forward_conv(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return x

    def forward(self, x):
        # convert to float and channels-first
        x = x.float()
        if x.ndim == 4 and x.shape[1] not in (1, 3, 4):
            # channels last → channels first
            x = x.permute(0, 3, 1, 2)
        if x.max() > 1.0:
            x = x / 255.0
        x = self._forward_conv(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc(x))
        logits = self.policy_head(x)
        value = self.value_head(x)
        return logits, value.squeeze(-1)

In [None]:
# ──────────────────────────────────────────────────────────────────────────
# 4. A2C agent
# ──────────────────────────────────────────────────────────────────────────
class A2CAgent:
    def __init__(self, model: nn.Module, cfg: Config):
        self.model = model
        self.device = cfg.device
        self.gamma = cfg.gamma
        self.entropy_coef = cfg.entropy_coef
        self.optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    def select_action(self, state):
        state_t = torch.tensor(state, device=self.device)
        if state_t.ndim == 3:
            state_t = state_t.unsqueeze(0)
        logits, value = self.model(state_t)
        dist = Categorical(logits=logits)
        action = dist.sample()
        if random.random() < Config.epsilon:
            action = torch.tensor(dist.sample().new_full((), random.randrange(dist.probs.size(-1))))
        return (
            action.item(),
            dist.log_prob(action).squeeze(0),
            value.squeeze(0),
            dist.entropy().squeeze(0)
        )

    def update(self, rewards, log_probs, values, entropies, next_value, done):
        """Performs the A2C update and **returns debug scalars**."""
        rewards = torch.tensor(rewards, device=self.device, dtype=torch.float32)
        log_probs = torch.stack(log_probs)
        values = torch.stack(values)        # shape [T]
        entropies = torch.stack(entropies)

        # Compute returns
        R = 0.0 if done else next_value.item()
        returns = []
        for r in reversed(rewards):
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, device=self.device)

        advantages = returns - values.detach()
        adv_mean = advantages.mean()
        adv_std  = advantages.std(unbiased=False)        # never NaN for 1 element
        advantages = (advantages - adv_mean) / (adv_std + 1e-8)

        # Losses
        policy_loss = -(advantages * log_probs).mean()
        value_loss = F.mse_loss(values, returns)
        entropy_mean = entropies.mean()
        total_loss = (policy_loss
              + Config.value_loss_coef * value_loss       # critic weight ↓
              - Config.entropy_coef      * entropies.mean())

        # Back-prop
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), Config.grad_clip_norm)
        self.optimizer.step()

        # DEBUG scalars (Python floats)
        return (
            policy_loss.item(),
            value_loss.item(),
            entropy_mean.item(),
            total_loss.item(),
            advantages.mean().item(),
            advantages.min().item(),
            advantages.max().item()
        )

In [None]:
# ──────────────────────────────────────────────────────────────────────────
# 5. Training loop with integrated debugging
# ──────────────────────────────────────────────────────────────────────────
def main():
    env = make_env(Config.env_id, Config.seed)
    obs_shape = env.observation_space.shape      # (C, H, W) after preprocessing
    n_actions = env.action_space.n

    model = ActorCriticNet(obs_shape, n_actions).to(Config.device)
    agent = A2CAgent(model, Config)

    episode_rewards = deque(maxlen=Config.moving_avg_episodes)

    # DEBUG accumulators
    debug_on = bool(Config.debug_interval)
    interval = Config.debug_interval or 1  # avoid div by zero

    act_counts_interval = np.zeros(n_actions, dtype=int)
    pl_sum = vl_sum = ent_sum = tot_sum = adv_mean_sum = 0.0
    adv_min_list, adv_max_list = [], []

    for ep in range(1, Config.total_episodes + 1):
        state, _ = env.reset()
        done, total_reward = False, 0.0
        # Per-episode action counts (for optional per-episode analysis)
        ep_act_counts = np.zeros(n_actions, dtype=int)

        # Roll out until done
        while not done:
            rewards, log_probs, values, entropies = [], [], [], []

            # n-step rollout (or until episode ends)
            for _ in range(Config.rollout_steps):
                action, logprob, value, entropy = agent.select_action(state)
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                rewards.append(reward)
                log_probs.append(logprob)
                values.append(value)
                entropies.append(entropy)
                total_reward += reward
                ep_act_counts[action] += 1

                state = next_state
                if done:
                    break

            # bootstrap value
            if done:
                next_val = torch.zeros(1, device=Config.device)
            else:
                st_t = torch.tensor(state, device=Config.device)
                if st_t.ndim == 3:
                    st_t = st_t.unsqueeze(0)
                with torch.no_grad():
                    _, next_val_t = agent.model(st_t)
                next_val = next_val_t.squeeze(0)

            # agent update & debug scalars
            pl, vl, ent, tot, adv_m, adv_min, adv_max = agent.update(
                rewards, log_probs, values, entropies, next_val, done
            )

        # ─── end of episode ────────────────────────────────────────────
        episode_rewards.append(total_reward)

        # accumulate debug stats
        if debug_on:
            pl_sum += pl
            vl_sum += vl
            ent_sum += ent
            tot_sum += tot
            adv_mean_sum += adv_m
            adv_min_list.append(adv_min)
            adv_max_list.append(adv_max)
            act_counts_interval += ep_act_counts

        # reward log
        if ep % Config.log_interval == 0:
            avg_r = sum(episode_rewards) / len(episode_rewards)
            print(f"Episode {ep:<4d} | EpReward: {total_reward:>6.2f} | MovingAvg({len(episode_rewards)}): {avg_r:6.2f}")

        # DEBUG print every N episodes
        if debug_on and ep % interval == 0:
            avg_pl  = pl_sum  / interval
            avg_vl  = vl_sum  / interval
            avg_ent = ent_sum / interval
            avg_tot = tot_sum / interval
            avg_adv_mean = adv_mean_sum / interval
            adv_min_int = min(adv_min_list)
            adv_max_int = max(adv_max_list)

            total_act = act_counts_interval.sum()
            if total_act:
                pct = act_counts_interval / total_act * 100.0
                act_dist = ", ".join(f"{i}:{c} ({p:.1f}%)" for i, (c, p) in enumerate(zip(act_counts_interval, pct)))
            else:
                act_dist = "n/a"

            print(f"\n--- DEBUG {ep - interval + 1}-{ep} ---")
            print(f"Avg PolicyLoss: {avg_pl:.4f} | Avg ValueLoss: {avg_vl:.4f} | "
                  f"Avg Entropy: {avg_ent:.4f} | Avg TotalLoss: {avg_tot:.4f}")
            print(f"Advantage μ: {avg_adv_mean:.4f} | min: {adv_min_int:.4f} | max: {adv_max_int:.4f}")
            print(f"Action distribution (last {interval} eps): {act_dist}\n")

            # reset accumulators
            pl_sum = vl_sum = ent_sum = tot_sum = adv_mean_sum = 0.0
            adv_min_list.clear()
            adv_max_list.clear()
            act_counts_interval.fill(0)

    # closing the environment
    env.close()

In [None]:
# ──────────────────────────────────────────────────────────────────────────
# 6. Execution
# ──────────────────────────────────────────────────────────────────────────
main()