# 使用DQN训练Atari Donkey Kong

本notebook实现了一个DQN代理来玩Atari游戏Donkey Kong，并包含以下特性：
- 并行训练多个游戏环境
- 预处理游戏帧以提高训练效率
- 使用优先经验回放提高训练质量
- 训练日志记录
- 定期保存模型
- 定期评估并录制游戏视频

## 1. 安装必要的依赖

In [111]:
# 安装必要的库
# %pip install stable-baselines3[extra] gymnasium[atari] numpy matplotlib opencv-python tensorboard autorom[accept-rom-license]

## 2. 导入库

In [112]:
import os
import random
import time
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from gymnasium.wrappers import RecordVideo, FrameStackObservation
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque, namedtuple
from tqdm.notebook import tqdm
import cv2
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.atari_wrappers import AtariWrapper
import ale_py
from gymnasium import spaces

# 设置随机种子，保证实验可复现
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

## 3. 配置参数

In [113]:
# 环境参数
ENV_NAME = "ALE/DonkeyKong-v5"
NUM_ENVS = 1  # 并行环境数量
FRAME_SKIP = 4  # 跳帧数，每隔4帧进行一次决策
ALLOWED_ACTIONS = [0,1,2,3,4,5,11,12]  # 有效动作

# 模型参数
BATCH_SIZE = 32
GAMMA = 0.99  # 折扣因子
LEARNING_RATE = 0.0001
MEMORY_SIZE = 100000  # 经验回放缓冲区大小
TARGET_UPDATE = 10000  # 目标网络更新频率

# 训练参数
NUM_FRAMES = 15_000_000  # 总训练帧数
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY = 10_000_000
DEMO_PATH = "./demo/dk_demo_20250325_171619.pkl"

# 保存和评估参数
SAVE_INTERVAL = 100_000  # 保存模型的间隔（帧数）
EVAL_INTERVAL = 50_000   # 评估模型的间隔（帧数）
EVAL_EPISODES = 3       # 每次评估的游戏局数

# 创建保存模型和日志的目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_PATH = f"./models/donkey_kong_{timestamp}"
LOG_PATH = f"./logs/donkey_kong_{timestamp}"
VIDEO_PATH = f"./videos/donkey_kong_{timestamp}"

for path in [SAVE_PATH, LOG_PATH, VIDEO_PATH]:
    if not os.path.exists(path):
        os.makedirs(path)

# 设置设备（GPU或CPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

使用设备: cuda


## 4. 环境预处理

In [114]:
# 限制动作空间,减少 agent 的无用动作
class ActionRestrictWrapper(gym.ActionWrapper):
    def __init__(self, env, allowed_actions):
        super().__init__(env)
        self.allowed_actions = allowed_actions
        self.action_space = spaces.Discrete(len(self.allowed_actions))

    def action(self, act):
        # 把 agent 输出的动作索引映射成原动作编号
        return self.allowed_actions[act]

    def reverse_action(self, act):
        return self.allowed_actions.index(act)

# 强制首个动作为FIRE的包装器
class ForceFirstFireWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.first_action_done = False
    
    def reset(self, **kwargs):
        self.first_action_done = False
        return self.env.reset(**kwargs)
    
    def step(self, action):
        # 如果是首个动作且不是FIRE(1)，则强制替换为FIRE
        if not self.first_action_done:
            self.first_action_done = True
            # 使用FIRE动作(索引1)代替传入的动作
            return self.env.step(1)  # 1对应FIRE动作
        return self.env.step(action)

# 根据颜色检测人物位置的函数
def get_agent_position(frame): 
    """ 根据颜色检测人物位置，返回 (x, y) 坐标。未检测到则返回 None。 """
    # 确保frame是numpy数组且格式正确
    if frame is None:
        return None
    
    # 目标颜色（BGR 格式）
    target_bgr = np.array([194, 64, 82], dtype=np.uint8)

    # 容差范围（可调，20~40 一般比较合适）
    tolerance = 30
    lower = np.array([max(0, c - tolerance) for c in target_bgr], dtype=np.uint8)
    upper = np.array([min(255, c + tolerance) for c in target_bgr], dtype=np.uint8)

    # 生成掩码
    mask = cv2.inRange(frame, lower, upper)
    
    # 查找轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return None

    # 找面积最大轮廓
    largest = max(contours, key=cv2.contourArea)
    M = cv2.moments(largest)

    if M["m00"] == 0:
        return None

    cx = int(M["m10"] / M["m00"])
    cy = int(M["m01"] / M["m00"])

    return (cx, cy)

# 自定义视频显示包装器，用于在视频中显示动作和代理位置
class VideoDisplayWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.current_action = None
        self.action_names = {
            0: "NOOP",
            1: "FIRE",
            2: "UP",
            3: "RIGHT",
            4: "LEFT",
            5: "DOWN",
            11: "RIGHT-FIRE",
            12: "LEFT-FIRE"
        }
        
    def step(self, action):
        # 记录当前动作
        self.current_action = action
        return self.env.step(action)
    
    def reset(self, **kwargs):
        self.current_action = None
        return self.env.reset(**kwargs)
    
    def render(self):
        # 获取原始渲染帧
        frame = self.env.render()
        
        if frame is None:
            return None
        
        # 确保帧是RGB格式
        if len(frame.shape) == 2:
            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)

        # 1. 在右上角显示当前动作
        if self.current_action is not None:
            action_name = self.action_names.get(self.current_action, f"ACTION_{self.current_action}")
            cv2.putText(frame, 
                       action_name, 
                       (frame.shape[1] - 50, 10), # 右上角位置
                       cv2.FONT_HERSHEY_SIMPLEX, 
                       0.3,
                       (255, 255, 255), # 白色文本
                       1, 
                       cv2.LINE_AA)
        
        # 2. 在代理位置绘制绿色圆点
        position = get_agent_position(frame)
        if position:
            x, y = position
            cv2.circle(frame, 
                      (x, y), 
                      2, # 圆点半径
                      (0, 255, 0), # 绿色
                      -1) # 填充圆
        
        return frame

# 自定义奖励包装器，用于根据Agent的位置变化调整奖励
class CustomRewardWrapper(gym.Wrapper):
    def __init__(self, env, y_static_penalty=0, up_success_reward=5,
                 up_fail_penalty=0, x_static_penalty=0.01,
                 y_threshold=3, x_threshold=3, 
                 y_static_frames=30, x_static_frames=30):
        super().__init__(env)
        # 奖励参数
        self.y_static_penalty = y_static_penalty  # 垂直静止惩罚
        self.up_success_reward = up_success_reward  # 成功向上移动奖励
        self.up_fail_penalty = up_fail_penalty  # 向上失败惩罚
        self.x_static_penalty = x_static_penalty  # 水平静止惩罚
        
        # 阈值参数
        self.y_threshold = y_threshold  # 垂直移动阈值
        self.x_threshold = x_threshold  # 水平移动阈值
        self.y_static_frames = y_static_frames  # 垂直静止判定帧数
        self.x_static_frames = x_static_frames  # 水平静止判定帧数
        
        # 状态记录
        self.prev_positions = []  # 存储过去的位置 [(x, y), ...]
        self.y_static_count = 0  # 垂直静止计数
        self.x_static_count = 0  # 水平静止计数
        self.prev_action = None  # 上一个动作
    
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        # 重置状态记录
        self.prev_positions = []
        self.y_static_count = 0
        self.x_static_count = 0
        self.prev_action = None
        return obs, info
    
    def step(self, action):
        # 记录当前动作
        self.prev_action = action
        
        # 执行环境步骤
        obs, reward, terminated, truncated, info = self.env.step(action)
        
        # 从观察中提取RGB帧
        frame = None
        try:
            if isinstance(obs, np.ndarray):
                if len(obs.shape) == 4:  # (stack, height, width, channel)
                    frame = obs[-1]  # 最后一帧
                elif len(obs.shape) == 3:  # (height, width, channel)
                    frame = obs
                elif len(obs.shape) == 2:  # (height, width)
                    frame = obs
            elif hasattr(obs, '__getitem__'):
                # 对于 FrameStackObservation
                try:
                    frame = obs[-1]
                except:
                    try:
                        frame = obs[3]  # 假设是4帧堆叠
                    except:
                        pass
            
            # 如果上述尝试都失败，尝试渲染环境
            if frame is None:
                try:
                    frame = self.env.render()
                except:
                    pass
        except Exception as e:
            print(f"从观察中提取帧时发生错误: {e}")
            frame = None
        
        # 检测Agent位置
        position = None
        if frame is not None:
            position = get_agent_position(frame)
        
        # 如果检测到位置，则更新位置历史并计算奖励调整
        additional_reward = 0
        
        if position is not None:
            x, y = position
            self.prev_positions.append((x, y))
            
            # 保持历史记录在合理大小
            if len(self.prev_positions) > max(self.y_static_frames, self.x_static_frames):
                self.prev_positions.pop(0)
            
            # 至少有两个位置记录才能判断移动
            if len(self.prev_positions) >= 2:
                prev_x, prev_y = self.prev_positions[-2]
                
                # 1. 检查垂直方向是否静止
                if abs(y - prev_y) < self.y_threshold:
                    self.y_static_count += 1
                    if self.y_static_count >= self.y_static_frames:
                        # 线性增加惩罚
                        additional_reward -= self.y_static_penalty * (self.y_static_count - self.y_static_frames + 1)
                else:
                    self.y_static_count = 0
                
                # 2. 检查UP动作的效果
                if self.prev_action == 2:  # 假设2是UP动作
                    if (prev_y - y) > self.y_threshold:  # 成功向上移动
                        additional_reward += self.up_success_reward
                    else:  # 未成功向上移动
                        additional_reward -= self.up_fail_penalty
                
                # 3. 检查水平方向是否静止
                if abs(x - prev_x) < self.x_threshold:
                    self.x_static_count += 1
                    if self.x_static_count >= self.x_static_frames:
                        # 线性增加惩罚
                        additional_reward -= self.x_static_penalty * (self.x_static_count - self.x_static_frames + 1)
                else:
                    self.x_static_count = 0
        
        # 应用奖励调整
        adjusted_reward = reward + additional_reward
        
        return obs, adjusted_reward, terminated, truncated, info

# 创建预处理后的环境的函数
def make_env(env_id, idx, capture_video=False, run_name=None):
    def thunk():
        import ale_py
        
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            # 添加视频显示包装器
            env = VideoDisplayWrapper(env)
            env = RecordVideo(
                env,
                VIDEO_PATH,
                episode_trigger=lambda x: True,
                name_prefix=f"donkey_kong_{run_name}"
            )
        else:
            env = gym.make(env_id)
        
        env = ActionRestrictWrapper(env, ALLOWED_ACTIONS)
        env = ForceFirstFireWrapper(env)
        # 添加自定义奖励包装器
        env = CustomRewardWrapper(env)
        env = AtariWrapper(env, terminal_on_life_loss=True, frame_skip=FRAME_SKIP)
        env = FrameStackObservation(env, 4)  # 堆叠4帧以捕获时间信息
        return env
    return thunk

# 创建并行环境
def make_vec_env(env_id, num_envs, seed=SEED):
    env_fns = [make_env(env_id, i) for i in range(num_envs)]
    envs = SubprocVecEnv(env_fns)
    envs.seed(seed)
    return envs

## 5. DQN网络模型

In [115]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        # 输入形状: (batch, stack_frames, height, width)
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

## 6. 优先经验回放

In [116]:
# 使用优先经验回放提高训练效率
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_frames=100000):
        self.capacity = capacity
        self.alpha = alpha  # 控制优先级的程度
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame = 1  # 当前帧，用于beta计算
        self.buffer = []
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.position = 0
        self.Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
    
    def beta_by_frame(self, frame_idx):
        # beta从beta_start线性增加到1.0
        return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames)
    
    def push(self, *args):
        # 添加新的经验
        max_prio = np.max(self.priorities) if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append(self.Transition(*args))
        else:
            self.buffer[self.position] = self.Transition(*args)
        
        self.priorities[self.position] = max_prio
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.position]
        
        # 计算采样概率
        probs = prios ** self.alpha
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]
        
        # 计算重要性采样权重
        beta = self.beta_by_frame(self.frame)
        self.frame += 1
        
        weights = (len(self.buffer) * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights = torch.tensor(weights, device=device, dtype=torch.float32)
        
        # 转换为批量处理格式
        batch = self.Transition(*zip(*samples))
        states = torch.cat(batch.state)
        actions = torch.tensor(batch.action, device=device)
        rewards = torch.tensor(batch.reward, device=device, dtype=torch.float32)
        next_states = torch.cat(batch.next_state)
        dones = torch.tensor(batch.done, device=device, dtype=torch.float32)
        
        return states, actions, rewards, next_states, dones, indices, weights
    
    def update_priorities(self, indices, priorities):
        # 更新优先级
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority
    
    def __len__(self):
        return len(self.buffer)

## 7. DQN代理

In [117]:
class DQNAgent:
    def __init__(self, state_shape, n_actions):
        self.state_shape = state_shape
        self.n_actions = n_actions
        
        # 创建策略网络和目标网络
        self.policy_net = DQN(state_shape, n_actions).to(device)
        self.target_net = DQN(state_shape, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # 目标网络不需要计算梯度
        
        # 设置优化器
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        
        # 创建经验回放缓冲区
        self.memory = PrioritizedReplayBuffer(MEMORY_SIZE)
        
        # 训练相关参数
        self.steps_done = 0
        self.epsilon = EPSILON_START
        
        # 日志记录器
        self.writer = SummaryWriter(LOG_PATH)
    
    def select_action(self, state, eval_mode=False):
        # ε-贪婪策略选择动作
        sample = random.random()
        # 在评估模式下，始终选择最佳动作
        if eval_mode:
            eps_threshold = 0.05  # 评估时使用小的epsilon，增加一些探索性
        else:
            # 线性衰减epsilon
            self.epsilon = max(EPSILON_END, EPSILON_START - self.steps_done / EPSILON_DECAY)
            eps_threshold = self.epsilon
            
        if sample > eps_threshold:
            with torch.no_grad():
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.n_actions)]], device=device, dtype=torch.long)
    
    def optimize_model(self):
        if len(self.memory) < BATCH_SIZE:
            return 0.0  # 缓冲区中的样本不足
        
        # 从经验回放缓冲区中采样
        states, actions, rewards, next_states, dones, indices, weights = self.memory.sample(BATCH_SIZE)
        
        # 计算当前Q值
        q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # 使用Double DQN计算下一个状态的Q值
        # 使用策略网络选择动作
        next_actions = self.policy_net(next_states).max(1)[1].unsqueeze(1)
        # 使用目标网络评估动作
        next_q_values = self.target_net(next_states).gather(1, next_actions).squeeze(1)
        # 将终止状态的下一个Q值设为0
        next_q_values = next_q_values * (1 - dones)
        # 计算目标Q值
        target_q_values = rewards + GAMMA * next_q_values
        
        # 计算损失（TD误差）
        td_error = torch.abs(q_values - target_q_values).detach().cpu().numpy()
        loss = F.smooth_l1_loss(q_values, target_q_values, reduction='none') * weights
        loss = loss.mean()
        
        # 优化模型
        self.optimizer.zero_grad()
        loss.backward()
        # 梯度裁剪，防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10)
        self.optimizer.step()
        
        # 更新优先级
        self.memory.update_priorities(indices, td_error + 1e-5)
        
        return loss.item()
    
    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
    
    def save_model(self, path):
        torch.save({
            'policy_net': self.policy_net.state_dict(),
            'target_net': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'steps_done': self.steps_done,
            'epsilon': self.epsilon
        }, path)
    
    def load_model(self, path):
        checkpoint = torch.load(path)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.target_net.load_state_dict(checkpoint['target_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.steps_done = checkpoint['steps_done']
        self.epsilon = checkpoint['epsilon']

## 8. 预处理和状态转换函数

In [118]:
def preprocess_observation(obs):
    # 把堆叠的4帧图像转换为PyTorch的输入格式
    frames = np.array(obs).squeeze(-1)
    tensor = torch.tensor(frames, dtype=torch.float32, device=device).unsqueeze(0)
    return tensor / 255.0  # 归一化

def preprocess_batch_observation(obs):
    # 处理批量观察数据
    frames = np.array(obs).squeeze(-1)
    tensor = torch.tensor(frames, dtype=torch.float32, device=device)
    return tensor / 255.0  # 归一化

## 9. 评估函数

In [119]:
def evaluate(agent, env_id, num_episodes=5, video_prefix="evaluation"):
    # 创建评估环境，包括视频录制
    env = make_env(env_id, 0, capture_video=True, run_name=video_prefix)()  
    episode_rewards = []
    
    for i in range(num_episodes):
        obs, _ = env.reset()
        obs_tensor = preprocess_observation(obs)
        done = False
        total_reward = 0.0
        
        while not done:
            action = agent.select_action(obs_tensor, eval_mode=True).item()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward
            
            obs = next_obs
            obs_tensor = preprocess_observation(obs)
        
        episode_rewards.append(total_reward)
    
    env.close()
    return np.mean(episode_rewards), np.std(episode_rewards)

## 10. 训练函数

In [120]:
def train(agent, envs, num_frames):
    # 初始化环境和进度条
    obs = envs.reset()
    obs_tensor = preprocess_batch_observation(obs)
    
    losses = []
    all_rewards = []
    episode_reward = np.zeros(NUM_ENVS)
    episode_length = np.zeros(NUM_ENVS)
    
    progress_bar = tqdm(range(1, num_frames + 1), desc="Training")
    
    # 训练循环
    for frame_idx in progress_bar:
        # 选择动作
        actions = []
        for i in range(NUM_ENVS):
            action = agent.select_action(obs_tensor[i:i+1])
            actions.append(action.item())
        
        # 执行动作
        next_obs, rewards, terminateds, truncateds = envs.step(actions)
        
        # 处理每个环境的数据
        dones = []
        for t, tr in zip(terminateds, truncateds):
            if isinstance(tr, dict):
                done = t or tr.get("TimeLimit.truncated", False)
            else:
                done = t or tr
            dones.append(done)

        next_obs_tensor = preprocess_batch_observation(next_obs)
        
        # 更新累计奖励和回合长度
        episode_reward += rewards
        episode_length += 1
        
        # 将数据存入经验回放缓冲区
        for i in range(NUM_ENVS):
            agent.memory.push(
                obs_tensor[i:i+1],
                actions[i],
                rewards[i],
                next_obs_tensor[i:i+1],
                float(dones[i])
            )
        
        # 更新观察
        obs = next_obs
        obs_tensor = next_obs_tensor
        
        # 优化模型
        loss = agent.optimize_model()
        losses.append(loss)
        
        # 检查是否有回合结束
        for i, done in enumerate(dones):
            if done:
                # 记录回合结果
                agent.writer.add_scalar("train/episode_reward", episode_reward[i], agent.steps_done)
                agent.writer.add_scalar("train/episode_length", episode_length[i], agent.steps_done)
                all_rewards.append(episode_reward[i])
                
                # 重置回合统计
                episode_reward[i] = 0
                episode_length[i] = 0
        
        # 更新目标网络
        if frame_idx % TARGET_UPDATE == 0:
            agent.update_target_network()
        
        # 记录训练统计信息
        if frame_idx % 1000 == 0:
            mean_reward = np.mean(all_rewards[-100:]) if all_rewards else 0
            mean_loss = np.mean(losses[-100:]) if losses else 0
            agent.writer.add_scalar("train/epsilon", agent.epsilon, frame_idx)
            agent.writer.add_scalar("train/loss", mean_loss, frame_idx)
            agent.writer.add_scalar("train/mean_reward_100", mean_reward, frame_idx)
            
            progress_bar.set_postfix({
                "avg_reward": f"{mean_reward:.2f}",
                "loss": f"{mean_loss:.5f}",
                "epsilon": f"{agent.epsilon:.2f}"
            })
        
        # 保存模型
        if frame_idx % SAVE_INTERVAL == 0:
            save_path = os.path.join(SAVE_PATH, f"model_{frame_idx}.pt")
            agent.save_model(save_path)
            print(f"\n模型已保存到: {save_path}")
        
        # 评估模型
        if frame_idx % EVAL_INTERVAL == 0:
            print("\n开始评估...")
            eval_reward, eval_std = evaluate(
                agent,
                ENV_NAME,
                num_episodes=EVAL_EPISODES,
                video_prefix=f"eval_{frame_idx}"
            )
            agent.writer.add_scalar("eval/mean_reward", eval_reward, frame_idx)
            agent.writer.add_scalar("eval/reward_std", eval_std, frame_idx)
            print(f"评估结果: 平均奖励 = {eval_reward:.2f} ± {eval_std:.2f}")
        
        # 更新代理的步数计数器
        agent.steps_done += 1
    
    # 训练结束，保存最终模型
    final_path = os.path.join(SAVE_PATH, "model_final.pt")
    agent.save_model(final_path)
    print(f"\n最终模型已保存到: {final_path}")


def load_demonstrations(agent, filepath):
    """加载示范轨迹文件并注入 agent 的 replay buffer"""
    import pickle

    with open(filepath, 'rb') as f:
        all_trajectories = pickle.load(f)

    count = 0
    for traj in all_trajectories:
        for s, a, r, ns, d in traj:
            # 放到 cuda 上
            agent.memory.push(
                s.to(device),
                a,
                r.to(device),
                ns.to(device),
                d.to(device)
            )
            count += 1
    print(f"🚀 导入示范轨迹完成，共 {count} 条 transition 已加入 replay buffer。")

## 11. 主训练流程

In [None]:
# 创建并行环境
envs = make_vec_env(ENV_NAME, NUM_ENVS)

# 获取环境信息
obs_shape = (4, 84, 84)  # 堆叠的4帧，每帧84x84
n_actions = envs.action_space.n

print(f"观察空间形状: {obs_shape}")
print(f"动作空间大小: {n_actions}")

# 创建DQN代理
agent = DQNAgent(obs_shape, n_actions)

# 加载示范轨迹
if DEMO_PATH:
    Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
    load_demonstrations(agent, DEMO_PATH)

# 开始训练
print("开始训练...")
train(agent, envs, NUM_FRAMES)

# 关闭环境
envs.close()

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


观察空间形状: (4, 84, 84)
动作空间大小: 8
🚀 导入示范轨迹完成，共 2131 条 transition 已加入 replay buffer。
开始训练...


Training:   0%|          | 0/15000000 [00:00<?, ?it/s]

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 4, 84, 84, 1]

## 12. 加载和测试训练好的模型

In [None]:
def play_and_record_video(model_path, env_id, num_episodes=5):
    # 创建环境
    env = make_env(env_id, 0, capture_video=True, run_name="final_test")()  
    
    # 创建代理并加载模型
    obs_shape = (4, 84, 84)  # 堆叠的4帧，每帧84x84
    n_actions = env.action_space.n
    agent = DQNAgent(obs_shape, n_actions)
    agent.load_model(model_path)
    
    # 测试训练好的代理
    rewards = []
    for i in range(num_episodes):
        obs, _ = env.reset()
        obs_tensor = preprocess_observation(obs)
        done = False
        episode_reward = 0
        
        while not done:
            action = agent.select_action(obs_tensor, eval_mode=True).item()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            
            obs = next_obs
            obs_tensor = preprocess_observation(obs)
            
        rewards.append(episode_reward)
        print(f"Episode {i+1}: Reward = {episode_reward}")
    
    env.close()
    print(f"平均奖励: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")

In [None]:
# 加载并测试最终模型
model_path = os.path.join(SAVE_PATH, "model_final.pt")
play_and_record_video(model_path, ENV_NAME, num_episodes=5)



## 13. 可视化训练结果

In [None]:
# 使用TensorBoard可视化训练结果
print(f"可以通过以下命令在终端中启动TensorBoard查看训练指标:")
print(f"tensorboard --logdir={LOG_PATH}")

