Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C51 algorithm #266

Merged
merged 26 commits into from Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 154 additions & 0 deletions examples/atari/atari_c51.py
@@ -0,0 +1,154 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import C51Policy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import C51
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_c51(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = C51(*args.state_shape, args.action_shape,
args.num_atoms, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(net, optim, args.gamma, args.num_atoms,
args.v_min, args.v_max, args.n_step,
target_update_freq=args.target_update_freq)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False

def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
print("Testing agent ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * 4)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)

pprint.pprint(result)
watch()


if __name__ == '__main__':
test_c51(get_args())
2 changes: 2 additions & 0 deletions tianshou/policy/__init__.py
Expand Up @@ -2,6 +2,7 @@
from tianshou.policy.random import RandomPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.c51 import C51Policy
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.modelfree.ddpg import DDPGPolicy
Expand All @@ -18,6 +19,7 @@
"RandomPolicy",
"ImitationPolicy",
"DQNPolicy",
"C51Policy",
"PGPolicy",
"A2CPolicy",
"DDPGPolicy",
Expand Down
225 changes: 225 additions & 0 deletions tianshou/policy/modelfree/c51.py
@@ -0,0 +1,225 @@
import torch
import numpy as np
from numba import njit
from typing import Any, Dict, Union, Optional, Tuple

from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy


class C51Policy(DQNPolicy):
"""Implementation of Categorical Deep Q-network. arXiv:1707.06887.

:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int num_atoms: the number of atoms in the support set of the
value distribution, defaults to 51.
:param float v_min: the value of the smallest atom in the support set,
defaults to -10.0.
:param float v_max: the value of the largest atom in the support set,
defaults to -10.0.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.

.. seealso::

Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
explanation.
"""

def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
num_atoms: int = 51,
v_min: float = -10.0,
v_max: float = 10.0,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(model, optim, discount_factor,
estimation_step, target_update_freq,
reward_normalization, **kwargs)
self._num_atoms = num_atoms
self._v_min = v_min
self._v_max = v_max
self.support = torch.linspace(self._v_min, self._v_max,
self._num_atoms)
self.delta_z = (v_max - v_min) / (num_atoms - 1)

@staticmethod
def prepare_n_step(
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
gamma: float = 0.99,
n_step: int = 1,
rew_norm: bool = False,
) -> Batch:
"""Modify the obs_next, done and rew in batch for computing n-step return.

:param batch: a data batch, which is equal to buffer[indice].
:type batch: :class:`~tianshou.data.Batch`
:param buffer: a data buffer which contains several full-episode data
chronologically.
:type buffer: :class:`~tianshou.data.ReplayBuffer`
:param indice: sampled timestep.
:type indice: numpy.ndarray
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param int n_step: the number of estimation step, should be an int
greater than 0, defaults to 1.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to False.

:return: a Batch with modified obs_next, done and rew.
"""
buf_len = len(buffer)
if rew_norm:
bfr = buffer.rew[: min(buf_len, 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0, 1e-2):
mean, std = 0.0, 1.0
else:
mean, std = 0.0, 1.0
buffer_n = buffer[(indice + n_step - 1) % buf_len]
batch.obs_next = buffer_n.obs_next
rew_n, done_n = _nstep_batch(buffer.rew, buffer.done,
indice, gamma, n_step, buf_len, mean, std)
batch.rew = rew_n
batch.done = done_n
return batch

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Prepare the batch for calculating the n-step return.

More details can be found at
:meth:`~tianshou.policy.C51Policy.prepare_n_step`.
"""
batch = self.prepare_n_step(
batch, buffer, indice,
self._gamma, self._n_step, self._rew_norm)
return batch

def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.

:return: A :class:`~tianshou.data.Batch` which has 2 keys:

* ``act`` the action.
* ``state`` the hidden state.

.. seealso::

Please refer to :meth:`~tianshou.policy.DQNPolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hasattr(obs, "obs") could be false ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three are the same as existing DQNPolicy. I guess we can make a separate PR to enhance these things :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I noticed that :)

dist, h = model(obs_, state=state, info=batch.info)
q = (dist * to_torch_as(self.support, dist)).sum(2)
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like much this approach, but right now I have no idea about to avoid it. Maybe adding masked_array method to Batch class to offer something similar to numpy's masked arrays. Internally it would use the same mechanism, but it would be hidden in Batch, which is way better in by opinion.

# some of actions are masked, they cannot be selected
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act in training or testing phase
if not self.updating and not np.isclose(self.eps, 0.0):
for i in range(len(q)):
if np.random.rand() < self.eps:
q_ = np.random.rand(*q[i].shape)
if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf
act[i] = q_.argmax()
return Batch(logits=dist, act=act, state=h)

def _target_dist(
self, batch: Batch
) -> torch.Tensor:
if self._target:
a = self(batch, input="obs_next").act
next_dist = self(
batch, model="model_old", input="obs_next"
).logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
batch_size = len(a)
next_dist = next_dist[np.arange(batch_size), a, :]
device = next_dist.device
reward = torch.from_numpy(batch.rew).to(device).unsqueeze(1)
done = torch.from_numpy(batch.rew).to(device).float().unsqueeze(1)
support = self.support.to(device)

# Compute the projection of bellman update Tz onto the support z.
target_support = reward + (self._gamma ** self._n_step
) * (1.0 - done) * support.unsqueeze(0)
target_support = target_support.clamp(self._v_min, self._v_max)

# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL
target_dist = (1 - (target_support.unsqueeze(1) -
support.view(1, -1, 1)).abs() / self.delta_z
).clamp(0, 1) * next_dist.unsqueeze(1)
target_dist = target_dist.sum(-1)
return target_dist

def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
with torch.no_grad():
target_dist = self._target_dist(batch)
curr_dist = self(batch).logits
act = batch.act
curr_dist = curr_dist[np.arange(len(act)), act, :]
cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1)
loss = (cross_entropy * weight).mean()
batch.weight = cross_entropy.detach() # prio-buffer
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
loss.backward()
self.optim.step()
self._cnt += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend explicit variable names _cnt

return {"loss": loss.item()}


@njit
def _nstep_batch(
rew: np.ndarray,
done: np.ndarray,
indice: np.ndarray,
gamma: float,
n_step: int,
buf_len: int,
mean: float,
std: float,
) -> Tuple[np.ndarray, np.ndarray]:
rew_n = np.zeros(indice.shape)
done_n = done[indice]
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
done_t = done[now]
done_n = np.bitwise_or(done_n, done_t)
rew_n = (rew[now] - mean) / std + (1.0 - done_t) * gamma * rew_n
return rew_n, done_n