# Training Atari Donkey Kong with DQN

This notebook implements a DQN agent to play the Atari game Donkey Kong, with the following features:
- Parallel training across multiple game environments
- Game frame preprocessing for improved training efficiency
- Prioritized experience replay to enhance training quality
- Training log recording
- Regular model saving
- Periodic evaluation and game video recording

## 1. Installing Required Dependencies

In [None]:
# Install required libraries
%pip install stable-baselines3[extra] gymnasium[atari] numpy matplotlib opencv-python tensorboard autorom[accept-rom-license] ipywidgets gymnasium[other]

## 2. Import Libraries

In [None]:

import os
import random
import time
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from gymnasium.wrappers import RecordVideo, FrameStackObservation
import torch
print(torch.cuda.is_available())
torch.cuda.empty_cache()
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque, namedtuple
from tqdm.notebook import tqdm
import cv2
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.atari_wrappers import AtariWrapper
import ale_py
from gymnasium import spaces

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

## 3. Configuration Parameters

In [None]:
# Environment parameters
ENV_NAME = "ALE/DonkeyKong-v5"
NUM_ENVS = 4  # Number of parallel environments
FRAME_SKIP = 4  # Frame skip, make decisions every 4 frames
ALLOWED_ACTIONS = [0,1,2,3,4,5,11,12]  # Valid actions

# Model parameters
BATCH_SIZE = 64
GAMMA = 0.99  # Discount factor
LEARNING_RATE = 0.0001
MEMORY_SIZE = 100000  # Experience replay buffer size
TARGET_UPDATE = 10000  # Target network update frequency

# Training parameters
NUM_FRAMES = 10_000_000  # Total training frames
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY = 6_000_000
DEMO_PATH = "./demo/dk_demo_20250325_192148.pkl"

# Save and evaluation parameters
SAVE_INTERVAL = 100_000  # Model save interval (frames)
EVAL_INTERVAL = 20_000   # Model evaluation interval (frames)
EVAL_EPISODES = 3        # Number of episodes per evaluation

# Create directories for saving models and logs
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_PATH = f"./models/donkey_kong_{timestamp}"
LOG_PATH = f"./logs/donkey_kong_{timestamp}"
VIDEO_PATH = f"./videos/donkey_kong_{timestamp}"

for path in [SAVE_PATH, LOG_PATH, VIDEO_PATH]:
    if not os.path.exists(path):
        os.makedirs(path)

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Device: {device}")

## 4. Environment Preprocessing

In [None]:
# Restrict action space, reduce agent's useless actions
class ActionRestrictWrapper(gym.ActionWrapper):
    def __init__(self, env, allowed_actions):
        super().__init__(env)
        self.allowed_actions = allowed_actions
        self.action_space = spaces.Discrete(len(self.allowed_actions))

    def action(self, act):
        # Map the action index output by the agent to the original action number
        return self.allowed_actions[act]

    def reverse_action(self, act):
        return self.allowed_actions.index(act)

# Wrapper that forces the first action to be FIRE
class ForceFirstFireWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.first_action_done = False
    
    def reset(self, **kwargs):
        self.first_action_done = False
        return self.env.reset(**kwargs)
    
    def step(self, action):
        # If it's the first action and not RIGHT FIRE, force replace it with RIGHT FIRE
        if not self.first_action_done:
            self.first_action_done = True
            # Use RIGHT FIRE action
            action_idx = ALLOWED_ACTIONS.index(11)
            return self.env.step(action_idx)
        return self.env.step(action)

# Function to detect player position based on color
def get_agent_position(frame): 
    """ Detect player position by color, return (x, y) coordinates. Returns None if not detected. """
    # Ensure frame is numpy array with correct format
    if frame is None:
        return None
    
    # Target color (BGR format)
    target_bgr = np.array([194, 64, 82], dtype=np.uint8)

    # Tolerance range (adjustable, usually 20~40 works well)
    tolerance = 30
    lower = np.array([max(0, c - tolerance) for c in target_bgr], dtype=np.uint8)
    upper = np.array([min(255, c + tolerance) for c in target_bgr], dtype=np.uint8)

    # Generate mask
    mask = cv2.inRange(frame, lower, upper)
    
    # Find contours
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return None

    # Find the largest contour by area
    largest = max(contours, key=cv2.contourArea)
    M = cv2.moments(largest)

    if M["m00"] == 0:
        return None

    cx = int(M["m10"] / M["m00"])
    cy = int(M["m01"] / M["m00"])

    return (cx, cy)

# Custom video display wrapper for showing actions and agent position
class VideoDisplayWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.current_action = None
        self.action_names = {
            0: "",
            1: "Jump",
            2: "Up",
            3: "Right",
            4: "Left",
            5: "Down",
            11: "Jump R",
            12: "Jump L"
        }
        
    def step(self, action):
        # Record current action
        self.current_action = action
        return self.env.step(action)
    
    def reset(self, **kwargs):
        self.current_action = None
        return self.env.reset(**kwargs)
    
    def render(self):
        # Get original rendered frame
        frame = self.env.render()
        
        if frame is None:
            return None
        
        # Ensure frame is RGB format
        if len(frame.shape) == 2:
            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)

        # 1. Display current action in top right corner
        if self.current_action is not None:
            action_name = self.action_names.get(self.current_action, f"ACTION_{self.current_action}")
            cv2.putText(frame, 
                       action_name, 
                       (frame.shape[1] - 85, 28), # Top right position
                       cv2.FONT_HERSHEY_SIMPLEX, 
                       0.3,
                       (255, 255, 255), # White text
                       1, 
                       cv2.LINE_AA)
        
        return frame

# Custom reward wrapper to adjust rewards based on agent position changes
class CustomRewardWrapper(gym.Wrapper):
    def __init__(self, env, y_static_penalty=0.1, up_success_reward=10,
                 up_fail_penalty=0, x_static_penalty=0,
                 y_threshold=20, x_threshold=3, 
                 y_static_frames=30, x_static_frames=30):
        super().__init__(env)
        # Reward parameters
        self.y_static_penalty = y_static_penalty  # Vertical static penalty
        self.up_success_reward = up_success_reward  # Successful upward movement reward
        self.up_fail_penalty = up_fail_penalty  # Failed upward movement penalty
        self.x_static_penalty = x_static_penalty  # Horizontal static penalty
        
        # Threshold parameters
        self.y_threshold = y_threshold  # Vertical movement threshold
        self.x_threshold = x_threshold  # Horizontal movement threshold
        self.y_static_frames = y_static_frames  # Vertical static frame count
        self.x_static_frames = x_static_frames  # Horizontal static frame count
        
        # State tracking
        self.prev_positions = []  # Store past positions [(x, y), ...]
        self.y_static_count = 0  # Vertical static counter
        self.x_static_count = 0  # Horizontal static counter
        self.prev_action = None  # Previous action
    
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        # Reset state tracking
        self.prev_positions = []
        self.y_static_count = 0
        self.x_static_count = 0
        self.prev_action = None
        return obs, info
    
    def step(self, action):
        # Record current action
        self.prev_action = action
        
        # Execute environment step
        obs, reward, terminated, truncated, info = self.env.step(action)
        
        # Extract RGB frame from observation
        frame = None
        try:
            if isinstance(obs, np.ndarray):
                if len(obs.shape) == 4:  # (stack, height, width, channel)
                    frame = obs[-1]  # Last frame
                elif len(obs.shape) == 3:  # (height, width, channel)
                    frame = obs
                elif len(obs.shape) == 2:  # (height, width)
                    frame = obs
            elif hasattr(obs, '__getitem__'):
                # For FrameStackObservation
                try:
                    frame = obs[-1]
                except:
                    try:
                        frame = obs[3]  # Assuming 4 frame stack
                    except:
                        pass
            
            # If above attempts fail, try rendering the environment
            if frame is None:
                try:
                    frame = self.env.render()
                except:
                    pass
        except Exception as e:
            print(f"Failed to extract frame from observation: {e}")
            frame = None
        
        # Detect Agent position
        position = None
        if frame is not None:
            position = get_agent_position(frame)
        
        # If position detected, update position history and calculate reward adjustment
        additional_reward = 0
        
        if position is not None:
            x, y = position
            self.prev_positions.append((x, y))
            
            # Keep history at reasonable size
            if len(self.prev_positions) > max(self.y_static_frames, self.x_static_frames):
                self.prev_positions.pop(0)
            
            # Need at least two position records to determine movement
            if len(self.prev_positions) >= 2:
                prev_x, prev_y = self.prev_positions[-2]
                
                # 1. Check if vertically static
                if abs(y - prev_y) < self.y_threshold:
                    self.y_static_count += 1
                    if self.y_static_count >= self.y_static_frames:
                        # Linearly increasing penalty
                        additional_reward -= self.y_static_penalty * (self.y_static_count - self.y_static_frames + 1)
                else:
                    self.y_static_count = 0
                
                # 2. Check UP action effect
                if self.prev_action == 2:  # Assuming 2 is UP action
                    if (prev_y - y) > self.y_threshold:  # Successful upward movement
                        additional_reward += self.up_success_reward
                    else:  # Failed upward movement
                        additional_reward -= self.up_fail_penalty
                
                # 3. Check if horizontally static
                if abs(x - prev_x) < self.x_threshold:
                    self.x_static_count += 1
                    if self.x_static_count >= self.x_static_frames:
                        # Linearly increasing penalty
                        additional_reward -= self.x_static_penalty * (self.x_static_count - self.x_static_frames + 1)
                else:
                    self.x_static_count = 0
        
        # Apply reward adjustment
        adjusted_reward = reward + additional_reward
        
        return obs, adjusted_reward, terminated, truncated, info

# Function to create preprocessed environment
def make_env(env_id, idx, capture_video=False, run_name=None):
    def thunk():
        import ale_py
        
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            # Add video display wrapper
            env = VideoDisplayWrapper(env)
            env = RecordVideo(
                env,
                VIDEO_PATH,
                episode_trigger=lambda x: True,
                name_prefix=f"{run_name}"
            )
        else:
            env = gym.make(env_id)
        
        env = ActionRestrictWrapper(env, ALLOWED_ACTIONS)
        env = ForceFirstFireWrapper(env)
        env = CustomRewardWrapper(env)
        env = AtariWrapper(env, terminal_on_life_loss=True, frame_skip=FRAME_SKIP)
        env = FrameStackObservation(env, 4)  # Stack 4 frames to capture temporal information
            
        return env
    return thunk

# Create parallel environments
def make_vec_env(env_id, num_envs, seed=SEED):
    env_fns = [make_env(env_id, i) for i in range(num_envs)]
    envs = SubprocVecEnv(env_fns)
    envs.seed(seed)
    return envs

## 5. DQN Network Model

In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        # Input shape: (batch, stack_frames, height, width)
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

## 6. Prioritized Experience Replay

In [None]:
# Use prioritized experience replay to improve training efficiency
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_frames=100000):
        self.capacity = capacity
        self.alpha = alpha  # Controls the degree of prioritization
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame = 1  # Current frame, used for beta calculation
        self.buffer = []
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.position = 0
        self.Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
    
    def beta_by_frame(self, frame_idx):
        # beta increases linearly from beta_start to 1.0
        return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames)
    
    def push(self, *args):
        # Add new experience
        max_prio = np.max(self.priorities) if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append(self.Transition(*args))
        else:
            self.buffer[self.position] = self.Transition(*args)
        
        self.priorities[self.position] = max_prio
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.position]
        
        # Calculate sampling probabilities
        probs = prios ** self.alpha
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]
        
        # Calculate importance sampling weights
        beta = self.beta_by_frame(self.frame)
        self.frame += 1
        
        weights = (len(self.buffer) * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights = torch.tensor(weights, device=device, dtype=torch.float32)
        
        # Convert to batch processing format
        batch = self.Transition(*zip(*samples))
        states = torch.cat(batch.state)
        actions = torch.tensor(batch.action, device=device)
        rewards = torch.tensor(batch.reward, device=device, dtype=torch.float32)
        next_states = torch.cat(batch.next_state)
        dones = torch.tensor(batch.done, device=device, dtype=torch.float32)
        
        return states, actions, rewards, next_states, dones, indices, weights
    
    def update_priorities(self, indices, priorities):
        # Update priorities
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority
    
    def __len__(self):
        return len(self.buffer)

## 7. DQN Agent

In [None]:
class DQNAgent:
    def __init__(self, state_shape, n_actions):
        self.state_shape = state_shape
        self.n_actions = n_actions
        
        # Create policy network and target network
        self.policy_net = DQN(state_shape, n_actions).to(device)
        self.target_net = DQN(state_shape, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # Target network doesn't need gradient calculation
        
        # Setup optimizer
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        
        # Create experience replay buffer
        self.memory = PrioritizedReplayBuffer(MEMORY_SIZE)
        
        # Training related parameters
        self.steps_done = 0
        self.epsilon = EPSILON_START
        
        # Logger
        self.writer = SummaryWriter(LOG_PATH)
    
    def select_action(self, state, eval_mode=False):
        # ε-greedy policy for action selection
        sample = random.random()
        # In evaluation mode, always choose the best action
        if eval_mode:
            eps_threshold = 0.05  # Use small epsilon in eval mode for some exploration
        else:
            # Linear epsilon decay
            self.epsilon = max(EPSILON_END, EPSILON_START - self.steps_done / EPSILON_DECAY)
            eps_threshold = self.epsilon
            
        if sample > eps_threshold:
            with torch.no_grad():
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.n_actions)]], device=device, dtype=torch.long)
    
    def optimize_model(self):
        if len(self.memory) < BATCH_SIZE:
            return 0.0  # Not enough samples in buffer
        
        # Sample from experience replay buffer
        states, actions, rewards, next_states, dones, indices, weights = self.memory.sample(BATCH_SIZE)
        
        # Calculate current Q values
        q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # Use Double DQN to calculate next state Q values
        # Use policy network to select actions
        next_actions = self.policy_net(next_states).max(1)[1].unsqueeze(1)
        # Use target network to evaluate actions
        next_q_values = self.target_net(next_states).gather(1, next_actions).squeeze(1)
        # Set next Q values for terminal states to 0
        next_q_values = next_q_values * (1 - dones)
        # Calculate target Q values
        target_q_values = rewards + GAMMA * next_q_values
        
        # Calculate loss (TD error)
        td_error = torch.abs(q_values - target_q_values).detach().cpu().numpy()
        loss = F.smooth_l1_loss(q_values, target_q_values, reduction='none') * weights
        loss = loss.mean()
        
        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        # Gradient clipping to prevent explosion
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10)
        self.optimizer.step()
        
        # Update priorities
        self.memory.update_priorities(indices, td_error + 1e-5)
        
        return loss.item()
    
    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
    
    def save_model(self, path):
        torch.save({
            'policy_net': self.policy_net.state_dict(),
            'target_net': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'steps_done': self.steps_done,
            'epsilon': self.epsilon
        }, path)
    
    def load_model(self, path):
        checkpoint = torch.load(path)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.target_net.load_state_dict(checkpoint['target_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.steps_done = checkpoint['steps_done']
        self.epsilon = checkpoint['epsilon']

## 8. Observation Preprocessing and State Transformation Functions

In [None]:
def preprocess_observation(obs):
    # Convert stacked 4 frames to PyTorch input format
    frames = np.array(obs).squeeze(-1)
    tensor = torch.tensor(frames, dtype=torch.float32, device=device).unsqueeze(0)
    return tensor / 255.0  # Normalize

def preprocess_batch_observation(obs):
    # Process batch observation data
    frames = np.array(obs).squeeze(-1)
    tensor = torch.tensor(frames, dtype=torch.float32, device=device)
    return tensor / 255.0  # Normalize

## 9. Evaluation Function

In [None]:
def evaluate(agent, env_id, num_episodes=5, video_prefix="evaluation"):
    episode_rewards = []
    
    # Create a new environment instance for each evaluation episode
    for i in range(num_episodes):
        # Create new environment instance for each game round
        env = make_env(env_id, 0, capture_video=True, run_name=f"{video_prefix}_episode_{i}")()
        
        obs, _ = env.reset()
        obs_tensor = preprocess_observation(obs)
        done = False
        total_reward = 0.0
        
        while not done:
            action = agent.select_action(obs_tensor, eval_mode=True).item()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward
            
            obs = next_obs
            obs_tensor = preprocess_observation(obs)

        episode_rewards.append(total_reward)
        env.close()  # Close the environment after each episode

        # Modify video filename, remove extra suffix
        video_path = os.path.abspath(os.path.join(VIDEO_PATH, f"{video_prefix}_episode_{i}-episode-0.mp4"))
        if os.path.exists(video_path):
            new_video_path = video_path.replace("-episode-0.mp4", ".mp4")
            os.rename(video_path, new_video_path)
    
    return np.mean(episode_rewards), np.std(episode_rewards), episode_rewards

## 10. Training Functions

In [None]:
def train(agent, envs, num_frames):
    # Initialize environment and progress bar
    obs = envs.reset()
    obs_tensor = preprocess_batch_observation(obs)
    
    losses = []
    all_rewards = []
    episode_reward = np.zeros(NUM_ENVS)
    episode_length = np.zeros(NUM_ENVS)
    
    progress_bar = tqdm(range(1, num_frames + 1), desc="Training")
    
    # Training loop
    for frame_idx in progress_bar:
        # Select actions
        actions = []
        for i in range(NUM_ENVS):
            action = agent.select_action(obs_tensor[i:i+1])
            actions.append(action.item())
        
        # Execute actions
        next_obs, rewards, terminateds, truncateds = envs.step(actions)
        
        # Process data for each environment
        dones = []
        for t, tr in zip(terminateds, truncateds):
            if isinstance(tr, dict):
                done = t or tr.get("TimeLimit.truncated", False)
            else:
                done = t or tr
            dones.append(done)

        next_obs_tensor = preprocess_batch_observation(next_obs)
        
        # Update cumulative rewards and episode length
        episode_reward += rewards
        episode_length += 1
        
        # Store data in experience replay buffer
        for i in range(NUM_ENVS):
            agent.memory.push(
                obs_tensor[i:i+1],
                actions[i],
                rewards[i],
                next_obs_tensor[i:i+1],
                float(dones[i])
            )
        
        # Update observations
        obs = next_obs
        obs_tensor = next_obs_tensor
        
        # Optimize model
        loss = agent.optimize_model()
        losses.append(loss)
        
        # Check for episode completion
        for i, done in enumerate(dones):
            if done:
                # Record episode results
                agent.writer.add_scalar("train/episode_reward", episode_reward[i], agent.steps_done)
                agent.writer.add_scalar("train/episode_length", episode_length[i], agent.steps_done)
                all_rewards.append(episode_reward[i])
                
                # Reset episode statistics
                episode_reward[i] = 0
                episode_length[i] = 0
        
        # Update target network
        if frame_idx % TARGET_UPDATE == 0:
            agent.update_target_network()
        
        # Record training statistics
        if frame_idx % 1000 == 0:
            mean_reward = np.mean(all_rewards[-100:]) if all_rewards else 0
            mean_loss = np.mean(losses[-100:]) if losses else 0
            agent.writer.add_scalar("train/epsilon", agent.epsilon, frame_idx)
            agent.writer.add_scalar("train/loss", mean_loss, frame_idx)
            agent.writer.add_scalar("train/mean_reward_100", mean_reward, frame_idx)
            
            progress_bar.set_postfix({
                "avg_reward": f"{mean_reward:.2f}",
                "loss": f"{mean_loss:.5f}",
                "epsilon": f"{agent.epsilon:.2f}"
            })
        
        # Save model
        if frame_idx % SAVE_INTERVAL == 0:
            save_path = os.path.join(SAVE_PATH, f"model_{frame_idx}.pt")
            agent.save_model(save_path)
            print(f"\nFrame {frame_idx}: Model saved to {save_path}")
        
        # Evaluate model
        if frame_idx % EVAL_INTERVAL == 0:
            print(f"\nFrame {frame_idx}: Evaluating...")
            eval_reward, eval_std, _ = evaluate(
                agent,
                ENV_NAME,
                num_episodes=EVAL_EPISODES,
                video_prefix=f"eval_{frame_idx}"
            )
            agent.writer.add_scalar("eval/mean_reward", eval_reward, frame_idx)
            agent.writer.add_scalar("eval/reward_std", eval_std, frame_idx)
            print(f"Evaluation results: Mean reward = {eval_reward:.2f} ± {eval_std:.2f}")
        
        # Update agent's step counter
        agent.steps_done += 1
    
    # Save final model after training
    final_path = os.path.join(SAVE_PATH, "model_final.pt")
    agent.save_model(final_path)
    print(f"\nFinal model saved to: {final_path}")


def load_demonstrations(agent, filepath):
    """Load demonstration trajectories from file and inject them into the agent's replay buffer"""
    import pickle

    # Use the same ALLOWED_ACTIONS as in training
    ALLOWED_ACTIONS = [0,1,2,3,4,5,11,12]
    action_to_index = {a: i for i, a in enumerate(ALLOWED_ACTIONS)}

    with open(filepath, 'rb') as f:
        all_trajectories = pickle.load(f)

    count = 0
    skipped = 0
    for traj in all_trajectories:
        for s, a, r, ns, d in traj:
            if a not in action_to_index:
                print(f"Skipping illegal action: {a}")
                skipped += 1
                continue
            
            action_index = action_to_index[a]  # Map to 0~7

            agent.memory.push(
                s.to(device),
                action_index,
                r.to(device),
                ns.to(device),
                d.to(device)
            )
            count += 1

    print(f"\nDemonstrations imported, {count} transitions added to replay buffer. Skipped {skipped} illegal actions.\n")

## 11. Main Training Process

In [None]:
# Create parallel environments
envs = make_vec_env(ENV_NAME, NUM_ENVS)

# Get environment information
obs_shape = (4, 84, 84)  # 4 stacked frames, each 84x84
n_actions = envs.action_space.n

print(f"\nObservation shape: {obs_shape}")
print(f"Action space size: {n_actions}")


In [None]:

# Create DQN agent
agent = DQNAgent(obs_shape, n_actions)

# Load demonstration trajectories if available
if DEMO_PATH and os.path.exists(DEMO_PATH):
    Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
    load_demonstrations(agent, DEMO_PATH)
else:
    print(f"Demonstration file not found at {DEMO_PATH}, skipping demonstration loading.")


In [None]:
# Start training
print("\nTraining started...\n")

# Enable cudnn benchmark mode to improve training speed
torch.backends.cudnn.benchmark = True

train(agent, envs, NUM_FRAMES)

# Close environments
envs.close()

## 12. Load and Test Trained Models

In [None]:
def play_and_record_video(model_path, env_id, num_episodes=5):
    # Create agent and load model
    obs_shape = (4, 84, 84)  # 4 stacked frames, each 84x84
    env = make_env(env_id, 0)()
    n_actions = env.action_space.n
    agent = DQNAgent(obs_shape, n_actions)
    agent.load_model(model_path)
    
    # Test trained agent
    mean, std, rewards = evaluate(agent, env_id, num_episodes=num_episodes, video_prefix="final_test")

    for i, reward in enumerate(rewards):
        print(f"Episode {i+1}: Reward = {reward}")
        
    print(f"\nAverage reward: {mean:.2f} ± {std:.2f}")

In [None]:
# Load and test the final model
# model_path = os.path.join(SAVE_PATH, "model_final.pt")
model_path = "./model_final.pt"
play_and_record_video(model_path, ENV_NAME, num_episodes=5)

## Continue Training

In [None]:
import pandas as pd

def continue_training(model_path, envs, additional_frames=1_000_000, frames_per_session=200_000):
    import matplotlib.pyplot as plt
    
    # Create agent with the same configuration
    obs_shape = (4, 84, 84)
    n_actions = envs.action_space.n
    agent = DQNAgent(obs_shape, n_actions)
    
    # Load the existing model
    print(f"Loading model from {model_path}")
    agent.load_model(model_path)
    
    # Create new save paths for the continued training
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = f"./models/donkey_kong_continued_{timestamp}"
    log_path = f"./logs/donkey_kong_continued_{timestamp}"
    video_path = f"./videos/donkey_kong_continued_{timestamp}"
    
    for path in [save_path, log_path, video_path]:
        if not os.path.exists(path):
            os.makedirs(path)
    
    # Setup logging
    agent.writer = SummaryWriter(log_path)
    
    # Modify the train function to work with our custom save path
    global SAVE_PATH, LOG_PATH, VIDEO_PATH
    original_save_path = SAVE_PATH
    original_log_path = LOG_PATH
    original_video_path = VIDEO_PATH
    
    # Track training statistics
    episode_count = 0
    best_reward = float('-inf')
    
    # Use our custom paths for this training session
    SAVE_PATH = save_path
    LOG_PATH = log_path
    VIDEO_PATH = video_path
    
    # Metrics tracking for plotting
    metrics = {
        'episode_rewards': [],
        'episode_lengths': [],
        'losses': [],
        'mean_rewards_100': [],
        'eval_rewards': [],
        'epsilons': [],
        'frames': []
    }
    
    print(f"\nContinuing training for {additional_frames} frames from step {agent.steps_done}...\n")
    print(f"Training in sessions of {frames_per_session} frames each to manage memory usage")
    
    # Calculate how many sessions needed
    total_sessions = additional_frames // frames_per_session
    if additional_frames % frames_per_session > 0:
        total_sessions += 1
    
    # Train in smaller sessions to avoid memory issues
    frames_remaining = additional_frames
    session_start_time = time.time()
    overall_start_time = session_start_time
    
    for session in range(1, total_sessions + 1):
        # Find latest checkpoint if this isn't the first session
        if session > 1:
            # Get latest model file
            model_files = [f for f in os.listdir(save_path) if f.startswith("model_") and f.endswith(".pt")]
            if model_files:
                # Modified sorting function to handle 'model_best.pt'
                def get_model_number(filename):
                    parts = filename.split("_")[1].split(".")[0]
                    if parts == "best":
                        return float('inf')  # Make 'best' sort to the end
                    try:
                        return int(parts)
                    except ValueError:
                        return -1  # For any other non-integer names
                
                latest_model = sorted(model_files, key=get_model_number)[-1]
                latest_model_path = os.path.join(save_path, latest_model)
                print(f"\nResuming from checkpoint: {latest_model_path}")
                agent.load_model(latest_model_path)
        
        # Rest of the code remains the same until the training loop...
        
        # Calculate frames for this session
        session_frames = min(frames_per_session, frames_remaining)
        print(f"\nStarting training session {session}/{total_sessions} ({session_frames} frames)")
        
        # Initialize environment and progress bar
        obs = envs.reset()
        obs_tensor = preprocess_batch_observation(obs)
        
        losses = []
        all_rewards = []
        episode_reward = np.zeros(NUM_ENVS)
        episode_length = np.zeros(NUM_ENVS)
        
        progress_bar = tqdm(range(1, session_frames + 1), desc=f"Session {session}/{total_sessions}")
        
        try:
            # Training loop
            for frame_idx in progress_bar:
                # Code remains the same until the episode completion check...
                
                # Select actions
                actions = []
                for i in range(NUM_ENVS):
                    action = agent.select_action(obs_tensor[i:i+1])
                    actions.append(action.item())
                
                # Execute actions
                next_obs, rewards, terminateds, truncateds = envs.step(actions)
                
                # Process data for each environment
                dones = []
                for t, tr in zip(terminateds, truncateds):
                    if isinstance(tr, dict):
                        done = t or tr.get("TimeLimit.truncated", False)
                    else:
                        done = t or tr
                    dones.append(done)

                next_obs_tensor = preprocess_batch_observation(next_obs)
                
                # Update cumulative rewards and episode length
                episode_reward += rewards
                episode_length += 1
                
                # Store data in experience replay buffer
                for i in range(NUM_ENVS):
                    agent.memory.push(
                        obs_tensor[i:i+1],
                        actions[i],
                        rewards[i],
                        next_obs_tensor[i:i+1],
                        float(dones[i])
                    )
                
                # Update observations
                obs = next_obs
                obs_tensor = next_obs_tensor
                
                # Optimize model
                loss = agent.optimize_model()
                if loss is not None:
                    losses.append(loss)
                
                # Check for episode completion
                for i, done in enumerate(dones):
                    if done:
                        # Record episode results
                        agent.writer.add_scalar("train/episode_reward", episode_reward[i], agent.steps_done)
                        agent.writer.add_scalar("train/episode_length", episode_length[i], agent.steps_done)
                        all_rewards.append(episode_reward[i])
                        
                        # Store metrics for plotting
                        metrics['episode_rewards'].append(episode_reward[i])
                        metrics['episode_lengths'].append(episode_length[i])
                        metrics['frames'].append(agent.steps_done)
                        
                        episode_count += 1
                        
                        # Track best reward
                        if episode_reward[i] > best_reward:
                            best_reward = episode_reward[i]
                            print(f"\nNew best reward: {best_reward:.2f} at episode {episode_count}")
                            # Save best model
                            best_model_path = os.path.join(save_path, "model_best.pt")
                            agent.save_model(best_model_path)

                        # Print detailed episode information
                        elapsed = time.time() - overall_start_time
                        print(f"\nEpisode {episode_count} completed in env {i}:")
                        print(f"  Steps: {episode_length[i]}")
                        print(f"  Reward: {episode_reward[i]:.2f}")
                        print(f"  Epsilon: {agent.epsilon:.4f}")
                        if losses:
                            print(f"  Loss: {np.mean(losses[-100:]):.6f}")
                        print(f"  Total frames: {agent.steps_done}")
                        print(f"  Elapsed time: {elapsed/60:.2f} minutes")
                        print(f"  Frames per second: {agent.steps_done/elapsed:.2f}")
                        
                        # Generate and save plots periodically
                        if episode_count % 10 == 0:
                            plot_training_progress(metrics, save_path)
                        
                        # Reset episode statistics
                        episode_reward[i] = 0
                        episode_length[i] = 0
                
                # Update target network
                if agent.steps_done % TARGET_UPDATE == 0:
                    agent.update_target_network()
                    print(f"\nFrame {agent.steps_done}: Target network updated")
                
                # Record training statistics
                if frame_idx % 1000 == 0:
                    mean_reward = np.mean(all_rewards[-100:]) if all_rewards else 0
                    mean_loss = np.mean(losses[-100:]) if losses else 0
                    
                    # Store for plotting
                    metrics['mean_rewards_100'].append(mean_reward)
                    metrics['losses'].append(mean_loss)
                    metrics['epsilons'].append(agent.epsilon)
                    
                    agent.writer.add_scalar("train/epsilon", agent.epsilon, agent.steps_done)
                    agent.writer.add_scalar("train/loss", mean_loss, agent.steps_done)
                    agent.writer.add_scalar("train/mean_reward_100", mean_reward, agent.steps_done)
                    
                    progress_bar.set_postfix({
                        "avg_reward": f"{mean_reward:.2f}",
                        "loss": f"{mean_loss:.5f}",
                        "epsilon": f"{agent.epsilon:.2f}"
                    })
                
                # Save model every SAVE_INTERVAL steps and also at the end of each session
                if agent.steps_done % SAVE_INTERVAL == 0:
                    save_path_checkpoint = os.path.join(save_path, f"model_{agent.steps_done}.pt")
                    agent.save_model(save_path_checkpoint)
                    print(f"\nFrame {agent.steps_done}: Model saved to {save_path_checkpoint}")
                
                # Evaluate model
                if agent.steps_done % EVAL_INTERVAL == 0:
                    print(f"\nFrame {agent.steps_done}: Evaluating...")
                    eval_reward, eval_std, eval_rewards = evaluate(
                        agent,
                        ENV_NAME,
                        num_episodes=EVAL_EPISODES,
                        video_prefix=f"eval_{agent.steps_done}"
                    )
                    
                    # Store evaluation metrics
                    metrics['eval_rewards'].append(eval_reward)
                    
                    agent.writer.add_scalar("eval/mean_reward", eval_reward, agent.steps_done)
                    agent.writer.add_scalar("eval/reward_std", eval_std, agent.steps_done)
                    
                    print(f"Evaluation results: Mean reward = {eval_reward:.2f} ± {eval_std:.2f}")
                    for i, r in enumerate(eval_rewards):
                        print(f"  Eval episode {i+1}: Reward = {r:.2f}")
                
                # Update agent's step counter
                agent.steps_done += 1
            
            # Save model at the end of each session
            session_end_model_path = os.path.join(save_path, f"model_{agent.steps_done}.pt")
            agent.save_model(session_end_model_path)
            print(f"\nSession {session} complete. Model saved to {session_end_model_path}")
            
        except Exception as e:
            # Handle any exceptions (like CUDA OOM) by saving current progress
            print(f"\nError encountered: {e}")
            error_model_path = os.path.join(save_path, f"model_error_{agent.steps_done}.pt")
            agent.save_model(error_model_path)
            print(f"Model saved at error point: {error_model_path}")
            print("You can resume training from this checkpoint.")
        
        # Update frames_remaining for next session
        frames_remaining -= session_frames
        
        # Print session summary
        session_time = time.time() - session_start_time
        print(f"\nSession {session} Summary:")
        print(f"Frames processed: {session_frames}")
        print(f"Session time: {session_time/60:.2f} minutes")
        print(f"Frames per second: {session_frames/session_time:.2f}")
        
        # Generate and save plots at the end of each session
        plot_training_progress(metrics, save_path)
        
        # Reset for next session
        session_start_time = time.time()
    
    # Save final model after all training
    final_path = os.path.join(save_path, "model_final.pt")
    agent.save_model(final_path)
    print(f"\nFinal model saved to: {final_path}")
    
    # Generate final comprehensive plots
    plot_training_progress(metrics, save_path, final=True)
    
    # Save metrics as CSV for later analysis
    metrics_df = pd.DataFrame({
        'frames': metrics['frames'],
        'episode_rewards': metrics['episode_rewards'],
        'episode_lengths': metrics['episode_lengths']
    })
    metrics_df.to_csv(os.path.join(save_path, 'training_metrics.csv'), index=False)
    
    # Print overall training summary
    elapsed_time = time.time() - overall_start_time
    print("\n===== Training Summary =====")
    print(f"Total episodes: {episode_count}")
    print(f"Total frames: {agent.steps_done}")
    print(f"Best reward: {best_reward:.2f}")
    print(f"Average reward (last 100): {np.mean(all_rewards[-100:]):.2f}")
    print(f"Final epsilon: {agent.epsilon:.4f}")
    print(f"Total training time: {elapsed_time/60:.2f} minutes")
    print(f"Frames per second: {agent.steps_done/elapsed_time:.2f}")
    
    # Restore original paths
    SAVE_PATH = original_save_path
    LOG_PATH = original_log_path
    VIDEO_PATH = original_video_path
    
    return final_path, metrics_df

def plot_training_progress(metrics, save_path, final=False):
    """Generate and save plots showing training progress."""
    plt.figure(figsize=(20, 15))
    
    # Plot episode rewards
    plt.subplot(3, 2, 1)
    plt.plot(metrics['frames'], metrics['episode_rewards'], 'b-')
    plt.title('Episode Rewards Over Time')
    plt.xlabel('Frames')
    plt.ylabel('Reward')
    plt.grid(True)
    
    # Plot mean rewards (100-episode rolling average)
    if len(metrics['mean_rewards_100']) > 0:
        plt.subplot(3, 2, 2)
        x_frames = [i * 1000 for i in range(len(metrics['mean_rewards_100']))]
        plt.plot(x_frames, metrics['mean_rewards_100'], 'g-')
        plt.title('Mean Reward (Last 100 Episodes)')
        plt.xlabel('Frames')
        plt.ylabel('Mean Reward')
        plt.grid(True)
    
    # Plot episode lengths
    plt.subplot(3, 2, 3)
    plt.plot(metrics['frames'], metrics['episode_lengths'], 'r-')
    plt.title('Episode Lengths Over Time')
    plt.xlabel('Frames')
    plt.ylabel('Steps')
    plt.grid(True)
    
    # Plot losses
    if len(metrics['losses']) > 0:
        plt.subplot(3, 2, 4)
        x_frames = [i * 1000 for i in range(len(metrics['losses']))]
        plt.plot(x_frames, metrics['losses'], 'm-')
        plt.title('Training Loss')
        plt.xlabel('Frames')
        plt.ylabel('Loss')
        plt.grid(True)
    
    # Plot epsilon decay
    if len(metrics['epsilons']) > 0:
        plt.subplot(3, 2, 5)
        x_frames = [i * 1000 for i in range(len(metrics['epsilons']))]
        plt.plot(x_frames, metrics['epsilons'], 'k-')
        plt.title('Epsilon Decay')
        plt.xlabel('Frames')
        plt.ylabel('Epsilon')
        plt.grid(True)
    
    # Plot evaluation rewards
    if len(metrics['eval_rewards']) > 0:
        plt.subplot(3, 2, 6)
        x_frames = [i * EVAL_INTERVAL for i in range(len(metrics['eval_rewards']))]
        plt.plot(x_frames, metrics['eval_rewards'], 'c-')
        plt.title('Evaluation Rewards')
        plt.xlabel('Frames')
        plt.ylabel('Reward')
        plt.grid(True)
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = "final_training_plots.png" if final else f"training_plots_{len(metrics['frames'])}.png"
    plt.savefig(os.path.join(save_path, plot_filename))
    plt.close()


## How to Continue Training

To continue training from an existing model, uncomment the code in the cell below. This will create a new parallel environment, load the specified model checkpoint, and continue training for the desired number of frames. The training will be divided into smaller sessions to manage memory usage better.

In [None]:
# torch.cuda.empty_cache()

# # Create parallel environments
# envs = make_vec_env(ENV_NAME, NUM_ENVS)

# print(f"Number of actions: {envs.action_space.n}")

# # Continue training from your existing model
# model_path = "./models/donkey_kong_continued_20250410_174640/model_error_10021320.pt"
# final_model_path = continue_training(model_path, envs, additional_frames=1_000_000, frames_per_session=200_000)

# # Close environments
# envs.close()

# print(f"\nContinued training complete. Final model saved at {final_model_path}")