<a href="https://colab.research.google.com/github/tcharos/AIDL_B02-Advanced-Topics-in-Deep-Learning/blob/main/advanced_dqn_space_invaders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Advanced DQN Variants for Space Invaders

Implementation of DQN, Double DQN, Dueling DQN, and Prioritized Experience Replay for ALE/SpaceInvaders-v5

## Install Dependencies

In [None]:
!pip install gymnasium[atari,accept-rom-license]
!pip install ale-py
!pip install torch scipy numpy psutil

# Import and verify ALE is available
import ale_py
import gymnasium as gym
gym.register_envs(ale_py)

## Import Libraries

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque, namedtuple
import random
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
import psutil
import os
from datetime import datetime

# Register ALE environments
import ale_py
gym.register_envs(ale_py)

## Google Drive Setup (Optional)

In [None]:
USE_GDRIVE = False  # Set to True to enable Google Drive integration

if USE_GDRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/DQN_SpaceInvaders_Checkpoints'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")
else:
    CHECKPOINT_DIR = './checkpoints'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    print(f"Checkpoints will be saved locally to: {CHECKPOINT_DIR}")

## DQN Network Architectures

In [None]:
class DQN(nn.Module):
    """Standard DQN Network"""
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)


class DuelingDQN(nn.Module):
    """Dueling DQN Network with separate value and advantage streams"""
    def __init__(self, input_shape, n_actions):
        super(DuelingDQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        
        # Value stream
        self.value_stream = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        
        # Advantage stream
        self.advantage_stream = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        value = self.value_stream(conv_out)
        advantage = self.advantage_stream(conv_out)
        
        # Combine value and advantage using the aggregation formula
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return q_values

## Replay Buffers

In [None]:
class ReplayBuffer:
    """Standard Experience Replay Buffer"""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return np.array(state), action, reward, np.array(next_state), done
    
    def __len__(self):
        return len(self.buffer)


class PrioritizedReplayBuffer:
    """Prioritized Experience Replay Buffer"""
    def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_frames=100000):
        self.capacity = capacity
        self.alpha = alpha  # How much prioritization to use (0 = uniform, 1 = full prioritization)
        self.beta_start = beta_start  # Importance sampling weight
        self.beta_frames = beta_frames
        self.frame = 1
        self.buffer = []
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.pos = 0
    
    def beta_by_frame(self, frame_idx):
        """Linearly increase beta from beta_start to 1.0"""
        return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames)
    
    def push(self, state, action, reward, next_state, done):
        max_priority = self.priorities.max() if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.pos] = (state, action, reward, next_state, done)
        
        self.priorities[self.pos] = max_priority
        self.pos = (self.pos + 1) % self.capacity
    
    def sample(self, batch_size):
        N = len(self.buffer)
        if N == self.capacity:
            priorities = self.priorities
        else:
            priorities = self.priorities[:self.pos]
        
        # Calculate sampling probabilities
        probabilities = priorities ** self.alpha
        probabilities /= probabilities.sum()
        
        # Sample indices based on priorities
        indices = np.random.choice(N, batch_size, p=probabilities, replace=False)
        
        # Calculate importance sampling weights
        beta = self.beta_by_frame(self.frame)
        self.frame += 1
        
        weights = (N * probabilities[indices]) ** (-beta)
        weights /= weights.max()
        
        # Get samples
        samples = [self.buffer[idx] for idx in indices]
        state, action, reward, next_state, done = zip(*samples)
        
        return np.array(state), action, reward, np.array(next_state), done, indices, weights
    
    def update_priorities(self, indices, priorities):
        """Update priorities of sampled transitions"""
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority
    
    def __len__(self):
        return len(self.buffer)

## Frame Preprocessing

In [None]:
def preprocess_frame(frame):
    """Convert frame to grayscale, resize to 84x84, and normalize"""
    # Convert to grayscale
    gray = np.dot(frame[..., :3], [0.299, 0.587, 0.114])
    # Resize to 84x84
    resized = zoom(gray, (84/210, 84/160), order=1)
    # Normalize
    normalized = resized / 255.0
    return normalized.astype(np.float32)

## Configuration

In [None]:
# Base Configuration
BASE_CONFIG = {
    # Environment
    'ENV_ID': 'ALE/SpaceInvaders-v5',
    'SEED': 7,
    
    # Network
    'N_FRAMES': 4,
    'N_ACTIONS': 6,
    
    # Training
    'N_EPISODES': 1000,
    'LEARNING_RATE': 0.00025,
    'GAMMA': 0.99,
    'BATCH_SIZE': 32,
    
    # Exploration
    'EPSILON_START': 1.0,
    'EPSILON_END': 0.1,
    'EPSILON_DECAY': 10000,
    
    # Memory
    'BUFFER_SIZE': 10000,
    'TARGET_UPDATE': 1000,
    
    # Checkpointing
    'CHECKPOINT_EVERY': 200,  # Save checkpoint every N episodes
    'USE_GDRIVE': USE_GDRIVE,
    'CHECKPOINT_DIR': CHECKPOINT_DIR,
    
    # DQN Type
    'DQN_TYPE': 'DQN',  # Options: 'DQN', 'DoubleDQN', 'DuelingDQN'
    'USE_PER': False,  # Use Prioritized Experience Replay
    
    # PER hyperparameters (if USE_PER=True)
    'PER_ALPHA': 0.6,
    'PER_BETA_START': 0.4,
    'PER_BETA_FRAMES': 100000,
    'PER_EPSILON': 1e-6  # Small constant to prevent zero priority
}

## Helper Functions

In [None]:
def select_action(state, epsilon, policy_net, n_actions, device):
    """Epsilon-greedy action selection"""
    if random.random() < epsilon:
        return random.randrange(n_actions)
    else:
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
            q_values = policy_net(state_tensor)
            return q_values.max(1)[1].item()


def optimize_model_dqn(policy_net, target_net, optimizer, replay_buffer, batch_size, gamma, device):
    """Standard DQN optimization"""
    if len(replay_buffer) < batch_size:
        return None
    
    states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
    
    states = torch.FloatTensor(states).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.FloatTensor(next_states).to(device)
    dones = torch.FloatTensor(dones).to(device)
    
    # Current Q values
    current_q = policy_net(states).gather(1, actions.unsqueeze(1))
    
    # Next Q values from target network
    next_q = target_net(next_states).max(1)[0].detach()
    target_q = rewards + (1 - dones) * gamma * next_q
    
    # Loss
    loss = F.mse_loss(current_q.squeeze(), target_q)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()


def optimize_model_double_dqn(policy_net, target_net, optimizer, replay_buffer, batch_size, gamma, device):
    """Double DQN optimization"""
    if len(replay_buffer) < batch_size:
        return None
    
    states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
    
    states = torch.FloatTensor(states).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.FloatTensor(next_states).to(device)
    dones = torch.FloatTensor(dones).to(device)
    
    # Current Q values
    current_q = policy_net(states).gather(1, actions.unsqueeze(1))
    
    # Double DQN: use policy net to select actions, target net to evaluate them
    with torch.no_grad():
        next_actions = policy_net(next_states).max(1)[1].unsqueeze(1)
        next_q = target_net(next_states).gather(1, next_actions).squeeze()
    
    target_q = rewards + (1 - dones) * gamma * next_q
    
    # Loss
    loss = F.mse_loss(current_q.squeeze(), target_q)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()


def optimize_model_per(policy_net, target_net, optimizer, replay_buffer, batch_size, gamma, device, dqn_type='DQN'):
    """Optimization with Prioritized Experience Replay"""
    if len(replay_buffer) < batch_size:
        return None
    
    states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(batch_size)
    
    states = torch.FloatTensor(states).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.FloatTensor(next_states).to(device)
    dones = torch.FloatTensor(dones).to(device)
    weights = torch.FloatTensor(weights).to(device)
    
    # Current Q values
    current_q = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
    
    # Calculate target Q values based on DQN type
    with torch.no_grad():
        if dqn_type == 'DoubleDQN':
            next_actions = policy_net(next_states).max(1)[1].unsqueeze(1)
            next_q = target_net(next_states).gather(1, next_actions).squeeze()
        else:  # Standard DQN or Dueling DQN
            next_q = target_net(next_states).max(1)[0]
        
        target_q = rewards + (1 - dones) * gamma * next_q
    
    # Calculate TD errors for priority update
    td_errors = torch.abs(current_q - target_q).detach().cpu().numpy()
    
    # Weighted loss
    loss = (weights * F.mse_loss(current_q, target_q, reduction='none')).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Update priorities
    new_priorities = td_errors + 1e-6  # Add small epsilon to prevent zero priority
    replay_buffer.update_priorities(indices, new_priorities)
    
    return loss.item()


def print_config(config):
    """Print configuration in a formatted way"""
    print("\n" + "="*70)
    print(f"  DQN TYPE: {config['DQN_TYPE']}")
    if config['USE_PER']:
        print(f"  Using Prioritized Experience Replay (PER)")
    print("="*70)
    print("\nConfiguration:")
    print("-"*70)
    for key, value in config.items():
        if key not in ['CHECKPOINT_DIR']:  # Skip long paths
            print(f"  {key:20s}: {value}")
    print("="*70 + "\n")


def save_checkpoint(config, policy_net, target_net, optimizer, episode, avg_score, 
                   episode_rewards, is_best=False):
    """Save model checkpoint"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    dqn_type = config['DQN_TYPE']
    per_suffix = "_PER" if config['USE_PER'] else ""
    
    if is_best:
        filename = f"{dqn_type}{per_suffix}_best.pth"
    else:
        filename = f"{dqn_type}{per_suffix}_ep{episode}_{timestamp}.pth"
    
    filepath = os.path.join(config['CHECKPOINT_DIR'], filename)
    
    torch.save({
        'episode': episode,
        'config': config,
        'policy_net_state_dict': policy_net.state_dict(),
        'target_net_state_dict': target_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'episode_rewards': episode_rewards,
        'avg_score': avg_score,
        'timestamp': timestamp
    }, filepath)
    
    print(f"Checkpoint saved: {filename}")
    return filepath

## Generic Training Function

In [None]:
def train_dqn(config, policy_net, target_net, optimizer, replay_buffer, device='cpu'):
    """
    Generic DQN training function that works with all variants.
    
    Supports:
    - Standard DQN
    - Double DQN
    - Dueling DQN
    - Prioritized Experience Replay (PER)
    """
    # Print configuration
    print_config(config)
    
    # Create environment
    env = gym.make(config['ENV_ID'])
    if config.get('SEED') is not None:
        env.reset(seed=config['SEED'])
    
    n_actions = config['N_ACTIONS']
    episode_rewards = []
    steps = 0
    best_avg_score = -float('inf')
    
    # Select optimization function based on config
    if config['USE_PER']:
        optimize_fn = lambda: optimize_model_per(
            policy_net, target_net, optimizer, replay_buffer,
            config['BATCH_SIZE'], config['GAMMA'], device, config['DQN_TYPE']
        )
    elif config['DQN_TYPE'] == 'DoubleDQN':
        optimize_fn = lambda: optimize_model_double_dqn(
            policy_net, target_net, optimizer, replay_buffer,
            config['BATCH_SIZE'], config['GAMMA'], device
        )
    else:  # Standard DQN or Dueling DQN
        optimize_fn = lambda: optimize_model_dqn(
            policy_net, target_net, optimizer, replay_buffer,
            config['BATCH_SIZE'], config['GAMMA'], device
        )
    
    print("Starting training...\n")
    
    for episode in range(config['N_EPISODES']):
        state, _ = env.reset()
        state = preprocess_frame(state)
        state_stack = deque([state] * config['N_FRAMES'], maxlen=config['N_FRAMES'])
        
        episode_reward = 0
        done = False
        
        while not done:
            # Epsilon decay
            epsilon = config['EPSILON_END'] + (config['EPSILON_START'] - config['EPSILON_END']) * \
                      np.exp(-1. * steps / config['EPSILON_DECAY'])
            
            # Select action
            state_array = np.array(state_stack)
            action = select_action(state_array, epsilon, policy_net, n_actions, device)
            
            # Take step
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            next_state = preprocess_frame(next_state)
            next_state_stack = state_stack.copy()
            next_state_stack.append(next_state)
            
            # Store transition
            replay_buffer.push(
                np.array(state_stack),
                action,
                reward,
                np.array(next_state_stack),
                float(done)
            )
            
            state_stack = next_state_stack
            episode_reward += reward
            steps += 1
            
            # Optimize
            optimize_fn()
            
            # Update target network
            if steps % config['TARGET_UPDATE'] == 0:
                target_net.load_state_dict(policy_net.state_dict())
        
        episode_rewards.append(episode_reward)
        
        # Print progress every 10 episodes
        if episode % 10 == 0:
            avg_score = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards)
            
            # Get memory info
            mem = psutil.virtual_memory()
            gpu_mem = 0.0
            gpu_mem_str = "N/A"
            if device.type == 'cuda':
                gpu_mem = torch.cuda.memory_allocated() / 1024**3
                gpu_mem_str = f"{gpu_mem:.2f}GB"
            elif device.type == 'mps':
                gpu_mem_str = "Active"  # MPS doesn't expose memory stats
            else:
                gpu_mem_str = "N/A"
            if torch.cuda.is_available():
                gpu_mem = torch.cuda.memory_allocated() / 1024**3  # Convert to GB
            
            print(f'Episode {episode}\tScore: {episode_reward:.1f}\tAvg: {avg_score:.2f}\tEps: {epsilon:.3f}\tSteps: {steps}')
            print(f'RAM: {mem.percent:.1f}% | GPU: {gpu_mem_str} | Buffer: {len(replay_buffer)}/{config["BUFFER_SIZE"]}')
            
            # Save best model
            if avg_score > best_avg_score:
                best_avg_score = avg_score
                save_checkpoint(config, policy_net, target_net, optimizer, episode, 
                              avg_score, episode_rewards, is_best=True)
        
        # Save checkpoint every N episodes
        if episode > 0 and episode % config['CHECKPOINT_EVERY'] == 0:
            avg_score = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards)
            save_checkpoint(config, policy_net, target_net, optimizer, episode, 
                          avg_score, episode_rewards, is_best=False)
    
    env.close()
    print("\nTraining completed!")
    print(f"Best average score: {best_avg_score:.2f}")
    return episode_rewards

## Initialize and Train - Standard DQN

In [None]:
# Configuration for Standard DQN
CONFIG_DQN = BASE_CONFIG.copy()
CONFIG_DQN['DQN_TYPE'] = 'DQN'
CONFIG_DQN['USE_PER'] = False

# Setup
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA GPU")
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
else:
    device = torch.device("cpu")
    print("Using CPU (will be slower)")
    
print(f"Using device: {device}")

# Set random seeds
random.seed(CONFIG_DQN['SEED'])
np.random.seed(CONFIG_DQN['SEED'])
torch.manual_seed(CONFIG_DQN['SEED'])

# Networks
policy_net_dqn = DQN((CONFIG_DQN['N_FRAMES'], 84, 84), CONFIG_DQN['N_ACTIONS']).to(device)
target_net_dqn = DQN((CONFIG_DQN['N_FRAMES'], 84, 84), CONFIG_DQN['N_ACTIONS']).to(device)
target_net_dqn.load_state_dict(policy_net_dqn.state_dict())

optimizer_dqn = optim.Adam(policy_net_dqn.parameters(), lr=CONFIG_DQN['LEARNING_RATE'])
replay_buffer_dqn = ReplayBuffer(CONFIG_DQN['BUFFER_SIZE'])

# Train
rewards_dqn = train_dqn(CONFIG_DQN, policy_net_dqn, target_net_dqn, 
                        optimizer_dqn, replay_buffer_dqn, device)

## Initialize and Train - Double DQN

In [None]:
# Configuration for Double DQN
CONFIG_DDQN = BASE_CONFIG.copy()
CONFIG_DDQN['DQN_TYPE'] = 'DoubleDQN'
CONFIG_DDQN['USE_PER'] = False
CONFIG_DDQN['SEED'] = 42  # Different seed

# Set random seeds
random.seed(CONFIG_DDQN['SEED'])
np.random.seed(CONFIG_DDQN['SEED'])
torch.manual_seed(CONFIG_DDQN['SEED'])

# Networks
policy_net_ddqn = DQN((CONFIG_DDQN['N_FRAMES'], 84, 84), CONFIG_DDQN['N_ACTIONS']).to(device)
target_net_ddqn = DQN((CONFIG_DDQN['N_FRAMES'], 84, 84), CONFIG_DDQN['N_ACTIONS']).to(device)
target_net_ddqn.load_state_dict(policy_net_ddqn.state_dict())

optimizer_ddqn = optim.Adam(policy_net_ddqn.parameters(), lr=CONFIG_DDQN['LEARNING_RATE'])
replay_buffer_ddqn = ReplayBuffer(CONFIG_DDQN['BUFFER_SIZE'])

# Train
rewards_ddqn = train_dqn(CONFIG_DDQN, policy_net_ddqn, target_net_ddqn, 
                         optimizer_ddqn, replay_buffer_ddqn, device)

## Initialize and Train - Dueling DQN

In [None]:
# Configuration for Dueling DQN
CONFIG_DuelDQN = BASE_CONFIG.copy()
CONFIG_DuelDQN['DQN_TYPE'] = 'DuelingDQN'
CONFIG_DuelDQN['USE_PER'] = False
CONFIG_DuelDQN['SEED'] = 123  # Different seed

# Set random seeds
random.seed(CONFIG_DuelDQN['SEED'])
np.random.seed(CONFIG_DuelDQN['SEED'])
torch.manual_seed(CONFIG_DuelDQN['SEED'])

# Networks - Use DuelingDQN architecture
policy_net_dueling = DuelingDQN((CONFIG_DuelDQN['N_FRAMES'], 84, 84), CONFIG_DuelDQN['N_ACTIONS']).to(device)
target_net_dueling = DuelingDQN((CONFIG_DuelDQN['N_FRAMES'], 84, 84), CONFIG_DuelDQN['N_ACTIONS']).to(device)
target_net_dueling.load_state_dict(policy_net_dueling.state_dict())

optimizer_dueling = optim.Adam(policy_net_dueling.parameters(), lr=CONFIG_DuelDQN['LEARNING_RATE'])
replay_buffer_dueling = ReplayBuffer(CONFIG_DuelDQN['BUFFER_SIZE'])

# Train
rewards_dueling = train_dqn(CONFIG_DuelDQN, policy_net_dueling, target_net_dueling, 
                            optimizer_dueling, replay_buffer_dueling, device)

## Initialize and Train - DQN with PER

In [None]:
# Configuration for DQN with Prioritized Experience Replay
CONFIG_PER = BASE_CONFIG.copy()
CONFIG_PER['DQN_TYPE'] = 'DQN'
CONFIG_PER['USE_PER'] = True
CONFIG_PER['SEED'] = 456  # Different seed

# Set random seeds
random.seed(CONFIG_PER['SEED'])
np.random.seed(CONFIG_PER['SEED'])
torch.manual_seed(CONFIG_PER['SEED'])

# Networks
policy_net_per = DQN((CONFIG_PER['N_FRAMES'], 84, 84), CONFIG_PER['N_ACTIONS']).to(device)
target_net_per = DQN((CONFIG_PER['N_FRAMES'], 84, 84), CONFIG_PER['N_ACTIONS']).to(device)
target_net_per.load_state_dict(policy_net_per.state_dict())

optimizer_per = optim.Adam(policy_net_per.parameters(), lr=CONFIG_PER['LEARNING_RATE'])
replay_buffer_per = PrioritizedReplayBuffer(
    CONFIG_PER['BUFFER_SIZE'],
    alpha=CONFIG_PER['PER_ALPHA'],
    beta_start=CONFIG_PER['PER_BETA_START'],
    beta_frames=CONFIG_PER['PER_BETA_FRAMES']
)

# Train
rewards_per = train_dqn(CONFIG_PER, policy_net_per, target_net_per, 
                        optimizer_per, replay_buffer_per, device)

## Consolidated Results Storage

In [None]:
# Store all results
all_results = {
    'DQN': rewards_dqn,
    'DoubleDQN': rewards_ddqn,
    'DuelingDQN': rewards_dueling,
    'DQN_PER': rewards_per
}

## Plotting Functions

In [None]:
def plot_consolidated_results(results_dict, window=100, figsize=(14, 8)):
    """
    Plot consolidated training progress for multiple DQN runs.
    """
    plt.figure(figsize=figsize)
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    
    for idx, (name, rewards) in enumerate(results_dict.items()):
        # Calculate moving average
        if len(rewards) >= window:
            moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
            episodes = range(window-1, len(rewards))
            final_avg = np.mean(rewards[-100:])
            
            # Plot moving average
            plt.plot(episodes, moving_avg, 
                    label=f'{name} (Avg={final_avg:.2f})',
                    color=colors[idx % len(colors)],
                    linewidth=2)
    
    # Add goal lines
    plt.axhline(y=500, color='green', linestyle='--', linewidth=2, label='Goal: 500', alpha=0.7)
    plt.axhline(y=400, color='red', linestyle='--', linewidth=2, label='Goal: 400', alpha=0.7)
    
    plt.xlabel('Episode #', fontsize=12)
    plt.ylabel(f'Average Score ({window}-Game Window)', fontsize=12)
    plt.title(f'Consolidated DQN Training Progress ({window}-Episode Moving Average)', fontsize=14, fontweight='bold')
    plt.legend(loc='best', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


def plot_individual_results(rewards, name, window=100, figsize=(12, 6)):
    """Plot individual run results"""
    plt.figure(figsize=figsize)
    plt.plot(rewards, alpha=0.6, label='Episode Reward')
    
    # Calculate moving average
    moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
    plt.plot(range(window-1, len(rewards)), moving_avg, label=f'Moving Average ({window})', linewidth=2)
    
    plt.axhline(y=500, color='r', linestyle='--', label='Target (500)')
    plt.axhline(y=400, color='orange', linestyle='--', label='Minimum (400)')
    
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title(f'{name} - Training Progress on Space Invaders')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Print final statistics
    final_avg = np.mean(rewards[-100:])
    print(f"\n{name} - Final average reward (last 100 episodes): {final_avg:.2f}")

## Plot All Results

In [None]:
# Plot consolidated results
plot_consolidated_results(all_results, window=100)

# Plot individual results
for name, rewards in all_results.items():
    plot_individual_results(rewards, name, window=100)

## Save Final Results

In [None]:
# Save all results to a single file
import pickle

results_file = os.path.join(CHECKPOINT_DIR, 'all_results.pkl')
with open(results_file, 'wb') as f:
    pickle.dump(all_results, f)

print(f"All results saved to: {results_file}")

# Print summary statistics
print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)
for name, rewards in all_results.items():
    final_avg = np.mean(rewards[-100:])
    max_reward = max(rewards)
    print(f"{name:20s} - Avg (last 100): {final_avg:6.2f} | Max: {max_reward:6.1f}")
print("="*70)