Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887 1. add C51 policy in tianshou/policy/modelfree/c51.py. 2. add C51 net in tianshou/utils/net/discrete.py. 3. add C51 atari example in examples/atari/atari_c51.py. 4. add C51 statement in tianshou/policy/__init__.py. 5. add C51 test in test/discrete/test_c51.py. 6. add C51 atari results in examples/atari/results/c51/. By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '20.50 ± 0.50', in epoch 9. By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
- Loading branch information
1 parent
5d13d8a
commit c6f2648
Showing
19 changed files
with
533 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import os | ||
import torch | ||
import pprint | ||
import argparse | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.policy import C51Policy | ||
from tianshou.env import SubprocVectorEnv | ||
from tianshou.utils.net.discrete import C51 | ||
from tianshou.trainer import offpolicy_trainer | ||
from tianshou.data import Collector, ReplayBuffer | ||
|
||
from atari_wrapper import wrap_deepmind | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') | ||
parser.add_argument('--seed', type=int, default=0) | ||
parser.add_argument('--eps-test', type=float, default=0.005) | ||
parser.add_argument('--eps-train', type=float, default=1.) | ||
parser.add_argument('--eps-train-final', type=float, default=0.05) | ||
parser.add_argument('--buffer-size', type=int, default=100000) | ||
parser.add_argument('--lr', type=float, default=0.0001) | ||
parser.add_argument('--gamma', type=float, default=0.99) | ||
parser.add_argument('--num-atoms', type=int, default=51) | ||
parser.add_argument('--v-min', type=float, default=-10.) | ||
parser.add_argument('--v-max', type=float, default=10.) | ||
parser.add_argument('--n-step', type=int, default=3) | ||
parser.add_argument('--target-update-freq', type=int, default=500) | ||
parser.add_argument('--epoch', type=int, default=100) | ||
parser.add_argument('--step-per-epoch', type=int, default=10000) | ||
parser.add_argument('--collect-per-step', type=int, default=10) | ||
parser.add_argument('--batch-size', type=int, default=32) | ||
parser.add_argument('--training-num', type=int, default=16) | ||
parser.add_argument('--test-num', type=int, default=10) | ||
parser.add_argument('--logdir', type=str, default='log') | ||
parser.add_argument('--render', type=float, default=0.) | ||
parser.add_argument( | ||
'--device', type=str, | ||
default='cuda' if torch.cuda.is_available() else 'cpu') | ||
parser.add_argument('--frames_stack', type=int, default=4) | ||
parser.add_argument('--resume_path', type=str, default=None) | ||
parser.add_argument('--watch', default=False, action='store_true', | ||
help='watch the play of pre-trained policy only') | ||
return parser.parse_args() | ||
|
||
|
||
def make_atari_env(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack) | ||
|
||
|
||
def make_atari_env_watch(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack, | ||
episode_life=False, clip_rewards=False) | ||
|
||
|
||
def test_c51(args=get_args()): | ||
env = make_atari_env(args) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.env.action_space.shape or env.env.action_space.n | ||
# should be N_FRAMES x H x W | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
# make environments | ||
train_envs = SubprocVectorEnv([lambda: make_atari_env(args) | ||
for _ in range(args.training_num)]) | ||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) | ||
for _ in range(args.test_num)]) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
train_envs.seed(args.seed) | ||
test_envs.seed(args.seed) | ||
# define model | ||
net = C51(*args.state_shape, args.action_shape, | ||
args.num_atoms, args.device) | ||
optim = torch.optim.Adam(net.parameters(), lr=args.lr) | ||
# define policy | ||
policy = C51Policy( | ||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, | ||
args.n_step, target_update_freq=args.target_update_freq | ||
).to(args.device) | ||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load( | ||
args.resume_path, map_location=args.device | ||
)) | ||
print("Loaded agent from: ", args.resume_path) | ||
# replay buffer: `save_last_obs` and `stack_num` can be removed together | ||
# when you have enough RAM | ||
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, | ||
save_only_last_obs=True, stack_num=args.frames_stack) | ||
# collector | ||
train_collector = Collector(policy, train_envs, buffer) | ||
test_collector = Collector(policy, test_envs) | ||
# log | ||
log_path = os.path.join(args.logdir, args.task, 'c51') | ||
writer = SummaryWriter(log_path) | ||
|
||
def save_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||
|
||
def stop_fn(mean_rewards): | ||
if env.env.spec.reward_threshold: | ||
return mean_rewards >= env.spec.reward_threshold | ||
elif 'Pong' in args.task: | ||
return mean_rewards >= 20 | ||
else: | ||
return False | ||
|
||
def train_fn(epoch, env_step): | ||
# nature DQN setting, linear decay in the first 1M steps | ||
if env_step <= 1e6: | ||
eps = args.eps_train - env_step / 1e6 * \ | ||
(args.eps_train - args.eps_train_final) | ||
else: | ||
eps = args.eps_train_final | ||
policy.set_eps(eps) | ||
writer.add_scalar('train/eps', eps, global_step=env_step) | ||
|
||
def test_fn(epoch, env_step): | ||
policy.set_eps(args.eps_test) | ||
|
||
# watch agent's performance | ||
def watch(): | ||
print("Testing agent ...") | ||
policy.eval() | ||
policy.set_eps(args.eps_test) | ||
test_envs.seed(args.seed) | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=[1] * args.test_num, | ||
render=args.render) | ||
pprint.pprint(result) | ||
|
||
if args.watch: | ||
watch() | ||
exit(0) | ||
|
||
# test train_collector and start filling replay buffer | ||
train_collector.collect(n_step=args.batch_size * 4) | ||
# trainer | ||
result = offpolicy_trainer( | ||
policy, train_collector, test_collector, args.epoch, | ||
args.step_per_epoch, args.collect_per_step, args.test_num, | ||
args.batch_size, train_fn=train_fn, test_fn=test_fn, | ||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) | ||
|
||
pprint.pprint(result) | ||
watch() | ||
|
||
|
||
if __name__ == '__main__': | ||
test_c51(get_args()) |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import os | ||
import gym | ||
import torch | ||
import pprint | ||
import argparse | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.policy import C51Policy | ||
from tianshou.env import DummyVectorEnv | ||
from tianshou.utils.net.common import Net | ||
from tianshou.trainer import offpolicy_trainer | ||
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--task', type=str, default='CartPole-v0') | ||
parser.add_argument('--seed', type=int, default=1626) | ||
parser.add_argument('--eps-test', type=float, default=0.05) | ||
parser.add_argument('--eps-train', type=float, default=0.1) | ||
parser.add_argument('--buffer-size', type=int, default=20000) | ||
parser.add_argument('--lr', type=float, default=1e-3) | ||
parser.add_argument('--gamma', type=float, default=0.9) | ||
parser.add_argument('--num-atoms', type=int, default=51) | ||
parser.add_argument('--v-min', type=float, default=-10.) | ||
parser.add_argument('--v-max', type=float, default=10.) | ||
parser.add_argument('--n-step', type=int, default=3) | ||
parser.add_argument('--target-update-freq', type=int, default=320) | ||
parser.add_argument('--epoch', type=int, default=10) | ||
parser.add_argument('--step-per-epoch', type=int, default=1000) | ||
parser.add_argument('--collect-per-step', type=int, default=10) | ||
parser.add_argument('--batch-size', type=int, default=64) | ||
parser.add_argument('--layer-num', type=int, default=3) | ||
parser.add_argument('--training-num', type=int, default=8) | ||
parser.add_argument('--test-num', type=int, default=100) | ||
parser.add_argument('--logdir', type=str, default='log') | ||
parser.add_argument('--render', type=float, default=0.) | ||
parser.add_argument('--prioritized-replay', type=int, default=0) | ||
parser.add_argument('--alpha', type=float, default=0.6) | ||
parser.add_argument('--beta', type=float, default=0.4) | ||
parser.add_argument( | ||
'--device', type=str, | ||
default='cuda' if torch.cuda.is_available() else 'cpu') | ||
args = parser.parse_known_args()[0] | ||
return args | ||
|
||
|
||
def test_c51(args=get_args()): | ||
env = gym.make(args.task) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
# train_envs = gym.make(args.task) | ||
# you can also use tianshou.env.SubprocVectorEnv | ||
train_envs = DummyVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.training_num)]) | ||
# test_envs = gym.make(args.task) | ||
test_envs = DummyVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.test_num)]) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
train_envs.seed(args.seed) | ||
test_envs.seed(args.seed) | ||
# model | ||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device, | ||
softmax=True, num_atoms=args.num_atoms) | ||
optim = torch.optim.Adam(net.parameters(), lr=args.lr) | ||
policy = C51Policy( | ||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, | ||
args.n_step, target_update_freq=args.target_update_freq | ||
).to(args.device) | ||
# buffer | ||
if args.prioritized_replay > 0: | ||
buf = PrioritizedReplayBuffer( | ||
args.buffer_size, alpha=args.alpha, beta=args.beta) | ||
else: | ||
buf = ReplayBuffer(args.buffer_size) | ||
# collector | ||
train_collector = Collector(policy, train_envs, buf) | ||
test_collector = Collector(policy, test_envs) | ||
# policy.set_eps(1) | ||
train_collector.collect(n_step=args.batch_size) | ||
# log | ||
log_path = os.path.join(args.logdir, args.task, 'c51') | ||
writer = SummaryWriter(log_path) | ||
|
||
def save_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||
|
||
def stop_fn(mean_rewards): | ||
return mean_rewards >= env.spec.reward_threshold | ||
|
||
def train_fn(epoch, env_step): | ||
# eps annnealing, just a demo | ||
if env_step <= 10000: | ||
policy.set_eps(args.eps_train) | ||
elif env_step <= 50000: | ||
eps = args.eps_train - (env_step - 10000) / \ | ||
40000 * (0.9 * args.eps_train) | ||
policy.set_eps(eps) | ||
else: | ||
policy.set_eps(0.1 * args.eps_train) | ||
|
||
def test_fn(epoch, env_step): | ||
policy.set_eps(args.eps_test) | ||
|
||
# trainer | ||
result = offpolicy_trainer( | ||
policy, train_collector, test_collector, args.epoch, | ||
args.step_per_epoch, args.collect_per_step, args.test_num, | ||
args.batch_size, train_fn=train_fn, test_fn=test_fn, | ||
stop_fn=stop_fn, save_fn=save_fn, writer=writer) | ||
|
||
assert stop_fn(result['best_reward']) | ||
if __name__ == '__main__': | ||
pprint.pprint(result) | ||
# Let's watch its performance! | ||
env = gym.make(args.task) | ||
policy.eval() | ||
policy.set_eps(args.eps_test) | ||
collector = Collector(policy, env) | ||
result = collector.collect(n_episode=1, render=args.render) | ||
print(f'Final reward: {result["rew"]}, length: {result["len"]}') | ||
|
||
|
||
def test_pc51(args=get_args()): | ||
args.prioritized_replay = 1 | ||
args.gamma = .95 | ||
args.seed = 1 | ||
test_c51(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_c51(get_args()) |
Oops, something went wrong.