# A2C
## 实现代码
* 多步 TD 目标
    * 使用 Python 列表管理前 1 - m 步的 Transition
    * 每次插入新的 Transition 后, 更新 Transition 的回报 (乘上 $\gamma^{\Delta t}$)
    * 当 Episode 结束或累计 m 步 Transition, 打包为一个 Batch 进行训练
    * 打包时注意 Batch 中每个 Transition 的 next_state 相同, 但计算 TD 目标时 $\gamma$ 的指数不同
    * 使用多步 TD 目标时, 模型反向传播次数减少, 大幅提高了训练速度, 但又能使训练更加稳定 
* 使用 `torch.distributions.Categorical()` 根据概率质量分布采样
* 使用 `torch.Tensor.detach()` 将张量转变为没有梯度的常量, 用于计算策略梯度

## 特点记录
* 相比于 DQN, 由于无法使用经验回放, 可以发现 A2C 的稳定性非常差, 即使是通过 Optuna 等工具寻找最优超参数, 其在收敛后的表现依然差于 DQN, 但同时 A2C 的训练速度相比 DQN 稍快 
* 通过引入多步 TD 目标, 可以很好地改善 A2C 不稳定的问题, 但依然存在; 此外由于多步 TD 目标中 m 步经验被打包为一个 Batch, 反向传播次数减少, 训练速度进一步加快, 相比 DQN, 达到收敛所需的 Episode 增大一倍但收敛训练时间缩短一倍, 能大量搜索超参数
* A2C 还有特点是, 如果训练停止时模型虽然收敛但出现训练恶化, 则实际测试时模型表现也将非常差


In [1]:
import os
import sys
from ipynb_utility import get_file, set_seed
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(get_file()), '..')))

seed = 114514
set_seed(seed)

import gymnasium as gym

from src.RL.A2C import A2C, A2C_WithMultiStep, HyperParam
from src.RL.utility.train_rl import RL_Teacher

In [2]:
model = A2C(HyperParam(
    lr_critic = 1.6e-3,
    lr_actor = 7.5e-5,
    gamma = 0.99,
    hidden_dim = 128
))

teacher = RL_Teacher(model, "CartPole-v1_A2C", f"seed_{seed}", id = "CartPole-v1", render_mode = "rgb_array")
teacher.train(episode = 800, is_fix_seed = True)
print("CartPole-v1_A2C: ", teacher.test())

  logger.warn(
100%|██████████| 800/800 [10:08<00:00,  1.31it/s]
  logger.warn(
100%|██████████| 10/10 [00:11<00:00,  1.11s/it]


CartPole-v1_A2C:  500.0


In [3]:
model = A2C_WithMultiStep(HyperParam(
    lr_critic = 2.2e-3,
    lr_actor = 1.2e-4,
    m = 5
))

teacher = RL_Teacher(model, "CartPole-v1_A2C_WithMultiStep", f"seed_{seed}", id = "CartPole-v1", render_mode = "rgb_array")
teacher.train(episode = 1200, is_fix_seed = True)
print("CartPole-v1_A2C_WithMultiStep: ", teacher.test())

  logger.warn(
100%|██████████| 1200/1200 [11:40<00:00,  1.71it/s]
  logger.warn(
100%|██████████| 10/10 [00:16<00:00,  1.69s/it]


CartPole-v1_A2C_WithMultiStep:  500.0


In [4]:
# A2C With m-step td target 的 Optuna 超参数搜索

# import logging
# import sys

# import optuna

# # Add stream handler of stdout to show the messages
# optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

# def target(trial: optuna.Trial):
#     lr_critic = trial.suggest_float("lr_critic", 5e-3, 5e-3, log = True)
#     lr_actor = trial.suggest_float("lr_actor", 5e-3, 5e-3, log = True)

#     model = A2C_WithMultiStep(HyperParam(
#         lr_critic = lr_critic,
#         lr_actor = lr_actor,
#         gamma = 0.99,
#         hidden_dim = 128
#     ))

#     teacher = RL_Teacher(model, "CartPole-v1_A2C_WithMultiStep_param_search", f"lrc_{lr_critic:.2e}_lra_{lr_actor:.2e}", id = "CartPole-v1", render_mode = "rgb_array")
#     avg_return = teacher.train(
#         episode = 1200, is_log = False, 
#         last_episode_return = 300, 
#         is_fix_seed = True
#     )
#     return teacher.test(
#         is_log_vedio = True, 
#         vedio_record_gap = 6
#     ) + avg_return * 0.1

# study = optuna.create_study(
#     direction = "maximize", 
#     study_name = f"CartPole-v1_A2C_WithMultiStep", 
#     storage = "sqlite:///optuna_study/CartPole-v1_A2C_WithMultiStep.db", 
#     load_if_exists = True
# )
# study.optimize(target, 10)
# print(study.best_params)

In [5]:
# A2C 的 Optuna 超参数搜索

# import logging
# import sys

# import optuna

# # Add stream handler of stdout to show the messages
# optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

# def target(trial: optuna.Trial):
#     lr_critic = trial.suggest_float("lr_critic", 1e-5, 1e-2, log = True)
#     lr_actor = trial.suggest_float("lr_actor", 1e-5, 1e-2, log = True)

#     model = A2C(HyperParam(
#         lr_critic = lr_critic,
#         lr_actor = lr_actor,
#         gamma = 0.99,
#         hidden_dim = 128
#     ))

#     teacher = RL_Teacher(model, "CartPole-v1_A2C_param_search", f"lrc_{lr_critic:.2e}_lra_{lr_actor:.2e}", id = "CartPole-v1", render_mode = "rgb_array")
#     avg_return = teacher.train(
#         episode = 1200, is_log = False, 
#         last_episode_return = 300, 
#         is_fix_seed = True
#     )
#     return teacher.test(
#         is_log_vedio = True, 
#         vedio_record_gap = 6
#     ) + avg_return * 0.1

# study = optuna.create_study(
#     direction = "maximize", 
#     study_name = f"CartPole-v1_A2C", 
#     storage = "sqlite:///optuna_study/CartPole-v1_A2C.db", 
#     load_if_exists = True
# )
# study.optimize(target, 25)
# print(study.best_params)