From 4e48655972677b13015974c11c6df5368f83c025 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 16 Dec 2022 13:41:23 +0000 Subject: [PATCH 01/23] [BugFix] Fix NoopReset in parallel settings (#747) --- test/test_transforms.py | 16 +++++++ torchrl/envs/common.py | 6 ++- torchrl/envs/transforms/transforms.py | 60 +++++++++++++++------------ 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index ed953e7ce45..c63f6f69828 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1275,6 +1275,22 @@ def test_noop_reset_env(self, random, device, compose): else: assert transformed_env.step_count == 30 + @pytest.mark.parametrize("random", [True, False]) + @pytest.mark.parametrize("compose", [True, False]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_noop_reset_env_error(self, random, device, compose): + torch.manual_seed(0) + env = SerialEnv(3, lambda: ContinuousActionVecMockEnv()) + env.set_seed(100) + noop_reset_env = NoopResetEnv(random=random) + transformed_env = TransformedEnv(env) + transformed_env.append_transform(noop_reset_env) + with pytest.raises( + ValueError, + match="there is more than one done state in the parent environment", + ): + transformed_env.reset() + @pytest.mark.parametrize( "default_keys", [["action"], ["action", "monkeys jumping on the bed"]] ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b9bf02fcdc5..dbd2c118893 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -399,9 +399,11 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - self.is_done = tensordict_reset.get( + self.is_done = tensordict_reset.set_default( "done", - torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), + torch.zeros( + *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device + ), ) if self.is_done: raise RuntimeError( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b14ae8b0508..62c0d6b22f4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1960,40 +1960,48 @@ def base_env(self): def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Do no-op action for a number of steps in [1, noop_max].""" + td_reset = tensordict.clone(False) + tensordict = tensordict.clone(False) + # check that there is a single done state -- behaviour is undefined for multiple dones parent = self.parent - # keys = tensordict.keys() + if tensordict.get("done").numel() > 1: + raise ValueError( + "there is more than one done state in the parent environment. " + "NoopResetEnv is designed to work on single env instances, as partial reset " + "is currently not supported. If you feel like this is a missing feature, submit " + "an issue on TorchRL github repo. " + "In case you are trying to use NoopResetEnv over a batch of environments, know " + "that you can have a transformed batch of transformed envs, such as: " + "`TransformedEnv(ParallelEnv(3, lambda: TransformedEnv(MyEnv(), NoopResetEnv(3))), OtherTransform())`." + ) noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) - i = 0 trial = 0 - while i < noops: - i += 1 - tensordict = parent.rand_step(tensordict) - tensordict = step_mdp(tensordict) - if parent.is_done: - parent.reset() - i = 0 - trial += 1 - if trial > _MAX_NOOPS_TRIALS: - tensordict = parent.reset(tensordict) - tensordict = parent.rand_step(tensordict) + while True: + i = 0 + while i < noops: + i += 1 + tensordict = parent.rand_step(tensordict) + tensordict = step_mdp(tensordict, exclude_done=False) + if tensordict.get("done"): + tensordict = parent.reset(td_reset.clone(False)) break - if parent.is_done: - raise RuntimeError("NoopResetEnv concluded with done environment") - # td = step_mdp( - # tensordict, exclude_done=False, exclude_reward=True, exclude_action=True - # ) - - # for k in keys: - # if k not in td.keys(): - # td.set(k, tensordict.get(k)) - - # # replace the next_ prefix - # for out_key in parent.observation_spec: - # td.rename_key(out_key[5:], out_key) + else: + break + + trial += 1 + if trial > _MAX_NOOPS_TRIALS: + tensordict = parent.rand_step(tensordict) + if tensordict.get("done"): + raise RuntimeError( + f"parent is still done after a single random step (i={i})." + ) + break + if tensordict.get("done"): + raise RuntimeError("NoopResetEnv concluded with done environment") return tensordict def __repr__(self) -> str: From 286d9b1fd883f6d1425bd6ed3e663a63ddaf5dac Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 16 Dec 2022 13:42:56 +0000 Subject: [PATCH 02/23] [Refactor] Remove env.is_done attribute (#748) --- test/mocking_classes.py | 2 - torchrl/envs/common.py | 62 ++++++++++----------------- torchrl/envs/env_creator.py | 2 +- torchrl/envs/gym_like.py | 4 +- torchrl/envs/libs/brax.py | 1 - torchrl/envs/libs/gym.py | 2 +- torchrl/envs/libs/jumanji.py | 4 -- torchrl/envs/model_based/common.py | 1 - torchrl/envs/transforms/transforms.py | 12 ------ torchrl/envs/vec_env.py | 15 +------ 10 files changed, 27 insertions(+), 78 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index db26b1c687b..7629d2874e7 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -340,7 +340,6 @@ def _step( if not self.categorical_action_encoding: assert (a.sum(-1) == 1).all() - assert not self.is_done, "trying to execute step in done env" obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep tensordict = tensordict.select() # empty tensordict @@ -423,7 +422,6 @@ def _step( self.step_count += 1 tensordict = tensordict.to(self.device) a = tensordict.get("action") - assert not self.is_done, "trying to execute step in done env" obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a) tensordict = tensordict.select() # empty tensordict diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index dbd2c118893..fef7249fa38 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -186,8 +186,6 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): - reward_spec (TensorSpec): sampling spec of the rewards; - batch_size (torch.Size): number of environments contained in the instance; - device (torch.device): device where the env input and output are expected to live - - is_done (torch.Tensor): boolean value(s) indicating if the environment has reached a done state since the - last reset - run_type_checks (bool): if True, the observation and reward dtypes will be compared against their respective spec and an exception will be raised if they don't match. @@ -212,7 +210,6 @@ def __init__( super().__init__() if device is not None: self.device = torch.device(device) - self._is_done = None self.dtype = dtype_map.get(dtype, dtype) if "is_closed" not in self.__dir__(): self.is_closed = True @@ -331,7 +328,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " "tensordict.select()) inside _step before writing new tensors onto this new instance." ) - self.is_done = tensordict_out.get("done") if self.run_type_checks: for key in self._select_observation_keys(tensordict_out): obs = tensordict_out.get(key) @@ -399,13 +395,13 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - self.is_done = tensordict_reset.set_default( + tensordict_reset.set_default( "done", torch.zeros( *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device ), ) - if self.is_done: + if tensordict_reset.get("done").any(): raise RuntimeError( f"Env {self} was done after reset. This is (currently) not allowed." ) @@ -454,16 +450,6 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: f"got {tensordict.batch_size} and {self.batch_size}" ) - def is_done_get_fn(self) -> bool: - if self._is_done is None: - self._is_done = torch.zeros(self.batch_size, device=self.device) - return self._is_done.all() - - def is_done_set_fn(self, val: torch.Tensor) -> None: - self._is_done = val - - is_done = property(is_done_get_fn, is_done_set_fn) - def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: """Performs a random step in the environment given the action_spec attribute. @@ -550,30 +536,27 @@ def policy(td): return td.set("action", self.action_spec.rand(self.batch_size)) tensordicts = [] - if not self.is_done: - for i in range(max_steps): - if auto_cast_to_device: - tensordict = tensordict.to(policy_device) - tensordict = policy(tensordict) - if auto_cast_to_device: - tensordict = tensordict.to(env_device) - tensordict = self.step(tensordict) - tensordicts.append(tensordict.clone()) - if ( - break_when_any_done and tensordict.get("done").any() - ) or i == max_steps - 1: - break - tensordict = step_mdp( - tensordict, - keep_other=True, - exclude_reward=False, - exclude_action=False, - ) + for i in range(max_steps): + if auto_cast_to_device: + tensordict = tensordict.to(policy_device) + tensordict = policy(tensordict) + if auto_cast_to_device: + tensordict = tensordict.to(env_device) + tensordict = self.step(tensordict) + tensordicts.append(tensordict.clone()) + if ( + break_when_any_done and tensordict.get("done").any() + ) or i == max_steps - 1: + break + tensordict = step_mdp( + tensordict, + keep_other=True, + exclude_reward=False, + exclude_action=False, + ) - if callback is not None: - callback(self, tensordict) - else: - raise Exception("reset env before calling rollout!") + if callback is not None: + callback(self, tensordict) batch_size = self.batch_size if tensordict is None else tensordict.batch_size @@ -642,7 +625,6 @@ def to(self, device: DEVICE_TYPING) -> EnvBase: self.reward_spec = self.reward_spec.to(device) self.observation_spec = self.observation_spec.to(device) self.input_spec = self.input_spec.to(device) - self.is_done = self.is_done.to(device) self.device = device return super().to(device) diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 0fbb2b15943..c9121c76dd1 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -49,7 +49,7 @@ class EnvCreator: ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) - ... if env.is_done: + ... if tensordict.get("done"): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 7f832a10b74..9021a4f590b 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -215,7 +215,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = np.nan reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) done = self._to_tensor(done, dtype=torch.bool) - self.is_done = done tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, device=self.device @@ -240,8 +239,7 @@ def _reset( batch_size=self.batch_size, device=self.device, ) - self._is_done = torch.zeros(self.batch_size, dtype=torch.bool) - tensordict_out.set("done", self._is_done) + tensordict_out.set("done", torch.zeros(*self.batch_size, 1, dtype=torch.bool)) return tensordict_out def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index e0e271b5ede..47dc69ba2a5 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -224,7 +224,6 @@ def _step_with_grad(self, tensordict: TensorDictBase): # extract done values next_done = next_state_nograd["done"].bool() - self._is_done = next_done # merge with tensors with grad function next_state = next_state_nograd diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 1c4d0a2680c..309ed3b066a 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -272,7 +272,7 @@ def _make_specs(self, env: "gym.Env") -> None: ) def _init_env(self): - self.reset() # make sure that _is_done is populated + self.reset() def __repr__(self) -> str: return ( diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index e430acba5bc..2e59e452b19 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -255,8 +255,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: done = timestep.step_type == self.lib.types.StepType.LAST done = _ndarray_to_tensor(done).view(torch.bool).to(self.device) - self._is_done = done - # build results tensordict_out = TensorDict( source=obs_dict, @@ -288,8 +286,6 @@ def _reset( obs_dict = self.read_obs(timestep.observation) done = torch.zeros(self.batch_size, dtype=torch.bool) - self._is_done = done - # build results tensordict_out = TensorDict( source=obs_dict, diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 1ff0cd03712..e63825f13f7 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -91,7 +91,6 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): - input_spec (CompositeSpec): sampling spec of the inputs; - batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes. - device (torch.device): device where the env input and output are expected to live - - is_done (torch.Tensor): boolean value(s) indicating if the environment has reached a done state since the last reset Args: world_model (nn.Module): model that generates world states and its corresponding rewards; diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 62c0d6b22f4..4164cc44cfd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -485,16 +485,6 @@ def is_closed(self) -> bool: def is_closed(self, value: bool): self.base_env.is_closed = value - @property - def is_done(self) -> bool: - if self._is_done is None: - return self.base_env.is_done - return self._is_done.all() - - @is_done.setter - def is_done(self, val: torch.Tensor) -> None: - self._is_done = val - def close(self): self.base_env.close() self.is_closed = True @@ -568,8 +558,6 @@ def to(self, device: DEVICE_TYPING) -> TransformedEnv: self.base_env.to(device) self.transform.to(device) - self.is_done = self.is_done.to(device) - if self.cache_specs: self._input_spec = None self._observation_spec = None diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 878d851f946..a601b417172 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -356,9 +356,6 @@ def reward_spec(self) -> TensorSpec: def reward_spec(self, value: TensorSpec) -> None: self._reward_spec = value - def is_done_set_fn(self, value: bool) -> None: - self._is_done = value.all() - def _create_td(self) -> None: """Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations.""" if self._single_task: @@ -988,11 +985,8 @@ def _run_worker_pipe_shared_mem( _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - just_reset = True - if env.is_done: - raise RuntimeError( - f"{env.__class__.__name__}.is_done is {env.is_done} after reset" - ) + if _td.get("done").any(): + raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step": if not initialized: @@ -1002,10 +996,6 @@ def _run_worker_pipe_shared_mem( *env_input_keys, strict=False, ) - if env.is_done and not allow_step_when_done: - raise RuntimeError( - f"calling step when env is done, just reset = {just_reset}" - ) _td = env._step(_td) if step_keys is None: step_keys = set(env.observation_spec.keys()).union( @@ -1020,7 +1010,6 @@ def _run_worker_pipe_shared_mem( msg = "step_result" data = (msg, step_keys) child_pipe.send(data) - just_reset = False elif cmd == "close": del tensordict, _td, data From 1f8341cd9cd706b598356857297af8a738df7641 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 16 Dec 2022 14:58:50 +0000 Subject: [PATCH 03/23] Drop use of prototype modules (#738) --- test/test_tensordictmodules.py | 9 --------- torchrl/modules/tensordict_module/probabilistic.py | 8 ++++---- tutorials/sphinx-tutorials/tensordict_module.py | 2 +- tutorials/sphinx-tutorials/torchrl_demo.py | 2 +- 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 619b04c700e..b7e1708bb77 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -692,9 +692,6 @@ def test_stateful(self, safe, spec_type, lazy): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td) - # test bounds if not safe and spec_type == "bounded": assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() @@ -856,9 +853,6 @@ def test_functional(self, safe, spec_type): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params) - # test bounds if not safe and spec_type == "bounded": assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() @@ -1012,9 +1006,6 @@ def test_functional_with_buffer(self, safe, spec_type): td = TensorDict({"in": torch.randn(3, 7)}, [3]) tdmodule(td, params=params) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params) - assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 7]) diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 7e4bc6a68ca..a399edde0b5 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -6,10 +6,10 @@ import warnings from typing import Optional, Sequence, Type, Union -from tensordict.nn import TensorDictModule -from tensordict.nn.prototype import ( +from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, + TensorDictModule, ) from tensordict.tensordict import TensorDictBase @@ -20,7 +20,7 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): - """A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.prototype.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. + """A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. `SafeProbabilisticModule` is a non-parametric module representing a probability distribution. It reads the distribution parameters from an input @@ -190,7 +190,7 @@ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: class SafeProbabilisticSequential(ProbabilisticTensorDictSequential, SafeSequential): - """A :obj:``SafeProbabilisticSequential`` is an :obj:``tensordict.nn.prototype.ProbabilisticTensorDictSequential`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. + """A :obj:``SafeProbabilisticSequential`` is an :obj:``tensordict.nn.ProbabilisticTensorDictSequential`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. Similarly to :obj:`TensorDictSequential`, but enforces that the final module in the sequence is an :obj:`ProbabilisticTensorDictModule` and also exposes ``get_dist`` diff --git a/tutorials/sphinx-tutorials/tensordict_module.py b/tutorials/sphinx-tutorials/tensordict_module.py index 442a2f242d3..7e8b9dfef1d 100644 --- a/tutorials/sphinx-tutorials/tensordict_module.py +++ b/tutorials/sphinx-tutorials/tensordict_module.py @@ -189,7 +189,7 @@ def forward(self, x): print("the output tensordict shape is: ", result_td.shape) -from tensordict.nn.prototype import ( +from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, ) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 27a7c97e3a6..eb4a97a4cc3 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -505,7 +505,7 @@ tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) actor(tensordict) # action is the default value -from tensordict.nn.prototype import ( +from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, ) From 70d9e6096871296b850b7c3b1d12cb3aeb1a4cb2 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 11:37:31 +0000 Subject: [PATCH 04/23] Fix test_cost --- torchrl/envs/common.py | 12 ++++++++++-- torchrl/envs/transforms/transforms.py | 6 +++--- torchrl/objectives/deprecated.py | 3 +++ torchrl/objectives/dreamer.py | 7 ++++--- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 6eeadf5c733..a37a4752743 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -351,14 +351,22 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = tensordict_out.get("reward") # unsqueeze rewards if needed expected_reward_shape = torch.Size([*tensordict_out.batch_size, *self.reward_spec.shape]) - if reward.shape != expected_reward_shape: + n = len(expected_reward_shape) + if len(reward.shape) >= n and reward.shape[-n:] != expected_reward_shape: + reward = reward.view(*reward.shape[:n], *expected_reward_shape) + tensordict_out.set("reward", reward) + elif len(reward.shape) < n: reward = reward.view(expected_reward_shape) tensordict_out.set("reward", reward) done = tensordict_out.get("done") # unsqueeze done if needed expected_done_shape = torch.Size([*tensordict_out.batch_size, 1]) - if done.shape != expected_done_shape: + n = len(expected_done_shape) + if len(done.shape) >= n and done.shape[-n:] != expected_done_shape: + done = done.view(*done.shape[:n], *expected_done_shape) + tensordict_out.set("done", done) + elif len(done.shape) < n: done = done.view(expected_done_shape) tensordict_out.set("done", done) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 569a3ac550d..c9612abde2b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -561,9 +561,9 @@ def __repr__(self) -> str: def _erase_metadata(self): if self.cache_specs: - self._input_spec = None - self._observation_spec = None - self._reward_spec = None + self.__dict__["_input_spec"] = None + self.__dict__["_observation_spec"] = None + self.__dict__["_reward_spec"] = None def to(self, device: DEVICE_TYPING) -> TransformedEnv: self.base_env.to(device) diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index b80cf2854ff..38a284079cb 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -218,6 +218,9 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: next_td, selected_q_params, ) + state_action_value = next_td.get("state_action_value") + if state_action_value.shape[-len(sample_log_prob.shape):] != sample_log_prob.shape: + sample_log_prob = sample_log_prob.unsqueeze(-1) state_value = ( next_td.get("state_action_value") - self.alpha * sample_log_prob ) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 8d839078154..c42c8e29387 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -73,7 +73,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict.get(("next", "prior_std")), tensordict.get(("next", "posterior_mean")), tensordict.get(("next", "posterior_std")), - ) + ).unsqueeze(-1) reco_loss = distance_loss( tensordict.get(("next", "pixels")), tensordict.get(("next", "reco_pixels")), @@ -81,7 +81,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: ) if not self.global_average: reco_loss = reco_loss.sum((-3, -2, -1)) - reco_loss = reco_loss.mean() + reco_loss = reco_loss.mean().unsqueeze(-1) reward_loss = distance_loss( tensordict.get("true_reward"), @@ -90,7 +90,8 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: ) if not self.global_average: reward_loss = reward_loss.squeeze(-1) - reward_loss = reward_loss.mean() + reward_loss = reward_loss.mean().unsqueeze(-1) + # import ipdb; ipdb.set_trace() return ( TensorDict( { From ce4f4d23db3775a82f567d52c6610a19d4b8a72f Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 11:38:53 +0000 Subject: [PATCH 05/23] Lint and format --- torchrl/data/tensor_specs.py | 8 +++++--- torchrl/envs/common.py | 4 +++- torchrl/envs/gym_like.py | 4 +++- torchrl/envs/libs/gym.py | 2 +- torchrl/objectives/deprecated.py | 5 ++++- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index debe7e43a1c..debcb7f1623 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -222,13 +222,15 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: ): val = val.copy() val = torch.tensor(val, dtype=self.dtype, device=self.device) - if val.shape[-len(self.shape):] != self.shape: + if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end if self.shape == torch.Size([1]): val = val.unsqueeze(-1) else: - raise RuntimeError(f"Shape mismatch: the value has shape {val.shape} which " - f"is incompatible with the spec shape {self.shape}.") + raise RuntimeError( + f"Shape mismatch: the value has shape {val.shape} which " + f"is incompatible with the spec shape {self.shape}." + ) if not _NO_CHECK_SPEC_ENCODE: self.assert_is_in(val) return val diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index a37a4752743..920b7e3ad7b 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -350,7 +350,9 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = tensordict_out.get("reward") # unsqueeze rewards if needed - expected_reward_shape = torch.Size([*tensordict_out.batch_size, *self.reward_spec.shape]) + expected_reward_shape = torch.Size( + [*tensordict_out.batch_size, *self.reward_spec.shape] + ) n = len(expected_reward_shape) if len(reward.shape) >= n and reward.shape[-n:] != expected_reward_shape: reward = reward.view(*reward.shape[:n], *expected_reward_shape) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 5b117be37c6..f3fd531be44 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -204,7 +204,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = self.read_reward(reward, _reward) - if isinstance(done, bool) or (isinstance(done, np.ndarray) and not len(done)): + if isinstance(done, bool) or ( + isinstance(done, np.ndarray) and not len(done) + ): done = torch.tensor([done], device=self.device) done, do_break = self.read_done(done) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 2ef53b188a6..64e9c756e1d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -15,9 +15,9 @@ DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, NdBoundedTensorSpec, + NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, - UnboundedContinuousTensorSpec, NdUnboundedContinuousTensorSpec, ) from ..._utils import implement_for diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 38a284079cb..69d09fc3fe0 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -219,7 +219,10 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: selected_q_params, ) state_action_value = next_td.get("state_action_value") - if state_action_value.shape[-len(sample_log_prob.shape):] != sample_log_prob.shape: + if ( + state_action_value.shape[-len(sample_log_prob.shape) :] + != sample_log_prob.shape + ): sample_log_prob = sample_log_prob.unsqueeze(-1) state_value = ( next_td.get("state_action_value") - self.alpha * sample_log_prob From 84b5f37acebc49d2a9fad2414fdb581c78e62d3f Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 11:55:51 +0000 Subject: [PATCH 06/23] Fix test_postprocs --- test/test_postprocs.py | 14 +++----------- torchrl/collectors/utils.py | 2 +- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index d50b74bb08f..38cbf806f68 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -97,8 +97,8 @@ def create_fake_trajs( num_workers=32, traj_len=200, ): - traj_ids = torch.arange(num_workers).unsqueeze(-1) - steps_count = torch.zeros(num_workers).unsqueeze(-1) + traj_ids = torch.arange(num_workers) + steps_count = torch.zeros(num_workers) workers = torch.arange(num_workers) out = [] @@ -125,15 +125,7 @@ def create_fake_trajs( return out @pytest.mark.parametrize("num_workers", range(3, 34, 3)) - @pytest.mark.parametrize( - "traj_len", - [ - 10, - 17, - 50, - 97, - ], - ) + @pytest.mark.parametrize("traj_len", [10, 17, 50, 97]) def test_splits(self, num_workers, traj_len): trajs = TestSplits.create_fake_trajs(num_workers, traj_len) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 9edbb81c0e5..8a7c360b37b 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -72,7 +72,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: td = TensorDict( source=out_dict, device=rollout_tensordict.device, - batch_size=out_dict["mask"].shape[:-1], + batch_size=out_dict["mask"].shape, ) td = td.unflatten_keys(sep) if (out_dict["done"].sum(1) > 1).any(): From d61f7a53fec5717fe12e35dc414470c933c2d58e Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 21 Dec 2022 16:53:40 +0000 Subject: [PATCH 07/23] Fix test_collectors --- torchrl/collectors/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 8a7c360b37b..9ea8c67e831 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -69,10 +69,11 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: key: torch.nn.utils.rnn.pad_sequence(_o, batch_first=True) for key, _o in out_splits.items() } + out_dict["mask"] = out_dict["mask"].squeeze(-1) td = TensorDict( source=out_dict, device=rollout_tensordict.device, - batch_size=out_dict["mask"].shape, + batch_size=out_dict["mask"].squeeze(-1).shape, ) td = td.unflatten_keys(sep) if (out_dict["done"].sum(1) > 1).any(): From 459ab12aa77713118883396444d772decaeae112 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 10:05:11 +0100 Subject: [PATCH 08/23] amend --- test/test_collector.py | 60 ++++++++++++++++++++++++++++++-- torchrl/collectors/collectors.py | 34 +++++++++++------- torchrl/collectors/utils.py | 10 +++--- 3 files changed, 85 insertions(+), 19 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 769ef221ae6..b12e974097a 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -306,8 +306,8 @@ def make_env(): ) for _data in collector: continue - steps = _data["step_count"][..., 1:, :] - done = _data["done"][..., :-1, :] + steps = _data["step_count"][..., 1:] + done = _data["done"][..., :-1, :].squeeze(-1) # we don't want just one done assert done.sum() > 3 # check that after a done, the next step count is always 1 @@ -370,6 +370,62 @@ def make_env(seed): del collector +@pytest.mark.parametrize("frames_per_batch", [200, 10]) +@pytest.mark.parametrize("num_env", [1, 3]) +@pytest.mark.parametrize("env_name", ["vec"]) +def test_split_trajs(num_env, env_name, frames_per_batch, seed=5): + if num_env == 1: + + def env_fn(seed): + env = MockSerialEnv(device="cpu") + env.set_seed(seed) + return env + + else: + + def env_fn(seed): + def make_env(seed): + env = MockSerialEnv(device="cpu") + env.set_seed(seed) + return env + + env = SerialEnv( + num_workers=num_env, + create_env_fn=make_env, + create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], + allow_step_when_done=True, + ) + env.set_seed(seed) + return env + + policy = make_policy(env_name) + + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=frames_per_batch * num_env, + max_frames_per_traj=2000, + total_frames=20000, + device="cpu", + pin_memory=False, + reset_when_done=True, + split_trajs=True, + ) + for _, d in enumerate(collector): # noqa + break + + assert d.ndimension() == 2 + assert d["mask"].shape == d.shape + assert d["step_count"].shape == d.shape + assert d["traj_ids"].shape == d.shape + for traj in d.unbind(0): + assert traj["traj_ids"].unique().numel() == 1 + assert (traj["step_count"][1:] - traj["step_count"][:-1] == 1).all() + + del collector + + # TODO: design a test that ensures that collectors are interrupted even if __del__ is not called # @pytest.mark.parametrize("should_shutdown", [True, False]) # def test_shutdown_collector(should_shutdown, num_env=3, env_name="vec", seed=40): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 48b80e30a58..d7c1ff6643a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -432,7 +432,8 @@ def __init__( self._tensordict = env.reset() self._tensordict.set( - "step_count", torch.zeros(*self.env.batch_size, 1, dtype=torch.int) + "step_count", + torch.zeros(self.env.batch_size, dtype=torch.int, device=env.device), ) if ( @@ -468,14 +469,21 @@ def __init__( ) # in addition to outputs of the policy, we add traj_ids and step_count to # _tensordict_out which will be collected during rollout - if len(self.env.batch_size): - traj_ids = torch.zeros(*self._tensordict_out.batch_size, 1) - else: - traj_ids = torch.zeros(*self._tensordict_out.batch_size, 1, 1) - - self._tensordict_out.set("traj_ids", traj_ids) self._tensordict_out.set( - "step_count", torch.zeros(*self._tensordict_out.batch_size, 1) + "traj_ids", + torch.zeros( + *self._tensordict_out.batch_size, + dtype=torch.int64, + device=self.env_device, + ), + ) + self._tensordict_out.set( + "step_count", + torch.zeros( + *self._tensordict_out.batch_size, + dtype=torch.int64, + device=self.env_device, + ), ) self.return_in_place = return_in_place @@ -589,7 +597,7 @@ def _reset_if_necessary(self) -> None: if not self.reset_when_done: done = torch.zeros_like(done) steps = self._tensordict.get("step_count") - done_or_terminated = done | (steps == self.max_frames_per_traj) + done_or_terminated = done.squeeze(-1) | (steps == self.max_frames_per_traj) if self._has_been_done is None: self._has_been_done = done_or_terminated else: @@ -604,7 +612,7 @@ def _reset_if_necessary(self) -> None: traj_ids = self._tensordict.get("traj_ids").clone() steps = steps.clone() if len(self.env.batch_size): - self._tensordict.masked_fill_(done_or_terminated.squeeze(-1), 0) + self._tensordict.masked_fill_(done_or_terminated, 0) self._tensordict.set("reset_workers", done_or_terminated) else: self._tensordict.zero_() @@ -620,8 +628,8 @@ def _reset_if_necessary(self) -> None: 1, done_or_terminated.sum() + 1, device=traj_ids.device ) steps[done_or_terminated] = 0 - self._tensordict.set("traj_ids", traj_ids) # no ops if they already match - self._tensordict.set("step_count", steps) + self._tensordict.set_("traj_ids", traj_ids) # no ops if they already match + self._tensordict.set_("step_count", steps) @torch.no_grad() def rollout(self) -> TensorDictBase: @@ -636,7 +644,7 @@ def rollout(self) -> TensorDictBase: self._tensordict.fill_("step_count", 0) n = self.env.batch_size[0] if len(self.env.batch_size) else 1 - self._tensordict.set("traj_ids", torch.arange(n).unsqueeze(-1)) + self._tensordict.set("traj_ids", torch.arange(n).view(self.env.batch_size[:1])) tensordict_out = [] with set_exploration_mode(self.exploration_mode): diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 9ea8c67e831..2f3df6c3a5e 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -41,7 +41,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: splits = traj_ids.view(-1) splits = [(splits == i).sum().item() for i in splits.unique_consecutive()] # if all splits are identical then we can skip this function - if len(set(splits)) == 1 and splits[0] == traj_ids.shape[1]: + if len(set(splits)) == 1 and splits[0] == traj_ids.shape[-1]: rollout_tensordict.set( "mask", torch.ones( @@ -63,17 +63,19 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: dones = out_splits["done"] valid_ids = list(range(len(dones))) out_splits = {key: [_out[i] for i in valid_ids] for key, _out in out_splits.items()} - mask = [torch.ones_like(_out, dtype=torch.bool) for _out in out_splits["done"]] + mask = [ + torch.ones_like(_out[..., 0], dtype=torch.bool) for _out in out_splits["done"] + ] out_splits["mask"] = mask out_dict = { key: torch.nn.utils.rnn.pad_sequence(_o, batch_first=True) for key, _o in out_splits.items() } - out_dict["mask"] = out_dict["mask"].squeeze(-1) + out_dict["mask"] = out_dict["mask"] td = TensorDict( source=out_dict, device=rollout_tensordict.device, - batch_size=out_dict["mask"].squeeze(-1).shape, + batch_size=out_dict["mask"].shape, ) td = td.unflatten_keys(sep) if (out_dict["done"].sum(1) > 1).any(): From c2b9f9ce5d88af8f3f8d1549912c24f610f10190 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 22 Dec 2022 12:17:24 +0000 Subject: [PATCH 09/23] [BugFix] Fixes for `speed` branch merge on tensordict (#755) * init * empty * empty * amend * empty * run legit tests --- test/test_shared.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_shared.py b/test/test_shared.py index f93adcaa90b..f28c4d81d4a 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -59,7 +59,9 @@ def test_shared(self, indexing_method): td = tensordict.clone().share_memory_() if indexing_method == 0: subtd = TensorDict( - source={key: item[0] for key, item in td.items()}, batch_size=[] + source={key: item[0] for key, item in td.items()}, + batch_size=[], + _is_shared=True, ) elif indexing_method == 1: subtd = td.get_sub_tensordict(0) From c1f6364cea1ad7c555aca72cd21b58c496511c07 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 13:35:57 +0100 Subject: [PATCH 10/23] empty From 66278e2d7af4d77bd735ff69cf98eb31b327c52c Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 13:42:22 +0100 Subject: [PATCH 11/23] reset_workers --- docs/source/reference/envs.rst | 2 +- test/test_cost.py | 2 +- test/test_env.py | 4 ++-- test/test_trainer.py | 2 +- torchrl/collectors/collectors.py | 1 - torchrl/data/postprocs/postprocs.py | 2 +- torchrl/envs/vec_env.py | 4 ++-- torchrl/modules/models/recipes/impala.py | 2 +- torchrl/trainers/trainers.py | 10 +++++----- tutorials/sphinx-tutorials/coding_ddpg.py | 2 +- tutorials/sphinx-tutorials/coding_dqn.py | 4 ++-- 11 files changed, 17 insertions(+), 18 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 1b8bfdeb240..be856e59298 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -105,7 +105,7 @@ It is also possible to reset some but not all of the environments: fields={ done: Tensor(torch.Size([4, 1]), dtype=torch.bool), pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8), - reset_workers: Tensor(torch.Size([4, 1]), dtype=torch.bool)}, + reset_workers: Tensor(torch.Size([4]), dtype=torch.bool)}, batch_size=torch.Size([4]), device=None, is_shared=True) diff --git a/test/test_cost.py b/test/test_cost.py index fbba7ac4c5a..d41f7b385c7 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -266,7 +266,7 @@ def _create_seq_mock_data_dqn( "reward": reward * mask.to(obs.dtype), "action": action * mask.to(obs.dtype), "action_value": action_value - * expand_as_right(mask.to(obs.dtype).squeeze(-1), action_value), + * expand_as_right(mask.to(obs.dtype), action_value), }, ) return td diff --git a/test/test_env.py b/test/test_env.py index a20f7e06d39..a1e6a725e34 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -501,7 +501,7 @@ def test_parallel_env( td1 = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, + source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], @@ -585,7 +585,7 @@ def test_parallel_env_with_policy( td1 = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, + source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], diff --git a/test/test_trainer.py b/test/test_trainer.py index bd0c8a8ea59..87919df592d 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -773,7 +773,7 @@ def test_masking(): ) td_out = trainer._process_batch_hook(td) assert td_out.shape[0] == td.get("mask").sum() - assert (td["tensor"][td["mask"].squeeze(-1)] == td_out["tensor"]).all() + assert (td["tensor"][td["mask"]] == td_out["tensor"]).all() class TestSubSampler: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d7c1ff6643a..f5881e05f3c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -681,7 +681,6 @@ def reset(self, index=None, **kwargs) -> None: raise RuntimeError("resetting unique env with index is not permitted.") reset_workers = torch.zeros( *self.env.batch_size, - 1, dtype=torch.bool, device=self.env.device, ) diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 6d6ed79d4b0..8d7e47d3753 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -99,7 +99,7 @@ def _select_and_repeat( ) tensor_cat = torch.cat([tensor, tensor_repeat], 1) + post_terminal_tensor tensor_cat = tensor_cat[:, -T:] - mask = expand_as_right(mask.squeeze(-1), tensor_cat) + mask = expand_as_right(mask, tensor_cat) return tensor_cat.masked_fill(~mask, 0.0) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 0e3e1f57f38..06f9821eb8a 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -618,7 +618,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) reset_workers = tensordict.get("reset_workers") else: - reset_workers = torch.ones(self.num_workers, 1, dtype=torch.bool) + reset_workers = torch.ones(self.num_workers, dtype=torch.bool) keys = set() for i, _env in enumerate(self._envs): @@ -834,7 +834,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) reset_workers = tensordict.get("reset_workers") else: - reset_workers = torch.ones(self.num_workers, 1, dtype=torch.bool) + reset_workers = torch.ones(self.num_workers, dtype=torch.bool) for i, channel in enumerate(self.parent_channels): if not reset_workers[i]: diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index 6dedfb42e77..67d35635637 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -177,7 +177,7 @@ def forward(self, tensordict: TensorDictBase): # noqa: D102 x = tensordict.get(self.observation_key) done = tensordict.get("done").squeeze(-1) reward = tensordict.get("reward").squeeze(-1) - mask = tensordict.get("mask").squeeze(-1) + mask = tensordict.get("mask") core_state = ( tensordict.get("core_state") if "core_state" in tensordict.keys() else None ) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 07ad793668e..7b779ce1b83 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -646,7 +646,7 @@ def __init__( def extend(self, batch: TensorDictBase) -> TensorDictBase: if self.flatten_tensordicts: if "mask" in batch.keys(): - batch = batch[batch.get("mask").squeeze(-1)] + batch = batch[batch.get("mask")] else: batch = batch.reshape(-1) else: @@ -810,7 +810,7 @@ def __init__(self, logname="r_training", log_pbar: bool = False): def __call__(self, batch: TensorDictBase) -> Dict: if "mask" in batch.keys(): return { - self.logname: batch.get("reward")[batch.get("mask").squeeze(-1)] + self.logname: batch.get("reward")[batch.get("mask")] .mean() .item(), "log_pbar": self.log_pbar, @@ -857,7 +857,7 @@ def __init__( def update_reward_stats(self, batch: TensorDictBase) -> None: reward = batch.get("reward") if "mask" in batch.keys(): - reward = reward[batch.get("mask").squeeze(-1)] + reward = reward[batch.get("mask")] if self._update_has_been_called and not self._normalize_has_been_called: # We'd like to check that rewards are normalized. Problem is that the trainer can collect data without calling steps... # raise RuntimeError( @@ -935,7 +935,7 @@ def mask_batch(batch: TensorDictBase) -> TensorDictBase: """ if "mask" in batch.keys(): mask = batch.get("mask") - return batch[mask.squeeze(-1)] + return batch[mask] return batch @@ -997,7 +997,7 @@ def __call__(self, batch: TensorDictBase) -> TensorDictBase: if "mask" in batch.keys(): # if a valid mask is present, it's important to sample only # valid steps - traj_len = batch.get("mask").sum(1).squeeze() + traj_len = batch.get("mask").sum(-1) sub_traj_len = max( self.min_sub_traj_len, min(sub_traj_len, traj_len.min().int().item()), diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 78fe2973492..91965a9f091 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -643,7 +643,7 @@ def make_replay_buffer(make_replay_buffer=3): if "mask" in tensordict.keys(): # if multi-step, a mask is present to help filter padded values current_frames = tensordict["mask"].sum() - tensordict = tensordict[tensordict.get("mask").squeeze(-1)] + tensordict = tensordict[tensordict.get("mask")] else: tensordict = tensordict.view(-1) current_frames = tensordict.numel() diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 78003cbd9b2..c5b728f440c 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -365,7 +365,7 @@ def make_model(): pbar = tqdm.tqdm(total=total_frames) for j, data in enumerate(data_collector): # trajectories are padded to be stored in the same tensordict: since we do not care about consecutive step, we'll just mask the tensordict and get the flattened representation instead. - mask = data["mask"].squeeze(-1) + mask = data["mask"] current_frames = mask.sum().cpu().item() pbar.update(current_frames) @@ -602,7 +602,7 @@ def make_model(): pbar = tqdm.tqdm(total=total_frames) for j, data in enumerate(data_collector): - mask = data["mask"].squeeze(-1) + mask = data["mask"] data = pad(data, [0, 0, 0, max_size - data.shape[1]]) current_frames = mask.sum().cpu().item() pbar.update(current_frames) From 031601a25d0c033fa90bc17a4dc7d2574eb901f6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 14:04:39 +0100 Subject: [PATCH 12/23] brax/jumanji fix --- torchrl/envs/libs/brax.py | 15 +++++++++------ torchrl/envs/libs/jumanji.py | 10 +++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 47dc69ba2a5..7ab3253ba8c 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -117,22 +117,25 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 - self._input_spec = CompositeSpec( + self.input_spec = CompositeSpec( action=NdBoundedTensorSpec( minimum=-1, maximum=1, shape=(env.action_size,), device=self.device ) ) - self._reward_spec = NdUnboundedContinuousTensorSpec( - shape=(), device=self.device + self.reward_spec = NdUnboundedContinuousTensorSpec( + shape=[ + 1, + ], + device=self.device, ) - self._observation_spec = CompositeSpec( + self.observation_spec = CompositeSpec( observation=NdUnboundedContinuousTensorSpec( shape=(env.observation_size,), device=self.device ) ) # extract state spec from instance - self._state_spec = self._make_state_spec(env) - self._input_spec["state"] = self._state_spec + self.state_spec = self._make_state_spec(env) + self.input_spec["state"] = self.state_spec def _make_state_example(self): key = jax.random.PRNGKey(0) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 2e59e452b19..c8e0b2f2a00 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -193,13 +193,13 @@ def _make_reward_spec(self, env) -> TensorSpec: def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 # extract spec from jumanji definition - self._input_spec = self._make_input_spec(env) - self._observation_spec = self._make_observation_spec(env) - self._reward_spec = self._make_reward_spec(env) + self.input_spec = self._make_input_spec(env) + self.observation_spec = self._make_observation_spec(env) + self.reward_spec = self._make_reward_spec(env) # extract state spec from instance - self._state_spec = self._make_state_spec(env) - self._input_spec["state"] = self._state_spec + self.state_spec = self._make_state_spec(env) + self.input_spec["state"] = self.state_spec # build state example for data conversion self._state_example = self._make_state_example(env) From d024665bfe1986d70af6bcc122581f508a763386 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 14:11:09 +0100 Subject: [PATCH 13/23] lint --- torchrl/trainers/trainers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 7b779ce1b83..da312aaf9f5 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -810,9 +810,7 @@ def __init__(self, logname="r_training", log_pbar: bool = False): def __call__(self, batch: TensorDictBase) -> Dict: if "mask" in batch.keys(): return { - self.logname: batch.get("reward")[batch.get("mask")] - .mean() - .item(), + self.logname: batch.get("reward")[batch.get("mask")].mean().item(), "log_pbar": self.log_pbar, } return { From 4c777633d4c897c31e3d35d1b4ef81a17b8a6ccd Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 14:26:41 +0100 Subject: [PATCH 14/23] reward shape for jumanji --- torchrl/envs/libs/jumanji.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index c8e0b2f2a00..5bf22d8988f 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -59,25 +59,31 @@ def _jumanji_to_torchrl_spec_transform( dtype = numpy_to_torch_dtype_dict[spec.dtype] return action_space_cls(spec.num_values, dtype=dtype, device=device) elif isinstance(spec, jumanji.specs.BoundedArray): + shape = spec.shape + if not len(shape): + shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] return NdBoundedTensorSpec( - shape=spec.shape, + shape=shape, minimum=np.asarray(spec.minimum), maximum=np.asarray(spec.maximum), dtype=dtype, device=device, ) elif isinstance(spec, jumanji.specs.Array): + shape = spec.shape + if not len(shape): + shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): return NdUnboundedContinuousTensorSpec( - shape=spec.shape, dtype=dtype, device=device + shape=shape, dtype=dtype, device=device ) else: return NdUnboundedDiscreteTensorSpec( - shape=spec.shape, dtype=dtype, device=device + shape=shape, dtype=dtype, device=device ) elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"): new_spec = {} From a5e9b121cf6caa86db824906f314e10d259a92d6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 14:42:05 +0100 Subject: [PATCH 15/23] jumanji --- torchrl/envs/libs/jumanji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 5bf22d8988f..ff5300560e7 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -227,7 +227,7 @@ def _set_seed(self, seed): def read_state(self, state): state_dict = _object_to_tensordict(state, self.device, self.batch_size) - return self._state_spec.encode(state_dict) + return self.state_spec.encode(state_dict) def read_obs(self, obs): if isinstance(obs, (list, jnp.ndarray, np.ndarray)): From 5e9b603e5677417f895f51d2a9af1f5bdb253487 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 15:31:49 +0100 Subject: [PATCH 16/23] jumanji --- torchrl/data/tensor_specs.py | 6 +++--- torchrl/envs/libs/jumanji.py | 9 ++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index debcb7f1623..0db3b720b28 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -447,7 +447,7 @@ def rand(self, shape=None) -> torch.Tensor: return out else: interval = self.space.maximum - self.space.minimum - r = torch.rand(*shape, *interval.shape, device=interval.device) + r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) r = interval * r r = self.space.minimum + r r = r.to(self.dtype).to(self.device) @@ -533,7 +533,7 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) return torch.nn.functional.gumbel_softmax( - torch.rand(*shape, self.space.n, device=self.device), + torch.rand(torch.Size([*shape, self.space.n]), device=self.device), hard=True, dim=-1, ).to(torch.long) @@ -666,7 +666,7 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) interval = self.space.maximum - self.space.minimum - r = torch.rand(*shape, *interval.shape, device=interval.device) + r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) r = r * interval r = self.space.minimum + r r = r.to(self.dtype) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index ff5300560e7..db0034f4979 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -60,8 +60,6 @@ def _jumanji_to_torchrl_spec_transform( return action_space_cls(spec.num_values, dtype=dtype, device=device) elif isinstance(spec, jumanji.specs.BoundedArray): shape = spec.shape - if not len(shape): - shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] return NdBoundedTensorSpec( @@ -73,8 +71,6 @@ def _jumanji_to_torchrl_spec_transform( ) elif isinstance(spec, jumanji.specs.Array): shape = spec.shape - if not len(shape): - shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): @@ -194,7 +190,10 @@ def _make_observation_spec(self, env) -> TensorSpec: raise TypeError(f"Unsupported spec type {type(spec)}") def _make_reward_spec(self, env) -> TensorSpec: - return _jumanji_to_torchrl_spec_transform(env.reward_spec(), device=self.device) + reward_spec = _jumanji_to_torchrl_spec_transform(env.reward_spec(), device=self.device) + if not len(reward_spec.shape): + reward_spec.shape = torch.Size([1]) + return reward_spec def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 From 87735a326cce9be214c64b23a397d1e4b1c8dd00 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 16:16:33 +0100 Subject: [PATCH 17/23] bf --- test/test_libs.py | 8 +++++++- torchrl/data/tensor_specs.py | 8 +++++--- torchrl/envs/libs/jumanji.py | 4 +++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index c16774de4c2..de3e8c665c4 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -342,7 +342,13 @@ def test_habitat(self, envname): @pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed") -@pytest.mark.parametrize("envname", ["Snake-6x6-v0", "TSP50-v0"]) +@pytest.mark.parametrize( + "envname", + [ + "TSP50-v0", + "Snake-6x6-v0", + ], +) class TestJumanji: def test_jumanji_seeding(self, envname): final_seed = [] diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0db3b720b28..a46af2881e4 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -447,7 +447,9 @@ def rand(self, shape=None) -> torch.Tensor: return out else: interval = self.space.maximum - self.space.minimum - r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) + r = torch.rand( + torch.Size([*shape, *interval.shape]), device=interval.device + ) r = interval * r r = self.space.minimum + r r = r.to(self.dtype).to(self.device) @@ -1035,7 +1037,7 @@ def __init__( dtype: Optional[Union[str, torch.dtype]] = torch.long, ): if shape is None: - shape = torch.Size((1,)) + shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) super().__init__(shape, space, device, dtype, domain="discrete") @@ -1070,7 +1072,7 @@ def __eq__(self, other): ) def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: - return super().to_numpy(val, safe).squeeze(-1) + return super().to_numpy(val, safe) class CompositeSpec(TensorSpec): diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index db0034f4979..67dceafbaec 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -190,7 +190,9 @@ def _make_observation_spec(self, env) -> TensorSpec: raise TypeError(f"Unsupported spec type {type(spec)}") def _make_reward_spec(self, env) -> TensorSpec: - reward_spec = _jumanji_to_torchrl_spec_transform(env.reward_spec(), device=self.device) + reward_spec = _jumanji_to_torchrl_spec_transform( + env.reward_spec(), device=self.device + ) if not len(reward_spec.shape): reward_spec.shape = torch.Size([1]) return reward_spec From 80e3c35ba7f632254621fb9e4fbb2407ab735ec1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 17:39:59 +0100 Subject: [PATCH 18/23] amend --- test/test_postprocs.py | 4 ++-- test/test_tensor_spec.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 38cbf806f68..d684793670d 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -108,10 +108,10 @@ def create_fake_trajs( td = TensorDict( source={ "traj_ids": traj_ids, - "a": traj_ids.clone(), + "a": traj_ids.clone().unsqueeze(-1), "steps_count": steps_count, "workers": workers, - "done": done, + "done": done.unsqueeze(-1), }, batch_size=[num_workers], ) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 0b64c6df52d..7a8a455615d 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -53,7 +53,7 @@ def test_discrete(cls): r = ts.rand() ts.to_numpy(r) ts.encode(torch.tensor([5])) - ts.encode(torch.tensor([5]).numpy()) + ts.encode(torch.tensor(5).numpy()) ts.encode(9) with pytest.raises(AssertionError): ts.encode(torch.tensor([11])) # out of bounds @@ -887,9 +887,8 @@ def test_categorical_action_spec_rand(self): sample = action_spec.rand((10000,)) - sample_list = sample[:, 0] + sample_list = sample sample_list = [sum(sample_list == i).item() for i in range(10)] - print(sample_list) assert chisquare(sample_list).pvalue > 0.1 sample = action_spec.to_numpy(sample) From 61c41e97ff074a6facde961d7ec27788256d5ddf Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 21:20:50 +0100 Subject: [PATCH 19/23] amend --- test/test_cost.py | 98 ++++++++++++++++------------- torchrl/data/postprocs/postprocs.py | 4 +- torchrl/objectives/sac.py | 12 +++- 3 files changed, 66 insertions(+), 48 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index d41f7b385c7..30ecee922dc 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -26,7 +26,6 @@ # from torchrl.data.postprocs.utils import expand_as_right from tensordict.tensordict import assert_allclose_td, TensorDict -from tensordict.utils import expand_as_right from torch import autograd, nn from torchrl.data import ( CompositeSpec, @@ -253,20 +252,22 @@ def _create_seq_mock_data_dqn( if action_spec_type == "categorical": action_value = torch.max(action_value, -1, keepdim=True)[0] action = torch.argmax(action, -1, keepdim=True) + # action_value = action_value.unsqueeze(-1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), - "action_value": action_value - * expand_as_right(mask.to(obs.dtype), action_value), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0), }, ) return td @@ -488,16 +489,18 @@ def _create_seq_mock_data_ddpg( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -726,16 +729,18 @@ def _create_seq_mock_data_sac( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -1129,16 +1134,18 @@ def _create_seq_mock_data_redq( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -1564,23 +1571,25 @@ def _create_seq_mock_data_ppo( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), - "sample_log_prob": torch.randn_like(action[..., :1]) - / 10 - * mask.to(obs.dtype), - "loc": params_mean * mask.to(obs.dtype), - "scale": params_scale * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "sample_log_prob": ( + torch.randn_like(action[..., :1]) / 10 + ).masked_fill_(~mask.unsqueeze(-1), 0.0), + "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), + "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) @@ -1835,23 +1844,26 @@ def _create_seq_mock_data_a2c( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 td = TensorDict( batch_size=(batch, T), source={ - "observation": obs * mask.to(obs.dtype), - "next": {"observation": next_obs * mask.to(obs.dtype)}, + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0) + }, "done": done, "mask": mask, - "reward": reward * mask.to(obs.dtype), - "action": action * mask.to(obs.dtype), - "sample_log_prob": torch.randn_like(action[..., :1]) - / 10 - * mask.to(obs.dtype), - "loc": params_mean * mask.to(obs.dtype), - "scale": params_scale * mask.to(obs.dtype), + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "sample_log_prob": torch.randn_like(action[..., :1]).masked_fill_( + ~mask.unsqueeze(-1), 0.0 + ) + / 10, + "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), + "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, ) diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 8d7e47d3753..922880387a5 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -99,7 +99,7 @@ def _select_and_repeat( ) tensor_cat = torch.cat([tensor, tensor_repeat], 1) + post_terminal_tensor tensor_cat = tensor_cat[:, -T:] - mask = expand_as_right(mask, tensor_cat) + mask = mask.expand_as(tensor_cat) return tensor_cat.masked_fill(~mask, 0.0) @@ -170,7 +170,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: done = tensordict.get("done") if "mask" in tensordict.keys(): - mask = tensordict.get("mask") + mask = tensordict.get("mask").view_as(done) else: mask = done.clone().flip(1).cumsum(1).flip(1).to(torch.bool) reward = tensordict.get("reward") diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 994bb1cab7c..c5d31e19f5a 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -177,23 +177,29 @@ def device(self) -> torch.device: ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + shape = None if tensordict.ndimension() > 1: - tensordict = tensordict.view(-1) + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict device = self.device - td_device = tensordict.to(device) + td_device = tensordict_reshape.to(device) loss_actor = self._loss_actor(td_device) loss_qvalue, priority = self._loss_qvalue(td_device) loss_value = self._loss_value(td_device) loss_alpha = self._loss_alpha(td_device) - tensordict.set(self.priority_key, priority) + tensordict_reshape.set(self.priority_key, priority) if (loss_actor.shape != loss_qvalue.shape) or ( loss_actor.shape != loss_value.shape ): raise RuntimeError( f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}" ) + if shape: + tensordict.update(tensordict_reshape.view(shape)) return TensorDict( { "loss_actor": loss_actor.mean(), From f693cf9c0b3bb38ff9c17d4c08eac6aa8017ea8f Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Dec 2022 08:40:37 +0100 Subject: [PATCH 20/23] amend --- torchrl/trainers/helpers/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 3294c2b87fa..947e3e9a613 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -169,7 +169,6 @@ def make_dqn_actor( # automatically infer in key (in_key,) = itertools.islice(env_specs["observation_spec"], 1) - out_features = action_spec.shape[0] actor_class = QValueActor actor_kwargs = {} @@ -178,6 +177,8 @@ def make_dqn_actor( # to the number of possible choices and also set categorical behavioural for actors. actor_kwargs.update({"action_space": "categorical"}) out_features = env_specs["action_spec"].space.n + else: + out_features = action_spec.shape[0] if cfg.distributional: if not atoms: From 0b47c80ea48f8dffa7d9fb1b5e6b0d7d48e75fa2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Dec 2022 09:20:08 +0100 Subject: [PATCH 21/23] bf --- torchrl/data/postprocs/postprocs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 922880387a5..6463fc02d72 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -99,7 +99,7 @@ def _select_and_repeat( ) tensor_cat = torch.cat([tensor, tensor_repeat], 1) + post_terminal_tensor tensor_cat = tensor_cat[:, -T:] - mask = mask.expand_as(tensor_cat) + mask = expand_as_right(mask.squeeze(-1), tensor_cat) return tensor_cat.masked_fill(~mask, 0.0) From 7448875582e54947742561f7b2547b37487d35ea Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Dec 2022 09:51:10 +0100 Subject: [PATCH 22/23] bf --- torchrl/objectives/ppo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 26b73e70612..0ab4924e271 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -105,13 +105,12 @@ def _log_weight( dist = self.actor.get_dist(tensordict_clone, params=self.actor_params) log_prob = dist.log_prob(action) - log_prob = log_prob.unsqueeze(-1) prev_log_prob = tensordict.get("sample_log_prob") if prev_log_prob.requires_grad: raise RuntimeError("tensordict prev_log_prob requires grad.") - log_weight = log_prob - prev_log_prob + log_weight = (log_prob - prev_log_prob).unsqueeze(-1) return log_weight, dist def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: From c519d669e02c0310544c5f7d03ac9fcd07770811 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Dec 2022 10:24:10 +0100 Subject: [PATCH 23/23] bf --- test/test_cost.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 30ecee922dc..08c705cbfcd 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1550,7 +1550,7 @@ def _create_mock_data_ppo( "done": done, "reward": reward, "action": action, - "sample_log_prob": torch.randn_like(action[..., :1]) / 10, + "sample_log_prob": torch.randn_like(action[..., 1]) / 10, }, device=device, ) @@ -1585,9 +1585,9 @@ def _create_seq_mock_data_ppo( "mask": mask, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), - "sample_log_prob": ( - torch.randn_like(action[..., :1]) / 10 - ).masked_fill_(~mask.unsqueeze(-1), 0.0), + "sample_log_prob": (torch.randn_like(action[..., 1]) / 10).masked_fill_( + ~mask, 0.0 + ), "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, @@ -1858,8 +1858,8 @@ def _create_seq_mock_data_a2c( "mask": mask, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), - "sample_log_prob": torch.randn_like(action[..., :1]).masked_fill_( - ~mask.unsqueeze(-1), 0.0 + "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_( + ~mask, 0.0 ) / 10, "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),