diff --git a/README.md b/README.md index 21e8cdbc7..833a8562a 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) - [Double DQN](https://arxiv.org/pdf/1509.06461.pdf) - [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf) +- [C51](https://arxiv.org/pdf/1707.06887.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) diff --git a/docs/index.rst b/docs/index.rst index b7e65998c..3b1fe023c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,6 +13,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ * :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ +* :class:`~tianshou.policy.C51Policy` `C51 `_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ * :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ diff --git a/examples/atari/README.md b/examples/atari/README.md index 474f74c42..7fd034461 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -23,3 +23,19 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. We haven't tuned this result to the best, so have fun with playing these hyperparameters! + +# C51 (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 536.6 | ![](results/c51/Breakout_rew.png) | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1032 | ![](results/c51/Enduro_rew.png) | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` | +| QbertNoFrameskip-v4 | 16245 | ![](results/c51/Qbert_rew.png) | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 3133 | ![](results/c51/MsPacman_rew.png) | `python3 atari_c51.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` | + +Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. \ No newline at end of file diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py new file mode 100644 index 000000000..ce465ed2a --- /dev/null +++ b/examples/atari/atari_c51.py @@ -0,0 +1,155 @@ +import os +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import C51Policy +from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.discrete import C51 +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--num-atoms', type=int, default=51) + parser.add_argument('--v-min', type=float, default=-10.) + parser.add_argument('--v-max', type=float, default=10.) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--frames_stack', type=int, default=4) + parser.add_argument('--resume_path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') + return parser.parse_args() + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_c51(args=get_args()): + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = SubprocVectorEnv([lambda: make_atari_env(args) + for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = C51(*args.state_shape, args.action_shape, + args.num_atoms, args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = C51Policy( + net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, + args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device + )) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) + # collector + train_collector = Collector(policy, train_envs, buffer) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'c51') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + def train_fn(epoch, env_step): + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ + (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + writer.add_scalar('train/eps', eps, global_step=env_step) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch(): + print("Testing agent ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * 4) + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_c51(get_args()) diff --git a/examples/atari/results/c51/Breakout_rew.png b/examples/atari/results/c51/Breakout_rew.png new file mode 100644 index 000000000..6b510a04f Binary files /dev/null and b/examples/atari/results/c51/Breakout_rew.png differ diff --git a/examples/atari/results/c51/Enduro_rew.png b/examples/atari/results/c51/Enduro_rew.png new file mode 100644 index 000000000..0e08380a1 Binary files /dev/null and b/examples/atari/results/c51/Enduro_rew.png differ diff --git a/examples/atari/results/c51/MsPacman_rew.png b/examples/atari/results/c51/MsPacman_rew.png new file mode 100644 index 000000000..001af14ee Binary files /dev/null and b/examples/atari/results/c51/MsPacman_rew.png differ diff --git a/examples/atari/results/c51/Pong_rew.png b/examples/atari/results/c51/Pong_rew.png new file mode 100644 index 000000000..c835399d3 Binary files /dev/null and b/examples/atari/results/c51/Pong_rew.png differ diff --git a/examples/atari/results/c51/Qbert_rew.png b/examples/atari/results/c51/Qbert_rew.png new file mode 100644 index 000000000..47ee25ed9 Binary files /dev/null and b/examples/atari/results/c51/Qbert_rew.png differ diff --git a/examples/atari/results/c51/Seaquest_rew.png b/examples/atari/results/c51/Seaquest_rew.png new file mode 100644 index 000000000..6cc069fd8 Binary files /dev/null and b/examples/atari/results/c51/Seaquest_rew.png differ diff --git a/examples/atari/results/c51/SpaceInvader_rew.png b/examples/atari/results/c51/SpaceInvader_rew.png new file mode 100644 index 000000000..108cf04be Binary files /dev/null and b/examples/atari/results/c51/SpaceInvader_rew.png differ diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 69649e864..a8bdc7c9d 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -76,6 +76,10 @@ def target_q_fn(buffer, indice): return torch.tensor(-buffer.rew[indice], dtype=torch.float32) +def target_q_fn_multidim(buffer, indice): + return target_q_fn(buffer, indice).unsqueeze(1).repeat(1, 51) + + def compute_nstep_return_base(nstep, gamma, buffer, indice): returns = np.zeros_like(indice, dtype=np.float) buf_len = len(buffer) @@ -108,6 +112,10 @@ def test_nstep_returns(size=10000): assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indice) assert np.allclose(returns, r_), (r_, returns) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns')) @@ -115,6 +123,10 @@ def test_nstep_returns(size=10000): 3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indice) assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns')) @@ -122,6 +134,10 @@ def test_nstep_returns(size=10000): 3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indice) assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) if __name__ == '__main__': buf = ReplayBuffer(size) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 6057dfc6b..e947571fc 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -4,7 +4,7 @@ from tianshou.utils import MovAvg from tianshou.utils import SummaryWriter from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import DQN +from tianshou.utils.net.discrete import DQN, C51 from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -61,6 +61,10 @@ def test_net(): expect_output_shape = [bsz, *action_shape] net = DQN(*state_shape, action_shape) assert list(net(data)[0].shape) == expect_output_shape + num_atoms = 51 + net = C51(*state_shape, action_shape, num_atoms) + expect_output_shape = [bsz, *action_shape, num_atoms] + assert list(net(data)[0].shape) == expect_output_shape def test_summary_writer(): diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py new file mode 100644 index 000000000..f017955e6 --- /dev/null +++ b/test/discrete/test_c51.py @@ -0,0 +1,135 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import C51Policy +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--num-atoms', type=int, default=51) + parser.add_argument('--v-min', type=float, default=-10.) + parser.add_argument('--v-max', type=float, default=10.) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=3) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', type=int, default=0) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_c51(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, args.action_shape, args.device, + softmax=True, num_atoms=args.num_atoms) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = C51Policy( + net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, + args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # buffer + if args.prioritized_replay > 0: + buf = PrioritizedReplayBuffer( + args.buffer_size, alpha=args.alpha, beta=args.beta) + else: + buf = ReplayBuffer(args.buffer_size) + # collector + train_collector = Collector(policy, train_envs, buf) + test_collector = Collector(policy, test_envs) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size) + # log + log_path = os.path.join(args.logdir, args.task, 'c51') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) + + assert stop_fn(result['best_reward']) + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +def test_pc51(args=get_args()): + args.prioritized_replay = 1 + args.gamma = .95 + args.seed = 1 + test_c51(args) + + +if __name__ == '__main__': + test_c51(get_args()) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 456993d5d..9626d97eb 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -2,6 +2,7 @@ from tianshou.policy.random import RandomPolicy from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.ddpg import DDPGPolicy @@ -18,6 +19,7 @@ "RandomPolicy", "ImitationPolicy", "DQNPolicy", + "C51Policy", "PGPolicy", "A2CPolicy", "DDPGPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 809751fe7..6e172dfd7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -245,7 +245,7 @@ def compute_nstep_return( to False. :return: a Batch. The result will be stored in batch.returns as a - torch.Tensor with shape (bsz, ). + torch.Tensor with the same shape as target_q_fn's return tensor. """ rew = buffer.rew if rew_norm: @@ -257,12 +257,11 @@ def compute_nstep_return( mean, std = 0.0, 1.0 buf_len = len(buffer) terminal = (indice + n_step - 1) % buf_len - target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) + target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch) target_q = _nstep_return(rew, buffer.done, target_q, indice, gamma, n_step, len(buffer), mean, std) - batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) @@ -275,7 +274,7 @@ def _compile(self) -> None: i64 = np.array([0, 1], dtype=np.int64) _episodic_return(f64, f64, b, 0.1, 0.1) _episodic_return(f32, f64, b, 0.1, 0.1) - _nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0) + _nstep_return(f64, b, f32, i64, 0.1, 1, 4, 0.0, 1.0) @njit @@ -311,13 +310,18 @@ def _nstep_return( std: float, ) -> np.ndarray: """Numba speedup: 0.3s -> 0.15s.""" - returns = np.zeros(indice.shape) + target_shape = target_q.shape + bsz = target_shape[0] + # change target_q to 2d array + target_q = target_q.reshape(bsz, -1) + returns = np.zeros(target_q.shape) gammas = np.full(indice.shape, n_step) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n returns[done[now] > 0] = 0.0 - returns = (rew[now] - mean) / std + gamma * returns + returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns target_q[gammas != n_step] = 0.0 + gammas = gammas.reshape(-1, 1) target_q = target_q * (gamma ** gammas) + returns - return target_q + return target_q.reshape(target_shape) diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py new file mode 100644 index 000000000..96cc6801a --- /dev/null +++ b/tianshou/policy/modelfree/c51.py @@ -0,0 +1,143 @@ +import torch +import numpy as np +from typing import Any, Dict, Union, Optional + +from tianshou.policy import DQNPolicy +from tianshou.data import Batch, ReplayBuffer, to_numpy + + +class C51Policy(DQNPolicy): + """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int num_atoms: the number of atoms in the support set of the + value distribution, defaults to 51. + :param float v_min: the value of the smallest atom in the support set, + defaults to -10.0. + :param float v_max: the value of the largest atom in the support set, + defaults to 10.0. + :param int estimation_step: greater than 1, the number of steps to look + ahead. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + num_atoms: int = 51, + v_min: float = -10.0, + v_max: float = 10.0, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(model, optim, discount_factor, estimation_step, + target_update_freq, reward_normalization, **kwargs) + assert num_atoms > 1, "num_atoms should be greater than 1" + assert v_min < v_max, "v_max should be larger than v_min" + self._num_atoms = num_atoms + self._v_min = v_min + self._v_max = v_max + self.support = torch.nn.Parameter( + torch.linspace(self._v_min, self._v_max, self._num_atoms), + requires_grad=False, + ) + self.delta_z = (v_max - v_min) / (num_atoms - 1) + + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: + return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms] + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: str = "model", + input: str = "obs", + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 2 keys: + + * ``act`` the action. + * ``state`` the hidden state. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.DQNPolicy.forward` for + more detailed explanation. + """ + model = getattr(self, model) + obs = batch[input] + obs_ = obs.obs if hasattr(obs, "obs") else obs + dist, h = model(obs_, state=state, info=batch.info) + q = (dist * self.support).sum(2) + act: np.ndarray = to_numpy(q.max(dim=1)[1]) + if hasattr(obs, "mask"): + # some of actions are masked, they cannot be selected + q_: np.ndarray = to_numpy(q) + q_[~obs.mask] = -np.inf + act = q_.argmax(axis=1) + # add eps to act in training or testing phase + if not self.updating and not np.isclose(self.eps, 0.0): + for i in range(len(q)): + if np.random.rand() < self.eps: + q_ = np.random.rand(*q[i].shape) + if hasattr(obs, "mask"): + q_[~obs.mask[i]] = -np.inf + act[i] = q_.argmax() + return Batch(logits=dist, act=act, state=h) + + def _target_dist(self, batch: Batch) -> torch.Tensor: + if self._target: + a = self(batch, input="obs_next").act + next_dist = self( + batch, model="model_old", input="obs_next" + ).logits + else: + next_b = self(batch, input="obs_next") + a = next_b.act + next_dist = next_b.logits + next_dist = next_dist[np.arange(len(a)), a, :] + target_support = batch.returns.clamp(self._v_min, self._v_max) + # An amazing trick for calculating the projection gracefully. + # ref: https://github.com/ShangtongZhang/DeepRL + target_dist = (1 - (target_support.unsqueeze(1) - + self.support.view(1, -1, 1)).abs() / self.delta_z + ).clamp(0, 1) * next_dist.unsqueeze(1) + return target_dist.sum(-1) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._target and self._cnt % self._freq == 0: + self.sync_weight() + self.optim.zero_grad() + with torch.no_grad(): + target_dist = self._target_dist(batch) + weight = batch.pop("weight", 1.0) + curr_dist = self(batch).logits + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :] + cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1) + loss = (cross_entropy * weight).mean() + # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 + batch.weight = cross_entropy.detach() # prio-buffer + loss.backward() + self.optim.step() + self._cnt += 1 + return {"loss": loss.item()} diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 8b52a87a0..40f050af8 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -32,6 +32,8 @@ class Net(nn.Module): (for Dueling DQN), defaults to False. :param norm_layer: use which normalization before ReLU, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None. + :param int num_atoms: in order to expand to the net of distributional RL, + defaults to 1. """ def __init__( @@ -45,11 +47,14 @@ def __init__( hidden_layer_size: int = 128, dueling: Optional[Tuple[int, int]] = None, norm_layer: Optional[Callable[[int], nn.modules.Module]] = None, + num_atoms: int = 1, ) -> None: super().__init__() self.device = device self.dueling = dueling self.softmax = softmax + self.num_atoms = num_atoms + self.action_num = np.prod(action_shape) input_size = np.prod(state_shape) if concat: input_size += np.prod(action_shape) @@ -62,7 +67,8 @@ def __init__( if dueling is None: if action_shape and not concat: - model += [nn.Linear(hidden_layer_size, np.prod(action_shape))] + model += [nn.Linear( + hidden_layer_size, num_atoms * self.action_num)] else: # dueling DQN q_layer_num, v_layer_num = dueling Q, V = [], [] @@ -75,8 +81,9 @@ def __init__( hidden_layer_size, hidden_layer_size, norm_layer) if action_shape and not concat: - Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))] - V += [nn.Linear(hidden_layer_size, 1)] + Q += [nn.Linear( + hidden_layer_size, num_atoms * self.action_num)] + V += [nn.Linear(hidden_layer_size, num_atoms)] self.Q = nn.Sequential(*Q) self.V = nn.Sequential(*V) @@ -94,7 +101,12 @@ def forward( logits = self.model(s) if self.dueling is not None: # Dueling DQN q, v = self.Q(logits), self.V(logits) + if self.num_atoms > 1: + v = v.view(-1, 1, self.num_atoms) + q = q.view(-1, self.action_num, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v + elif self.num_atoms > 1: + logits = logits.view(-1, self.action_num, self.num_atoms) if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 0a8452ced..bff229cd7 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -130,3 +130,36 @@ def forward( if not isinstance(x, torch.Tensor): x = to_torch(x, device=self.device, dtype=torch.float32) return self.net(x), state + + +class C51(DQN): + """Reference: A distributional perspective on reinforcement learning. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + c: int, + h: int, + w: int, + action_shape: Sequence[int], + num_atoms: int = 51, + device: Union[str, int, torch.device] = "cpu", + ) -> None: + super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device) + self.action_shape = action_shape + self.num_atoms = num_atoms + + def forward( + self, + x: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: + r"""Mapping: x -> Z(x, \*).""" + x, state = super().forward(x) + x = x.view(-1, self.num_atoms).softmax(dim=-1) + x = x.view(-1, np.prod(self.action_shape), self.num_atoms) + return x, state