From 3695f126b1b8a340b1c72801fc13c69e26b33522 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 4 Jan 2021 11:11:22 +0800 Subject: [PATCH] nstep multidim support --- .../{Breakout_rew.PNG => Breakout_rew.png} | Bin .../c51/{Enduro_rew.PNG => Enduro_rew.png} | Bin .../{MsPacman_rew.PNG => MsPacman_rew.png} | Bin .../c51/{Pong_rew.PNG => Pong_rew.png} | Bin .../c51/{Qbert_rew.PNG => Qbert_rew.png} | Bin test/base/test_returns.py | 16 +++ tianshou/policy/base.py | 17 ++- tianshou/policy/modelfree/c51.py | 112 +++--------------- tianshou/utils/net/common.py | 2 +- 9 files changed, 43 insertions(+), 104 deletions(-) rename examples/atari/results/c51/{Breakout_rew.PNG => Breakout_rew.png} (100%) rename examples/atari/results/c51/{Enduro_rew.PNG => Enduro_rew.png} (100%) rename examples/atari/results/c51/{MsPacman_rew.PNG => MsPacman_rew.png} (100%) rename examples/atari/results/c51/{Pong_rew.PNG => Pong_rew.png} (100%) rename examples/atari/results/c51/{Qbert_rew.PNG => Qbert_rew.png} (100%) diff --git a/examples/atari/results/c51/Breakout_rew.PNG b/examples/atari/results/c51/Breakout_rew.png similarity index 100% rename from examples/atari/results/c51/Breakout_rew.PNG rename to examples/atari/results/c51/Breakout_rew.png diff --git a/examples/atari/results/c51/Enduro_rew.PNG b/examples/atari/results/c51/Enduro_rew.png similarity index 100% rename from examples/atari/results/c51/Enduro_rew.PNG rename to examples/atari/results/c51/Enduro_rew.png diff --git a/examples/atari/results/c51/MsPacman_rew.PNG b/examples/atari/results/c51/MsPacman_rew.png similarity index 100% rename from examples/atari/results/c51/MsPacman_rew.PNG rename to examples/atari/results/c51/MsPacman_rew.png diff --git a/examples/atari/results/c51/Pong_rew.PNG b/examples/atari/results/c51/Pong_rew.png similarity index 100% rename from examples/atari/results/c51/Pong_rew.PNG rename to examples/atari/results/c51/Pong_rew.png diff --git a/examples/atari/results/c51/Qbert_rew.PNG b/examples/atari/results/c51/Qbert_rew.png similarity index 100% rename from examples/atari/results/c51/Qbert_rew.PNG rename to examples/atari/results/c51/Qbert_rew.png diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 69649e864..a8bdc7c9d 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -76,6 +76,10 @@ def target_q_fn(buffer, indice): return torch.tensor(-buffer.rew[indice], dtype=torch.float32) +def target_q_fn_multidim(buffer, indice): + return target_q_fn(buffer, indice).unsqueeze(1).repeat(1, 51) + + def compute_nstep_return_base(nstep, gamma, buffer, indice): returns = np.zeros_like(indice, dtype=np.float) buf_len = len(buffer) @@ -108,6 +112,10 @@ def test_nstep_returns(size=10000): assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indice) assert np.allclose(returns, r_), (r_, returns) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns')) @@ -115,6 +123,10 @@ def test_nstep_returns(size=10000): 3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indice) assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns')) @@ -122,6 +134,10 @@ def test_nstep_returns(size=10000): 3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indice) assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) if __name__ == '__main__': buf = ReplayBuffer(size) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 809751fe7..b875f6075 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -245,7 +245,7 @@ def compute_nstep_return( to False. :return: a Batch. The result will be stored in batch.returns as a - torch.Tensor with shape (bsz, ). + torch.Tensor with the same shape as target_q_fn's return tensor. """ rew = buffer.rew if rew_norm: @@ -257,12 +257,11 @@ def compute_nstep_return( mean, std = 0.0, 1.0 buf_len = len(buffer) terminal = (indice + n_step - 1) % buf_len - target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) + target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch) target_q = _nstep_return(rew, buffer.done, target_q, indice, gamma, n_step, len(buffer), mean, std) - batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) @@ -275,7 +274,7 @@ def _compile(self) -> None: i64 = np.array([0, 1], dtype=np.int64) _episodic_return(f64, f64, b, 0.1, 0.1) _episodic_return(f32, f64, b, 0.1, 0.1) - _nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0) + _nstep_return(f64, b, f32, i64, 0.1, 1, 4, 0.0, 1.0) @njit @@ -311,7 +310,12 @@ def _nstep_return( std: float, ) -> np.ndarray: """Numba speedup: 0.3s -> 0.15s.""" - returns = np.zeros(indice.shape) + target_shape = target_q.shape + bsz = target_shape[0] + # change rew/target_q to 2d array + target_q = target_q.reshape(bsz, -1) + rew = rew.reshape(-1, 1) # assume reward is a scalar + returns = np.zeros(target_q.shape) gammas = np.full(indice.shape, n_step) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len @@ -319,5 +323,6 @@ def _nstep_return( returns[done[now] > 0] = 0.0 returns = (rew[now] - mean) / std + gamma * returns target_q[gammas != n_step] = 0.0 + gammas = gammas.reshape(-1, 1) target_q = target_q * (gamma ** gammas) + returns - return target_q + return target_q.reshape(target_shape) diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 4d1f148ee..a0662d2ff 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,7 +1,6 @@ import torch import numpy as np -from numba import njit -from typing import Any, Dict, Union, Optional, Tuple +from typing import Any, Dict, Union, Optional from tianshou.policy import DQNPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -19,7 +18,7 @@ class C51Policy(DQNPolicy): :param float v_min: the value of the smallest atom in the support set, defaults to -10.0. :param float v_max: the value of the largest atom in the support set, - defaults to -10.0. + defaults to 10.0. :param int estimation_step: greater than 1, the number of steps to look ahead. :param int target_update_freq: the target network update frequency (0 if @@ -30,7 +29,7 @@ class C51Policy(DQNPolicy): .. seealso:: Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed - explanation. + explanation. """ def __init__( @@ -46,9 +45,10 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, - estimation_step, target_update_freq, - reward_normalization, **kwargs) + super().__init__(model, optim, discount_factor, estimation_step, + target_update_freq, reward_normalization, **kwargs) + assert num_atoms > 1, "num_atoms should be greater than 1" + assert v_min < v_max, "v_max should be larger than v_min" self._num_atoms = num_atoms self._v_min = v_min self._v_max = v_max @@ -56,61 +56,10 @@ def __init__( self._num_atoms) self.delta_z = (v_max - v_min) / (num_atoms - 1) - @staticmethod - def prepare_n_step( - batch: Batch, - buffer: ReplayBuffer, - indice: np.ndarray, - gamma: float = 0.99, - n_step: int = 1, - rew_norm: bool = False, - ) -> Batch: - """Modify the obs_next, done and rew in batch for computing n-step return. - - :param batch: a data batch, which is equal to buffer[indice]. - :type batch: :class:`~tianshou.data.Batch` - :param buffer: a data buffer which contains several full-episode data - chronologically. - :type buffer: :class:`~tianshou.data.ReplayBuffer` - :param indice: sampled timestep. - :type indice: numpy.ndarray - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param int n_step: the number of estimation step, should be an int - greater than 0, defaults to 1. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. - - :return: a Batch with modified obs_next, done and rew. - """ - buf_len = len(buffer) - if rew_norm: - bfr = buffer.rew[: min(buf_len, 1000)] # avoid large buffer - mean, std = bfr.mean(), bfr.std() - if np.isclose(std, 0, 1e-2): - mean, std = 0.0, 1.0 - else: - mean, std = 0.0, 1.0 - buffer_n = buffer[(indice + n_step - 1) % buf_len] - batch.obs_next = buffer_n.obs_next - rew_n, done_n = _nstep_batch(buffer.rew, buffer.done, - indice, gamma, n_step, buf_len, mean, std) - batch.rew = rew_n - batch.done = done_n - return batch - - def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray - ) -> Batch: - """Prepare the batch for calculating the n-step return. - - More details can be found at - :meth:`~tianshou.policy.C51Policy.prepare_n_step`. - """ - batch = self.prepare_n_step( - batch, buffer, indice, - self._gamma, self._n_step, self._rew_norm) - return batch + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: + return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms] def forward( self, @@ -164,25 +113,15 @@ def _target_dist(self, batch: Batch) -> torch.Tensor: a = next_b.act next_dist = next_b.logits next_dist = next_dist[np.arange(len(a)), a, :] - device = next_dist.device - reward = torch.from_numpy(batch.rew).to(device).unsqueeze(1) - done = torch.from_numpy(batch.done).to(device).float().unsqueeze(1) - support = self.support.to(device) - - # Compute the projection of bellman update Tz onto the support z. - target_support = reward + ( - self._gamma ** self._n_step) * (1.0 - done) * support.unsqueeze(0) - target_support = target_support.clamp(self._v_min, self._v_max) - + support = self.support.to(next_dist.device) + target_support = batch.returns.clamp( + self._v_min, self._v_max).to(next_dist.device) # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL target_dist = (1 - (target_support.unsqueeze(1) - support.view(1, -1, 1)).abs() / self.delta_z ).clamp(0, 1) * next_dist.unsqueeze(1) - target_dist = target_dist.sum(-1) - if hasattr(batch, "weight"): # prio buffer update - batch.weight = to_torch_as(batch.weight, target_dist) - return target_dist + return target_dist.sum(-1) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: @@ -201,24 +140,3 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.step() self._cnt += 1 return {"loss": loss.item()} - - -@njit -def _nstep_batch( - rew: np.ndarray, - done: np.ndarray, - indice: np.ndarray, - gamma: float, - n_step: int, - buf_len: int, - mean: float, - std: float, -) -> Tuple[np.ndarray, np.ndarray]: - rew_n = np.zeros(indice.shape) - done_n = done[indice] - for n in range(n_step - 1, -1, -1): - now = (indice + n) % buf_len - done_t = done[now] - done_n = np.bitwise_or(done_n, done_t) - rew_n = (rew[now] - mean) / std + (1.0 - done_t) * gamma * rew_n - return rew_n, done_n diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index afde5f3f9..ea804385a 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -175,7 +175,7 @@ class CategoricalNet(Net): .. seealso:: Please refer to :class:`~tianshou.utils.net.common.Net` for - more detailed explanation. + more detailed explanation. """ def __init__(