# 基本 DQN
> 参考自 <https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html>

## 概述
* 包含 Double DQN 与经验回放的 DQN
* 模型实现代码 `./dqn/BaseDQN2.py`
* 测试环境 gymnasium CartPole-v1

## 记录
* v2.0
    * 相比于 v1.1, 更加符合标准 DQN 算法, 在模型 CartPole-v1 中收敛
    * 按 Episode 软更新 Target Network 
    * 按 action 的调用更新 epsilon, 更加合理
    * 使用了 AdamW 优化器 (开启 amsgrad) 以及裁剪梯度
    * 使用了更符合 CartPole-v1 的模型超参数


In [4]:
import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from dqn.BaseDQN2 import *

In [5]:
def train(name: str, comment: str, episode: int = 500, hparam: HyperParam | None = None, is_write: bool = True):
    '''
    * `name` 训练名称
    * `comment` 训练注释
    * `episode` 训练片段数
    * `hparam` 超参数
    * `is_write` 是否记录训练数据
    '''
    env = gym.make(
        "CartPole-v1", 
        render_mode = "rgb_array"
    )
    if is_write:
        env = RecordEpisodeStatistics(env, buffer_length = 1)
        env = RecordVideo(
            env, 
            video_folder = "vedio_CartPole_with_BaseDQN", 
            name_prefix = name,
            episode_trigger = lambda x: (x + 1) % 100 == 0
        )

    if hparam == None:
        hparam = HyperParam()
    model = BaseDQN(hparam)

    writer = None
    if is_write:
        writer = SummaryWriter(comment = name + "_" + comment)

    for episode in tqdm(range(episode)):
        state, info = env.reset()
        done = False
        total_loss = 0

        while not done:
            
            # 完成一次状态转移
            action = model.take_action_single(state)
            next_state, reward, terminated, truncated, info = env.step(action[0])
            done = terminated or truncated

            # 更新模型
            transition = make_transition_from_numpy(state, action, next_state, reward, terminated)
            loss = model.update(transition)
            if loss != False:
                total_loss += loss

            state = next_state

        model.update_episode(episode)

        # tensorboard 记录平均损失与累计回报
        if writer != None:        
            writer.add_scalar(
                f"{name}/avg_loss",
                total_loss / info["episode"]["l"],
                episode
            )
            writer.add_scalar(
                f"{name}/return",
                info["episode"]["r"],
                episode
            )

        # 记录动作倾向
        if writer != None:  
            if episode % 50 == 0:
                action_sum = 0
                for i in model.reply_queue.buffer:
                    action_sum += i.action.item()

                writer.add_scalar(
                    f"{name}/avg_action",
                    action_sum / model.reply_queue.size(),
                    int(episode / 50)
                )
    env.close()
    
    if writer != None:
        writer.close()


In [None]:
train("CartPole-v1", "test", 600, is_write = True)

# 运行结果

## v2.0
![](./res/CartPole_v1_BaseDQN_v2_0.png)

视频见 `./res/CartPole_v1_BaseDQN_v2_0.zip`

## todo
* 优化代码, 计算 TD 目标时, 当 done 为 True 时不进行预测
* 矢量化环境中, 决策可能存在问题