本章展示了天授与一些著名的强化学习算法平台对比评测的结果。
我们在众多强化学习平台中选取了5个具有代表性的平台:RLlib rllib
、OpenAI Baselines baselines
、PyTorch-DRL pytorch-drl
、Stable-Baselines stable-baselines
(以下简写为SB)、rlpyt rlpyt
,选取评测的平台列表和原因如 表 4.1 所示。
平台名称 | 评测ID | 选取原因 |
---|---|---|
RLlib rllib |
0.8.5 |
GitHub深度强化学习平台星标数目最多,算法实现全面 |
Baselines baselines |
ea25b9e |
GitHub深度强化学习平台星标数目第二多 |
PyTorch-DRL pytorch-drl |
49b5ec0 |
GitHub PyTorch深度强化学习平台星标数目第一多,算法实现全面 |
Stable-Baselines stable-baselines |
2.10.0 |
Baselines的改进版本,算法实现全面,提供了一系列调优的参数 |
rlpyt rlpyt |
668290d |
GitHub PyTorch深度强化学习平台星标数目第四多,算法实现全面 |
天授 | 57bca16 |
- |
接下来将从算法支持、模块化、定制化、单元测试和文档教程这五个维度来对包括天授在内的6个深度强化学习平台进行功能维度上的对比评测。
深度强化学习算法主要分为免模型强化学习(MFRL)、基于模型的强化学习(MBRL)、多智能体学习(MARL)、模仿学习(IL)等。此外根据所解决的问题分类,还可分为马尔科夫决策过程(MDP)和部分可观测马尔科夫决策过程(POMDP),其中POMDP要求策略网络支持循环神经网络(RNN)的训练。如今研究者们使用最多的是免模型强化学习算法,以下会对其详细对比。
免模型强化学习算法主要分为基于策略梯度的算法、基于价值函数的算法和二者结合的算法。现在深度强化学习社区公认的强化学习经典算法有:(1)基于价值函数:DQN dqn
及其改进版本Double-DQN double-dqn
(DDQN)、DQN优先级经验重放 per
(PDQN);(2)基于策略梯度:PG pg
、A2C a2c
、PPO ppo
;(3)二者结合:DDPG ddpg
、TD3 td3
、SAC sac
。
各个平台实现算法的程度如 表 4.2 所示。可以看出,大部分平台支持的算法种类是全面的,有些平台如Baselines支持的算法类型并不全面。天授平台支持所有的这些算法。
表 4.2:各平台支持的免模型深度强化学习算法一览平台与算法 | DQN | DDQN | PDQN | PG | A2C | PPO | DDPG | TD3 | SAC | 总计 |
---|---|---|---|---|---|---|---|---|---|---|
RLlib | 9 | |||||||||
Baselines | × | × | × | × | 5 | |||||
PyTorch-DRL | 9 | |||||||||
SB | × | 8 | ||||||||
rlpyt | 9 | |||||||||
天授 | 9 |
其他类型的强化学习算法包括基于模型的强化学习(MBRL)、多智能体强化学习(MARL)、元强化学习(MetaRL)、模仿学习(IL)。表 4.3 列出了各个平台的支持情况。可以看出,少有平台支持所有这些类型的算法。天授支持了模仿学习,但值得一提的是,基于模型的强化学习算法和多智能体强化学习算法都可以在现有的平台接口上完整实现。我们正在努力实现天授平台的MBRL和MARL算法中。
表 4.3:各平台支持的其他类型强化学习算法一览平台与算法类型 | MBRL | MARL | MetaRL | IL |
---|---|---|---|---|
RLlib | × | × | ||
Baselines | × | × | × | |
PyTorch-DRL | × | × | × | × |
SB | × | × | × | |
rlpyt | × | × | × | × |
天授 | × | × | × |
针对不完全信息观测的马尔科夫决策过程(POMDP),通常有两种处理方式:第一种是直接当作完全信息模式处理,但可能会导致一些诸如收敛性难以保证的问题;第二种是在智能体中维护一个内部状态,具体而言,将循环神经网络模型(RNN)融合到策略网络中。表 4.4 列出了各个平台对循环神经网络的支持程度。从 表 4.4 中可以看出,部分平台对RNN的支持程度并不大。天授平台中所有算法均支持RNN网络,还支持获取历史状态、历史动作和历史奖励,以及其他用户或者环境定义的变量的历史记录。
表 4.4:各平台对RNN的支持平台 | RNN |
---|---|
RLlib | |
Baselines | × |
PyTorch-DRL | × |
SB | × |
rlpyt | |
天授 |
最初的强化学习算法仅是单个智能体和单个环境进行交互,这样的话采样效率较低,因为每一次网络前向都只能以单个样本进行计算,无法充分利用批处理加速的优势,从而导致了强化学习即使在简单场景中训练速度仍然较慢的问题。解决的方案是并行环境采样:智能体每次与若干个环境同时进行交互,将神经网络的前向数据量加大但又不增加推理时间,从而做到采样速率是之前的数倍。表 4.5 显示了各个平台、各个算法支持的并行环境采样的情况。从 表 4.5 中可以看出,只有RLlib、rlpyt、天授全面地支持了各种算法的并行环境采样功能,剩下的平台要么缺失部分算法实现、要么缺失部分算法的并行环境采样功能。这对于强化学习智能体的训练而言,性能方面可能会大打折扣。
表 4.5:各平台各免模型深度强化学习算法支持并行环境采样情况一览平台与算法 | DQN | DDQN | PDQN | PG | A2C | PPO | DDPG | TD3 | SAC |
---|---|---|---|---|---|---|---|---|---|
RLlib | |||||||||
Baselines | × | - | × | - | - | - | |||
PyTorch-DRL | × | × | × | × | × | × | × | ||
SB | × | × | - | × | × | ||||
rlpyt | |||||||||
天授 |
注:“-”表示算法未实现
模块化的强化学习算法框架能够让开发者以更少的代码量来更简单地实现新功能,在增加了代码的可重用性的同时也减少了出错的可能性。表 4.6 列出了各个平台模块化的详细情况。从表中可以看出,除Baselines、Stable-Baselines和PyTorch-DRL三个框架外,其余的平台都做到了模块化。天授平台并没有在训练策略上做完全的模块化,因为在训练策略模块化虽然会节省代码,但是会使得用户难以二次修改代码进行开发。天授为了能够让开发者有更好的体验,在这二者之中做了折中:提供了一个定制化的训练策略函数,但不是必须的。用户可以利用天授提供的接口,和正常写强化学习代码一样,自由地编写所需训练策略。
表 4.6:各平台模块化功能实现一览,其中:(1)算法实现模块化,指实现强化学习算法的时候遵循一套统一的接口;(2)数据处理模块化,指将内部数据流进行封装存储;(3)训练策略模块化,指由专门的类或函数来处理如何训练强化学习智能体。平台与模块化 | 算法实现 | 数据处理 | 训练策略 |
---|---|---|---|
RLlib | |||
Baselines | × | × | × |
PyTorch-DRL | 部分模块化 | × | |
SB | × | × | × |
rlpyt | |||
天授 | 部分模块化 |
强化学习平台除了具有作为社区研究者中复现其他算法结果的作用之外,还承担在新场景、新任务上的运用和新算法的开发的作用,此时一个平台是否具有清晰简洁的代码结构、是否支持二次开发、是否能够方便地运用于新的任务(比如多模态环境)上,就成为一个衡量平台易用性的一个标准。 表 4.7 总结了各平台在代码复杂度与是否可定制化训练环境两个维度的测试结果,其中前者采用开源工具cloc1 进行代码统计,除去了测试代码和示例代码;后者采用Mujoco环境中 FetchReach-v1 任务进行模拟测试,其观测状态为一个字典,包含三个元素。此处使用这个任务来模拟对定制化多模态环境的测试,凡是报异常错误或者直接使用装饰器 gym.wrappers.FlattenObservation()
对观测值进行数据扁平化处理的平台,都不被认为对定制化训练环境做到了很好的支持。可以看出,天授在易用性的这两个评价层面上相比其他平台都具有十分明显的优势,使用精简的代码却能够支持更多需求。
平台与易用性 | 代码复杂度 | 环境定制化 | 文档 | 教程 |
---|---|---|---|---|
RLlib | 250/24065 | × | ||
Baselines | 110/10499 | × | × | × |
PyTorch-DRL | 55/4366 | × | × | × |
SB | 100/10989 | × | ||
rlpyt | 243/14487 | × | × | |
天授 | 29/2141 |
文档与教程对于平台的易用性而言具有十分重要的意义。表 4.7 列举出了各个平台的API接口文档与教程的情况。尽管天授的文档与Stable-Baselines相比还有待提高,但相比其它平台而言仍然提供了丰富的教程,供使用者使用。
单元测试对强化学习平台有着十分重要的作用:它在本身就难以训练的强化学习算法上加上了一个保险栓,进行代码正确性检查,避免了一些低级的错误发生,同时还保证了一些基础算法的可复现性。表 4.8 从代码风格测试、基本功能测试、训练过程测试、代码覆盖率这些维度展示了各个平台所拥有的单元测试。大部分平台满足代码风格测试和基本功能测试要求,只有约一半的平台有对完整训练过程进行测试(此处指从智能体的神经网络随机初始化至智能体完全解决问题),以及显示代码覆盖率。综合来看,天授平台是其中表现最好的。
表 4.8:各平台单元测试情况一览平台与单元测试 | PEP8代码风格 | 基本功能 | 训练过程 | 代码覆盖率 |
---|---|---|---|---|
RLlib | 部分 | 暂缺* | ||
Baselines | 部分 | 53% * * | ||
PyTorch-DRL | 不遵循+无测试 | 完整 | 62% * * | |
SB | 部分 | 85% | ||
rlpyt | × | 部分 | 部分 | 22% |
天授 | 完整 | 85% |
**:手动在其单元测试脚本中添加代码覆盖率开启选项,并在 Travis CI 第三方测试平台中获取测试结果。
本章节将各个强化学习平台在OpenAI Gym gym
简单环境中进行性能测试。实验运行环境配置参数如 表 4.9 所示。所有运行实验耗时取纯CPU和CPU+GPU混合使用的这两种运行状态模式下的时间的最优值。为减小测试结果误差,每组实验将会以不同的随机种子运行5次。
类型 | 参数 |
---|---|
操作系统 | Ubuntu 18.04 |
内核 | 5.3.0-53-generic |
CPU | Intel i7-8750H (12) @ 4.100GHz |
GPU | NVIDIA GeForce GTX 1060 Mobile |
RAM | 31.1 GiB DDR4 |
Disk | SAMSUNG MZVLB512HAJQ-000L2 SSD |
NVIDIA驱动版本 | 440.64.00 |
CUDA版本 | 10.0 |
Python版本 | 3.6.9 |
TensorFlow版本 | 1.14.0 |
PyTorch版本 | 1.4.0(PyTorch-DRL)或 1.5.0 |
离散动作空间的一系列强化学习任务中,最简单的任务是OpenAI Gym环境中的CartPole-v0任务:该任务要求智能体操纵小车,使得小车上的倒立摆能够保持垂直状态,一旦偏离超过一定角度、或者小车位置超出规定范围,则认为游戏结束。该任务观测空间为一个四维向量,动作空间取值为0或1,表示在这个时间节点内将小车向左或是向右移动。图 4.1 对该任务进行了可视化展示。
图 4.1:CartPole-v0任务可视化该任务选取PG pg
、DQN dqn
、A2C a2c
、PPO ppo
四种经典的免模型强化学习算法进行评测。根据Gym中说明的规则,每个算法必须在连续100次任务中,总奖励值取平均之后大于等于195才算解决了这个任务。各个平台不同算法解决任务的测试结果如 表 4.10 所示,原始数据见 附表 1
。天授与其他平台相比,有着令人惊艳的性能,尤其是PG、DQN和A2C算法,能够在平均不到10秒的时间内解决该问题。
平台与算法 | PG | DQN | A2C | PPO |
---|---|---|---|---|
RLlib | 19.26 ± 2.29 | 28.56 ± 4.60 | 57.92 ± 9.94 | 44.60 ± 17.04 |
Baselines | - | × | × | × |
PyTorch-DRL * | × | 31.58 ± 11.30 | × | 23.99 ± 9.26 |
SB | - | 93.47 ± 58.05 | 57.56 ± 12.87 | 34.79 ± 17.02 |
rlpyt | * * | * * | * * | * * |
天授 | 6.09 ± 4.60 | 6.09 ± 0.87 | 10.59 ± 2.04 | 31.82 ± 7.76 |
*:由于PyTorch-DRL中并未实现专门的评测函数,因此适当放宽条件为“训练过程中连续20次完整游戏的平均总奖励大于等于195”;
**:rlpyt对于离散动作空间非Atari任务的支持不友好,可参考 astooke/rlpyt#135 。
连续动作空间的一系列强化学习任务中,最简单的任务是OpenAI Gym环境中的Pendulum-v0任务:该任务要求智能体操控倒立摆,使其尽量保持直立,奖励值最大对应着与目标保持垂直,并且旋转速度和扭矩均为最小的状态。该任务观测空间为一个三维向量,动作空间为一个二维向量,范围为 [ − 2, 2]。图 4.2 对该任务进行了可视化展示。
图 4.2:Pendulum-v0任务可视化该任务选取PPO ppo
、DDPG ddpg
、TD3 td3
、SAC sac
四种经典的免模型强化学习算法进行评测。和上一小节中的评测方法类似,每个算法必须在连续100次任务中,总奖励值取平均值后大于等于-250才算解决该任务。各个平台不同算法解决任务的测试结果如 表 4.11 所示,原始数据见 附表 2
。与之前结果类似,天授平台在各个算法中的测试都取得了不错的成绩。
平台与算法 | PPO | DDPG | TD3 | SAC |
---|---|---|---|---|
RLlib | 123.62 ± 44.23 | 314.70 ± 7.92 | 149.90 ± 7.54 | 97.42 ± 4.75 |
Baselines | 745.43 ± 160.82 | × | - | - |
PyTorch-DRL * | * * | 59.05 ± 10.03 | 57.52 ± 17.71 | 63.80 ± 27.37 |
SB | 259.73 ± 27.37 | 277.52 ± 92.67 | 99.75 ± 21.63 | 124.85 ± 79.14 |
rlpyt | * * * | 123.57 ± 30.76 | 113.00 ± 13.31 | 132.80 ± 21.74 |
天授 | 16.18 ± 2.49 | 37.26 ± 9.55 | 44.04 ± 6.37 | 36.02 ± 0.77 |
*:由于PyTorch-DRL中并未实现专门的评测函数,因此适当放宽条件为“训练过程中连续20次完整游戏的平均总奖励大于等于-250”;
**:PyTorch-DRL中的PPO算法在连续动作空间任务中会报异常错误;
***:rlpyt并未提供使用PPO算法的任何示例代码,经尝试无法成功跑通。
本章节将天授平台与比较流行的5个深度强化学习平台进行功能维度和性能维度的对比。实验结果表明天授与其他平台相比,具有模块化、实现简洁、代码质量可靠、用户易用、速度快等等优点。
+-----------+--------+---------+----------+---------+---------+----------+----------+----------+ | 平台 | 算法 | 1 | 2 | 3 | 4 | 5 | 平均值 | 标准差 | +===========+========+=========+==========+=========+=========+==========+==========+==========+ | | PG | 19.43 | 17.62 | 18.27 | 17.38 | 23.61 | 19.26 | 2.29 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | DQN | 36.21 | 27.79 | 25.82 | 30.42 | 22.57 | 28.56 | 4.60 | + RLlib +--------+---------+----------+---------+---------+----------+----------+----------+ | | A2C | 42.84 | 55.18 | 63.22 | 55.45 | 72.91 | 57.92 | 9.94 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | PPO | 27.18 | 29.08 | 54.44 | 39.65 | 72.64 | 44.60 | 17.04 | +-----------+--------+---------+----------+---------+---------+----------+----------+----------+ | | PG | 未实现 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | DQN | 超过1000秒未完成任务,但在1000秒之后完成 | + Baselines +--------+---------+----------+---------+---------+----------+----------+----------+ | | A2C | | + +--------+ 超过1000秒未完成任务,且不能收敛 + | | PPO | | +-----------+--------+---------+----------+---------+---------+----------+----------+----------+ | | PG | 超过1000秒未完成任务,且不能收敛 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | DQN | 24.21 | 53.96 | 24.42 | 28.17 | 27.12 | 31.58 | 11.30 | +PyTorch-DRL+--------+---------+----------+---------+---------+----------+----------+----------+ A2C | 超过1000秒未完成任务,且不能收敛 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | PPO | 9.30 | 21.11 | 22.26 | 30.91 | 36.39 | 23.99 | 9.26 | +-----------+--------+---------+----------+---------+---------+----------+----------+----------+ | | PG | 未实现 | + +--------+---------+----------+---------+---------+----------+----------+----------+ DQN | 45.84 | 108.08 | 51.31 | 59.56 | 202.58 | 93.47 | 58.05 | +Baselines +--------+---------+----------+---------+---------+----------+----------+----------+ | | A2C | 81.00 | 44.06 | 56.70 | 47.81 | 58.23 | 57.56 | 12.87 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | PPO | 20.64 | 53.35 | 21.50 | 57.78 | 20.67 | 34.79 | 17.02 | +-----------+--------+---------+----------+---------+---------+----------+----------+----------+ | | PG | | + +--------+ + | | DQN | rlpyt对于离散动作空间非Atari任务的支持不友好, | + rlpyt +--------+ 可参考 astooke/rlpyt#135 + | | A2C | | + +--------+ + | | PPO | | +-----------+--------+---------+----------+---------+---------+----------+----------+----------+ | | PG | 1.65 | 4.98 | 14.79 | 6.01 | 3.03 | 6.09 | 4.60 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | DQN | 5.14 | 6.32 | 7.62 | 5.41 | 5.97 | 6.09 | 0.87 | + 天授 +--------+---------+----------+---------+---------+----------+----------+----------+ | | A2C | 9.54 | 12.06 | 8.17 | 9.40 | 13.80 | 10.59 | 2.04 | + +--------+---------+----------+---------+---------+----------+----------+----------+ | | PPO | 30.12 | 25.21 | 43.53 | 22.63 | 37.59 | 31.82 | 7.76 | +-----------+--------+---------+----------+---------+---------+----------+----------+----------+
附表 1:CartPole-v0实验原始数据
*:由于PyTorch-DRL中并未实现专门的评测函数,因此适当放宽条件为“训练过程中连续20次完整游戏的平均总奖励大于等于195”。
+-----------+--------+----------+----------+----------+----------+----------+----------+----------+ | 平台 | 算法 | 1 | 2 | 3 | 4 | 5 | 平均值 | 标准差 | +===========+========+==========+==========+==========+==========+==========+==========+==========+ | | PPO | 126.91 | 105.82 | 131.34 | 195.46 | 58.56 | 123.62 | 44.23 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | DDPG | 312.93 | 329.85 | 307.26 | 313.70 | 309.75 | 314.70 | 7.92 | + RLlib +--------+----------+----------+----------+----------+----------+----------+----------+ | | TD3 | 139.18 | 158.29 | 144.52 | 158.24 | 149.29 | 149.90 | 7.54 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | SAC | 102.93 | 95.21 | 89.95 | 102.04 | 96.98 | 97.42 | 4.75 | +-----------+--------+----------+----------+----------+----------+----------+----------+----------+ | | PPO | 804.92 | 832.88 | 444.79 | 733.01 | 911.53 | 745.43 | 160.82 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | DDPG | 超过1000秒未完成任务,且不能收敛 | + Baselines +--------+----------+----------+----------+----------+----------+----------+----------+ | | TD3 | | + +--------+ 未实现 + | | SAC | | +-----------+--------+----------+----------+----------+----------+----------+----------+----------+ | | PPO | PyTorch-DRL中的PPO算法在连续动作空间任务中会报异常错误 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | DDPG | 42.50 | 56.21 | 69.02 | 57.53 | 69.99 | 59.05 | 10.03 | +PyTorch-DRL+--------+----------+----------+----------+----------+----------+----------+----------+ TD3 | 43.97 | 46.44 | 46.06 | 91.04 | 60.10 | 57.52 | 17.71 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | SAC | 113.88 | 37.82 | 40.08 | 64.38 | 62.84 | 63.80 | 27.37 | +-----------+--------+----------+----------+----------+----------+----------+----------+----------+ | | PPO | 206.71 | 284.84 | 271.73 | 271.81 | 263.58 | 259.73 | 27.37 | + +--------+----------+----------+----------+----------+----------+----------+----------+ DDPG | 206.58 | 384.53 | 135.68 | 140.45 | 270.36 | 277.52 | 92.67 | +Baselines +--------+----------+----------+----------+----------+----------+----------+----------+ | | TD3 | 86.22 | 142.88 | 91.53 | 88.77 | 89.34 | 99.75 | 21.63 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | SAC | 251.22 | 123.47 | 165.39 | 42.07 | 42.10 | 124.85 | 79.14 | +-----------+--------+----------+----------+----------+----------+----------+----------+----------+ | | PPO | rlpyt并未提供使用PPO的任何示例代码,经尝试无法成功跑通 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | DDPG | 180.56 | 130.14 | 105.95 | 106.69 | 94.51 | 123.57 | 30.76 | + rlpyt +--------+----------+----------+----------+----------+----------+----------+----------+ | | TD3 | 106.37 | 98.42 | 136.02 | 119.05 | 105.12 | 113.00 | 13.31 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | SAC | 122.58 | 169.20 | 104.50 | 141.96 | 125.77 | 132.80 | 21.74 | +-----------+--------+----------+----------+----------+----------+----------+----------+----------+ | | PPO | 17.64 | 14.97 | 20.29 | 13.28 | 14.70 | 16.18 | 2.49 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | DDPG | 24.34 | 51.15 | 30.25 | 36.46 | 44.09 | 37.26 | 9.55 | + 天授 +--------+----------+----------+----------+----------+----------+----------+----------+ | | TD3 | 38.22 | 52.67 | 42.15 | 50.32 | 36.85 | 44.04 | 6.37 | + +--------+----------+----------+----------+----------+----------+----------+----------+ | | SAC | 35.56 | 35.08 | 35.61 | 36.83 | 37.04 | 36.02 | 0.77 | +-----------+--------+----------+----------+----------+----------+----------+----------+----------+
附表 2:Pendulum-v0实验原始数据
*:由于PyTorch-DRL中并未实现专门的评测函数,因此适当放宽条件为“训练过程中连续20次完整游戏的平均总奖励大于等于-250”。
GitHub地址: https://github.com/AlDanial/cloc↩