Skip to content

Commit

Permalink
Add C51 algorithm (#266)
Browse files Browse the repository at this point in the history
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887

1. add C51 policy in tianshou/policy/modelfree/c51.py.
2. add C51 net in tianshou/utils/net/discrete.py.
3. add C51 atari example in examples/atari/atari_c51.py.
4. add C51 statement in tianshou/policy/__init__.py.
5. add C51 test in test/discrete/test_c51.py.
6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get  best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
  • Loading branch information
shengxiang19 committed Jan 6, 2021
1 parent 5d13d8a commit c6f2648
Show file tree
Hide file tree
Showing 19 changed files with 533 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -23,6 +23,7 @@
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [C51](https://arxiv.org/pdf/1707.06887.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Expand Up @@ -13,6 +13,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `C51 <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
Expand Down
16 changes: 16 additions & 0 deletions examples/atari/README.md
Expand Up @@ -23,3 +23,19 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.

We haven't tuned this result to the best, so have fun with playing these hyperparameters!

# C51 (single run)

One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64` |
| BreakoutNoFrameskip-v4 | 536.6 | ![](results/c51/Breakout_rew.png) | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
| EnduroNoFrameskip-v4 | 1032 | ![](results/c51/Enduro_rew.png) | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` |
| QbertNoFrameskip-v4 | 16245 | ![](results/c51/Qbert_rew.png) | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 3133 | ![](results/c51/MsPacman_rew.png) | `python3 atari_c51.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` |

Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
155 changes: 155 additions & 0 deletions examples/atari/atari_c51.py
@@ -0,0 +1,155 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import C51Policy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import C51
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_c51(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = C51(*args.state_shape, args.action_shape,
args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False

def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
print("Testing agent ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * 4)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)

pprint.pprint(result)
watch()


if __name__ == '__main__':
test_c51(get_args())
Binary file added examples/atari/results/c51/Breakout_rew.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/c51/Enduro_rew.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/c51/MsPacman_rew.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/c51/Pong_rew.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/c51/Qbert_rew.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/c51/Seaquest_rew.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/c51/SpaceInvader_rew.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions test/base/test_returns.py
Expand Up @@ -76,6 +76,10 @@ def target_q_fn(buffer, indice):
return torch.tensor(-buffer.rew[indice], dtype=torch.float32)


def target_q_fn_multidim(buffer, indice):
return target_q_fn(buffer, indice).unsqueeze(1).repeat(1, 51)


def compute_nstep_return_base(nstep, gamma, buffer, indice):
returns = np.zeros_like(indice, dtype=np.float)
buf_len = len(buffer)
Expand Down Expand Up @@ -108,20 +112,32 @@ def test_nstep_returns(size=10000):
assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
r_ = compute_nstep_return_base(1, .1, buf, indice)
assert np.allclose(returns, r_), (r_, returns)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 2
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns'))
assert np.allclose(returns, [
3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
r_ = compute_nstep_return_base(2, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 10
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns'))
assert np.allclose(returns, [
3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
r_ = compute_nstep_return_base(10, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])

if __name__ == '__main__':
buf = ReplayBuffer(size)
Expand Down
6 changes: 5 additions & 1 deletion test/base/test_utils.py
Expand Up @@ -4,7 +4,7 @@
from tianshou.utils import MovAvg
from tianshou.utils import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import DQN
from tianshou.utils.net.discrete import DQN, C51
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic

Expand Down Expand Up @@ -61,6 +61,10 @@ def test_net():
expect_output_shape = [bsz, *action_shape]
net = DQN(*state_shape, action_shape)
assert list(net(data)[0].shape) == expect_output_shape
num_atoms = 51
net = C51(*state_shape, action_shape, num_atoms)
expect_output_shape = [bsz, *action_shape, num_atoms]
assert list(net(data)[0].shape) == expect_output_shape


def test_summary_writer():
Expand Down
135 changes: 135 additions & 0 deletions test/discrete/test_c51.py
@@ -0,0 +1,135 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import C51Policy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', type=int, default=0)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args


def test_c51(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device,
softmax=True, num_atoms=args.num_atoms)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay > 0:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.beta)
else:
buf = ReplayBuffer(args.buffer_size)
# collector
train_collector = Collector(policy, train_envs, buf)
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)

assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


def test_pc51(args=get_args()):
args.prioritized_replay = 1
args.gamma = .95
args.seed = 1
test_c51(args)


if __name__ == '__main__':
test_c51(get_args())

0 comments on commit c6f2648

Please sign in to comment.