# DDPG 与 TD3 算法
* DDPG 参考自 <https://hrl.boyuai.com/chapter/2/ddpg%E7%AE%97%E6%B3%95>
* TD3 参考自 <https://blog.csdn.net/weixin_45492196/article/details/107866309>

## 算法实现注意
* 模型
    * 激活函数 `nn.Tanh()` 保证输出为 $[-1, 1]$, 便于使用映射到有界的动作空间
    * 函数 `torch.randn()` 生成满足正态分布的噪声, 加在输出的动作上
* 训练
    * 更新参数时, 通过 `torch.mean()` 将各个 batch 的值合并为单个值 (更新参数只能对标量求导)
    * 如果梯度上升, 则求导变量要乘上 -1, 更新参数时, 计算出待求导标量直接进行优化即可, 不一定需要损失函数

## 特点记录


In [1]:
import os
import sys
from ipynb_utility import get_file, set_seed
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(get_file()), '..')))

seed = 0
set_seed(seed)

import gymnasium as gym

from src.RL.DDPG import DDPG, TD3, HyperParam
from src.RL.utility.train_rl import RL_Teacher

In [2]:
env = gym.make("Pendulum-v1", render_mode = "rgb_array")

model = DDPG(HyperParam(
    hidden_dim = 64,
    lr_critic = 5e-3,
    lr_actor = 1e-3
))

teacher = RL_Teacher(model, "Pendulum-v1_DDPG", f"seed_{seed}", id = "Pendulum-v1", render_mode = "rgb_array")
teacher.train(episode = 300)
print("Pendulum-v1_DDPG: ", teacher.test())

  logger.warn(
100%|██████████| 300/300 [08:40<00:00,  1.73s/it]
  logger.warn(
100%|██████████| 10/10 [00:07<00:00,  1.31it/s]


Pendulum-v1_DDPG:  -151.20351718065132


# DDPG 算法效果
测试环境 `gymnasium Pendulum-v1`

## 学习曲线

![](../res/Pendulum-v1_DDPG.png)

## 示例视频

<video controls src="../res/Pendulum-v1_DDPG.mp4">animation</video>


In [3]:
env = gym.make("Pendulum-v1", render_mode = "rgb_array")

# best
model = TD3(HyperParam(
    hidden_dim = 64,
    lr_critic = 5e-3,
    lr_actor = 2e-3,
    tau = 0.01,
    actor_update_period = 10
))

teacher = RL_Teacher(model, "Pendulum-v1_TD3", f"seed_{seed}", id = "Pendulum-v1", render_mode = "rgb_array")
teacher.train(episode = 300)
print("Pendulum-v1_TD3: ", teacher.test())

  logger.warn(
100%|██████████| 300/300 [08:31<00:00,  1.71s/it]
100%|██████████| 10/10 [00:06<00:00,  1.46it/s]


Pendulum-v1_TD3:  -129.83041391454074


# TD3 算法效果
测试环境 `gymnasium Pendulum-v1`

## 学习曲线

![](../res/Pendulum-v1_TD3.png)

## 示例视频

<video controls src="../res/Pendulum-v1_TD3.mp4">animation</video>


In [2]:
# DDPG 的 Optuna 超参数搜索

# import logging
# import sys

# import optuna

# # Add stream handler of stdout to show the messages
# optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

# def target(trial: optuna.Trial):
#     lr_critic = trial.suggest_float("lr_critic", 5e-4, 5e-2, log = True)
#     lr_actor = trial.suggest_float("lr_actor", 1e-4, 1e-2, log = True)

#     model = DDPG(HyperParam(
#         lr_critic = lr_critic,
#         lr_actor = lr_actor,
#         gamma = 0.99,
#     ))

#     teacher = RL_Teacher(model, "Pendulum-v1_DDPG_param_search", f"lrc_{lr_critic:.2e}_lra_{lr_actor:.2e}", id = "Pendulum-v1", render_mode = "rgb_array")
#     avg_return = teacher.train(
#         episode = 300, is_log = False, 
#         last_episode_return = 100, 
#         is_fix_seed = True
#     )
#     return teacher.test(
#         is_log_vedio = True, 
#         vedio_record_gap = 6
#     ) + avg_return * 0.1

# study = optuna.create_study(
#     direction = "maximize", 
#     study_name = f"Pendulum-v1_DDPG", 
#     storage = "sqlite:///optuna_study/RL.db", 
#     load_if_exists = True
# )
# study.optimize(target, 10)
# print(study.best_params)

  from .autonotebook import tqdm as notebook_tqdm
[I 2024-10-21 14:30:07,614] A new study created in RDB with name: Pendulum-v1_DDPG


A new study created in RDB with name: Pendulum-v1_DDPG


100%|██████████| 300/300 [09:48<00:00,  1.96s/it]
100%|██████████| 10/10 [00:03<00:00,  3.16it/s]
[I 2024-10-21 14:40:01,631] Trial 0 finished with value: -892.3437244091112 and parameters: {'lr_critic': 0.0005646471848238721, 'lr_actor': 0.000981886646040247}. Best is trial 0 with value: -892.3437244091112.


Trial 0 finished with value: -892.3437244091112 and parameters: {'lr_critic': 0.0005646471848238721, 'lr_actor': 0.000981886646040247}. Best is trial 0 with value: -892.3437244091112.


100%|██████████| 300/300 [06:56<00:00,  1.39s/it]
  logger.warn(
100%|██████████| 10/10 [00:02<00:00,  3.41it/s]
[I 2024-10-21 14:47:01,684] Trial 1 finished with value: -691.056564706235 and parameters: {'lr_critic': 0.027559678541426973, 'lr_actor': 0.005273439249997273}. Best is trial 1 with value: -691.056564706235.


Trial 1 finished with value: -691.056564706235 and parameters: {'lr_critic': 0.027559678541426973, 'lr_actor': 0.005273439249997273}. Best is trial 1 with value: -691.056564706235.


100%|██████████| 300/300 [08:39<00:00,  1.73s/it]
100%|██████████| 10/10 [00:02<00:00,  3.79it/s]
[I 2024-10-21 14:55:44,539] Trial 2 finished with value: -476.5666909139637 and parameters: {'lr_critic': 0.004389860221476969, 'lr_actor': 0.007758544640974687}. Best is trial 2 with value: -476.5666909139637.


Trial 2 finished with value: -476.5666909139637 and parameters: {'lr_critic': 0.004389860221476969, 'lr_actor': 0.007758544640974687}. Best is trial 2 with value: -476.5666909139637.


100%|██████████| 300/300 [08:33<00:00,  1.71s/it]
100%|██████████| 10/10 [00:02<00:00,  3.73it/s]
[I 2024-10-21 15:04:20,803] Trial 3 finished with value: -1106.1020237719738 and parameters: {'lr_critic': 0.029325503880990752, 'lr_actor': 0.0002306547983913713}. Best is trial 2 with value: -476.5666909139637.


Trial 3 finished with value: -1106.1020237719738 and parameters: {'lr_critic': 0.029325503880990752, 'lr_actor': 0.0002306547983913713}. Best is trial 2 with value: -476.5666909139637.


100%|██████████| 300/300 [08:38<00:00,  1.73s/it]
100%|██████████| 10/10 [00:02<00:00,  3.68it/s]
[I 2024-10-21 15:13:01,820] Trial 4 finished with value: -706.1215239515487 and parameters: {'lr_critic': 0.0009719841948298609, 'lr_actor': 0.00120108307219138}. Best is trial 2 with value: -476.5666909139637.


Trial 4 finished with value: -706.1215239515487 and parameters: {'lr_critic': 0.0009719841948298609, 'lr_actor': 0.00120108307219138}. Best is trial 2 with value: -476.5666909139637.


100%|██████████| 300/300 [08:35<00:00,  1.72s/it]
100%|██████████| 10/10 [00:02<00:00,  3.75it/s]
[I 2024-10-21 15:21:39,967] Trial 5 finished with value: -199.29209354795583 and parameters: {'lr_critic': 0.00249714679845251, 'lr_actor': 0.0008659353103577053}. Best is trial 5 with value: -199.29209354795583.


Trial 5 finished with value: -199.29209354795583 and parameters: {'lr_critic': 0.00249714679845251, 'lr_actor': 0.0008659353103577053}. Best is trial 5 with value: -199.29209354795583.


100%|██████████| 300/300 [08:37<00:00,  1.73s/it]
100%|██████████| 10/10 [00:02<00:00,  3.64it/s]
[I 2024-10-21 15:30:20,590] Trial 6 finished with value: -537.0282210996745 and parameters: {'lr_critic': 0.03673968666591285, 'lr_actor': 0.0008311868719512627}. Best is trial 5 with value: -199.29209354795583.


Trial 6 finished with value: -537.0282210996745 and parameters: {'lr_critic': 0.03673968666591285, 'lr_actor': 0.0008311868719512627}. Best is trial 5 with value: -199.29209354795583.


100%|██████████| 300/300 [08:32<00:00,  1.71s/it]
100%|██████████| 10/10 [00:02<00:00,  3.69it/s]
[I 2024-10-21 15:38:55,568] Trial 7 finished with value: -503.12321230499515 and parameters: {'lr_critic': 0.018058785870118606, 'lr_actor': 0.0013779613967998122}. Best is trial 5 with value: -199.29209354795583.


Trial 7 finished with value: -503.12321230499515 and parameters: {'lr_critic': 0.018058785870118606, 'lr_actor': 0.0013779613967998122}. Best is trial 5 with value: -199.29209354795583.


100%|██████████| 300/300 [08:30<00:00,  1.70s/it]
100%|██████████| 10/10 [00:02<00:00,  3.80it/s]
[I 2024-10-21 15:47:29,204] Trial 8 finished with value: -1317.7013675657226 and parameters: {'lr_critic': 0.0008998671627604154, 'lr_actor': 0.0021481401823776917}. Best is trial 5 with value: -199.29209354795583.


Trial 8 finished with value: -1317.7013675657226 and parameters: {'lr_critic': 0.0008998671627604154, 'lr_actor': 0.0021481401823776917}. Best is trial 5 with value: -199.29209354795583.


100%|██████████| 300/300 [08:25<00:00,  1.69s/it]
100%|██████████| 10/10 [00:02<00:00,  3.83it/s]
[I 2024-10-21 15:55:57,771] Trial 9 finished with value: -1645.1630078745793 and parameters: {'lr_critic': 0.001771324799077324, 'lr_actor': 0.0005877350666399808}. Best is trial 5 with value: -199.29209354795583.


Trial 9 finished with value: -1645.1630078745793 and parameters: {'lr_critic': 0.001771324799077324, 'lr_actor': 0.0005877350666399808}. Best is trial 5 with value: -199.29209354795583.
{'lr_critic': 0.00249714679845251, 'lr_actor': 0.0008659353103577053}


In [3]:
# TD3 的 Optuna 超参数搜索

# import logging
# import sys

# import optuna

# # Add stream handler of stdout to show the messages
# optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

# def target(trial: optuna.Trial):
#     lr_critic = trial.suggest_float("lr_critic", 5e-4, 5e-2, log = True)
#     lr_actor = trial.suggest_float("lr_actor", 2e-4, 2e-2, log = True)

#     model = TD3(HyperParam(
#         lr_critic = lr_critic,
#         lr_actor = lr_actor,
#         gamma = 0.99,
#     ))

#     teacher = RL_Teacher(model, "Pendulum-v1_TD3_param_search", f"lrc_{lr_critic:.2e}_lra_{lr_actor:.2e}", id = "Pendulum-v1", render_mode = "rgb_array")
#     avg_return = teacher.train(
#         episode = 300, is_log = False, 
#         last_episode_return = 100, 
#         is_fix_seed = True
#     )
#     return teacher.test(
#         is_log_vedio = True, 
#         vedio_record_gap = 6
#     ) + avg_return * 0.1

# study = optuna.create_study(
#     direction = "maximize", 
#     study_name = f"Pendulum-v1_TD3", 
#     storage = "sqlite:///optuna_study/RL.db", 
#     load_if_exists = True
# )
# study.optimize(target, 10)
# print(study.best_params)

[I 2024-10-21 15:55:57,882] A new study created in RDB with name: Pendulum-v1_TD3


A new study created in RDB with name: Pendulum-v1_TD3
A new study created in RDB with name: Pendulum-v1_TD3


100%|██████████| 300/300 [08:06<00:00,  1.62s/it]
100%|██████████| 10/10 [00:02<00:00,  3.73it/s]
[I 2024-10-21 16:04:07,293] Trial 0 finished with value: -1066.2809712154926 and parameters: {'lr_critic': 0.0005822701660399452, 'lr_actor': 0.015556272450804275}. Best is trial 0 with value: -1066.2809712154926.


Trial 0 finished with value: -1066.2809712154926 and parameters: {'lr_critic': 0.0005822701660399452, 'lr_actor': 0.015556272450804275}. Best is trial 0 with value: -1066.2809712154926.
Trial 0 finished with value: -1066.2809712154926 and parameters: {'lr_critic': 0.0005822701660399452, 'lr_actor': 0.015556272450804275}. Best is trial 0 with value: -1066.2809712154926.


100%|██████████| 300/300 [08:08<00:00,  1.63s/it]
  logger.warn(
100%|██████████| 10/10 [00:02<00:00,  3.75it/s]
[I 2024-10-21 16:12:18,251] Trial 1 finished with value: -1521.5917985177452 and parameters: {'lr_critic': 0.011657526306295074, 'lr_actor': 0.011536861112274562}. Best is trial 0 with value: -1066.2809712154926.


Trial 1 finished with value: -1521.5917985177452 and parameters: {'lr_critic': 0.011657526306295074, 'lr_actor': 0.011536861112274562}. Best is trial 0 with value: -1066.2809712154926.
Trial 1 finished with value: -1521.5917985177452 and parameters: {'lr_critic': 0.011657526306295074, 'lr_actor': 0.011536861112274562}. Best is trial 0 with value: -1066.2809712154926.


100%|██████████| 300/300 [08:11<00:00,  1.64s/it]
100%|██████████| 10/10 [00:02<00:00,  3.81it/s]
[I 2024-10-21 16:20:32,852] Trial 2 finished with value: -564.0634798111196 and parameters: {'lr_critic': 0.008847622045835033, 'lr_actor': 0.01415469664351834}. Best is trial 2 with value: -564.0634798111196.


Trial 2 finished with value: -564.0634798111196 and parameters: {'lr_critic': 0.008847622045835033, 'lr_actor': 0.01415469664351834}. Best is trial 2 with value: -564.0634798111196.
Trial 2 finished with value: -564.0634798111196 and parameters: {'lr_critic': 0.008847622045835033, 'lr_actor': 0.01415469664351834}. Best is trial 2 with value: -564.0634798111196.


100%|██████████| 300/300 [08:08<00:00,  1.63s/it]
100%|██████████| 10/10 [00:02<00:00,  3.71it/s]
[I 2024-10-21 16:28:44,079] Trial 3 finished with value: -871.1477523014414 and parameters: {'lr_critic': 0.03143411713036953, 'lr_actor': 0.001009148143980186}. Best is trial 2 with value: -564.0634798111196.


Trial 3 finished with value: -871.1477523014414 and parameters: {'lr_critic': 0.03143411713036953, 'lr_actor': 0.001009148143980186}. Best is trial 2 with value: -564.0634798111196.
Trial 3 finished with value: -871.1477523014414 and parameters: {'lr_critic': 0.03143411713036953, 'lr_actor': 0.001009148143980186}. Best is trial 2 with value: -564.0634798111196.


100%|██████████| 300/300 [08:12<00:00,  1.64s/it]
100%|██████████| 10/10 [00:02<00:00,  3.76it/s]
[I 2024-10-21 16:36:59,276] Trial 4 finished with value: -572.1537357027903 and parameters: {'lr_critic': 0.0018449803348097407, 'lr_actor': 0.00609952170837964}. Best is trial 2 with value: -564.0634798111196.


Trial 4 finished with value: -572.1537357027903 and parameters: {'lr_critic': 0.0018449803348097407, 'lr_actor': 0.00609952170837964}. Best is trial 2 with value: -564.0634798111196.
Trial 4 finished with value: -572.1537357027903 and parameters: {'lr_critic': 0.0018449803348097407, 'lr_actor': 0.00609952170837964}. Best is trial 2 with value: -564.0634798111196.


100%|██████████| 300/300 [08:09<00:00,  1.63s/it]
100%|██████████| 10/10 [00:02<00:00,  3.72it/s]
[I 2024-10-21 16:45:11,903] Trial 5 finished with value: -945.3197452317918 and parameters: {'lr_critic': 0.0017911556288311705, 'lr_actor': 0.0054727245133124585}. Best is trial 2 with value: -564.0634798111196.


Trial 5 finished with value: -945.3197452317918 and parameters: {'lr_critic': 0.0017911556288311705, 'lr_actor': 0.0054727245133124585}. Best is trial 2 with value: -564.0634798111196.
Trial 5 finished with value: -945.3197452317918 and parameters: {'lr_critic': 0.0017911556288311705, 'lr_actor': 0.0054727245133124585}. Best is trial 2 with value: -564.0634798111196.


100%|██████████| 300/300 [08:13<00:00,  1.65s/it]
100%|██████████| 10/10 [00:02<00:00,  3.81it/s]
[I 2024-10-21 16:53:28,366] Trial 6 finished with value: -199.03884512803745 and parameters: {'lr_critic': 0.020209824367021293, 'lr_actor': 0.0018334721290907891}. Best is trial 6 with value: -199.03884512803745.


Trial 6 finished with value: -199.03884512803745 and parameters: {'lr_critic': 0.020209824367021293, 'lr_actor': 0.0018334721290907891}. Best is trial 6 with value: -199.03884512803745.
Trial 6 finished with value: -199.03884512803745 and parameters: {'lr_critic': 0.020209824367021293, 'lr_actor': 0.0018334721290907891}. Best is trial 6 with value: -199.03884512803745.


100%|██████████| 300/300 [08:07<00:00,  1.63s/it]
100%|██████████| 10/10 [00:02<00:00,  3.81it/s]
[I 2024-10-21 17:01:38,779] Trial 7 finished with value: -589.2729938638469 and parameters: {'lr_critic': 0.016821087821542326, 'lr_actor': 0.01858721228385982}. Best is trial 6 with value: -199.03884512803745.


Trial 7 finished with value: -589.2729938638469 and parameters: {'lr_critic': 0.016821087821542326, 'lr_actor': 0.01858721228385982}. Best is trial 6 with value: -199.03884512803745.
Trial 7 finished with value: -589.2729938638469 and parameters: {'lr_critic': 0.016821087821542326, 'lr_actor': 0.01858721228385982}. Best is trial 6 with value: -199.03884512803745.


100%|██████████| 300/300 [08:18<00:00,  1.66s/it]
100%|██████████| 10/10 [00:02<00:00,  3.79it/s]
[I 2024-10-21 17:10:00,391] Trial 8 finished with value: -915.5141203818089 and parameters: {'lr_critic': 0.0007326643593495767, 'lr_actor': 0.017578295606188763}. Best is trial 6 with value: -199.03884512803745.


Trial 8 finished with value: -915.5141203818089 and parameters: {'lr_critic': 0.0007326643593495767, 'lr_actor': 0.017578295606188763}. Best is trial 6 with value: -199.03884512803745.
Trial 8 finished with value: -915.5141203818089 and parameters: {'lr_critic': 0.0007326643593495767, 'lr_actor': 0.017578295606188763}. Best is trial 6 with value: -199.03884512803745.


100%|██████████| 300/300 [08:08<00:00,  1.63s/it]
100%|██████████| 10/10 [00:02<00:00,  3.80it/s]
[I 2024-10-21 17:18:11,672] Trial 9 finished with value: -1352.013637893145 and parameters: {'lr_critic': 0.0010393904990054388, 'lr_actor': 0.0011558949626366295}. Best is trial 6 with value: -199.03884512803745.


Trial 9 finished with value: -1352.013637893145 and parameters: {'lr_critic': 0.0010393904990054388, 'lr_actor': 0.0011558949626366295}. Best is trial 6 with value: -199.03884512803745.
Trial 9 finished with value: -1352.013637893145 and parameters: {'lr_critic': 0.0010393904990054388, 'lr_actor': 0.0011558949626366295}. Best is trial 6 with value: -199.03884512803745.
{'lr_critic': 0.020209824367021293, 'lr_actor': 0.0018334721290907891}
