From 51ef1edcfe597bcb692af82fb0bde89764866903 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 7 Mar 2024 21:08:38 +0000 Subject: [PATCH 01/14] init --- torchrl/data/tensor_specs.py | 4 ++-- torchrl/envs/common.py | 28 +++++++++++++++++++++++----- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 69f3c7b3efd..3632e6d6f42 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1555,9 +1555,9 @@ def __init__( dtype = torch.get_default_dtype() if not isinstance(low, torch.Tensor): - low = torch.as_tensor(low, dtype=dtype, device=device) + low = torch.tensor(low, dtype=dtype, device=device) if not isinstance(high, torch.Tensor): - high = torch.as_tensor(high, dtype=dtype, device=device) + high = torch.tensor(high, dtype=dtype, device=device) if high.device != device: high = high.to(device) if low.device != device: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 06a3ddfee3f..2a625f4b883 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2661,6 +2661,19 @@ def step_and_maybe_reset( tensordict_ = self.maybe_reset(tensordict_) return tensordict, tensordict_ + @property + def _simple_done(self): + _simple_done = self.__dict__.get("_simple_done_value", None) + if _simple_done is None: + key_set = set(self.full_done_spec.keys()) + _simple_done = key_set == { + "done", + "truncated", + "terminated", + } or key_set == {"done", "terminated"} + self.__dict__["_simple_done_value"] = _simple_done + return _simple_done + def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Checks the done keys of the input tensordict and, if needed, resets the environment where it is done. @@ -2672,11 +2685,16 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: not reset and contains the new reset data where the environment was reset. """ - any_done = _terminated_or_truncated( - tensordict, - full_done_spec=self.output_spec["full_done_spec"], - key="_reset", - ) + if self._simple_done: + done = tensordict._get_str("done", default=None) + tensordict._set_str("done", done.clone(), validated=True, inplace=False) + any_done = done.any() + else: + any_done = _terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key="_reset", + ) if any_done: tensordict = self.reset(tensordict) return tensordict From 63a3fbbe4ce3e35ca14ae6d50067a194a1b9f7d4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 10:03:08 +0000 Subject: [PATCH 02/14] amend --- test/test_env.py | 49 ++++++++++++ torchrl/data/tensor_specs.py | 4 +- torchrl/envs/common.py | 31 +++----- torchrl/envs/utils.py | 148 +++++++++++++++++++++++++++++++++++ 4 files changed, 211 insertions(+), 21 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index d13a7ed7934..1241cde1dc9 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -79,6 +79,7 @@ from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv from torchrl.envs.utils import ( + _StepMDP, _terminated_or_truncated, check_env_specs, check_marl_grouping, @@ -1312,6 +1313,54 @@ def test_steptensordict( if has_out: assert out is next_tensordict + @pytest.mark.parametrize("keep_other", [True, False]) + @pytest.mark.parametrize("exclude_reward", [True, False]) + @pytest.mark.parametrize("exclude_done", [False, True]) + @pytest.mark.parametrize("exclude_action", [False, True]) + @pytest.mark.parametrize( + "envcls", + [ + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + NestedCountingEnv, + CountingBatchedEnv, + HeterogeneousCountingEnv, + DiscreteActionConvMockEnv, + ], + ) + def test_step_class( + self, + envcls, + keep_other, + exclude_reward, + exclude_done, + exclude_action, + ): + torch.manual_seed(0) + env = envcls() + + tensordict = env.rand_step(env.reset()) + out = step_mdp( + tensordict.lock_(), + keep_other=keep_other, + exclude_reward=exclude_reward, + exclude_done=exclude_done, + exclude_action=exclude_action, + done_keys=env.done_keys, + action_keys=env.action_keys, + reward_keys=env.reward_keys, + ) + step_func = _StepMDP( + env, + keep_other=keep_other, + exclude_reward=exclude_reward, + exclude_done=exclude_done, + exclude_action=exclude_action, + ) + out2 = step_func(tensordict) + assert (out == out2).all() + @pytest.mark.parametrize("nested_obs", [True, False]) @pytest.mark.parametrize("nested_action", [True, False]) @pytest.mark.parametrize("nested_done", [True, False]) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3632e6d6f42..3bf92a0e7c4 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -576,9 +576,9 @@ def encode( ): val = val.copy() if not ignore_device: - val = torch.as_tensor(val, device=self.device, dtype=self.dtype) + val = torch.tensor(val, device=self.device, dtype=self.dtype) else: - val = torch.as_tensor(val, dtype=self.dtype) + val = torch.tensor(val, dtype=self.dtype) if val.shape != self.shape: # if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 2a625f4b883..d189626d15d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -28,10 +28,10 @@ from torchrl.envs.utils import ( _make_compatible_policy, _repr_by_depth, + _StepMDP, _terminated_or_truncated, _update_during_reset, get_available_libraries, - step_mdp, ) LIBRARIES = get_available_libraries() @@ -2513,6 +2513,14 @@ def rollout( out_td.refine_names(..., "time") return out_td + @property + def _step_mdp(self): + step_func = self.__dict__.get("_step_mdp_value", None) + if step_func is None: + step_func = _StepMDP(self, exclude_action=False) + self.__dict__["_step_mdp_value"] = step_func + return step_func + def _rollout_stop_early( self, *, @@ -2543,15 +2551,8 @@ def _rollout_stop_early( if i == max_steps - 1: # we don't truncated as one could potentially continue the run break - tensordict = step_mdp( - tensordict, - keep_other=True, - exclude_action=False, - exclude_reward=True, - reward_keys=self.reward_keys, - action_keys=self.action_keys, - done_keys=self.done_keys, - ) + tensordict = self._step_mdp_stop_early(tensordict) + # done and truncated are in done_keys # We read if any key is done. any_done = _terminated_or_truncated( @@ -2649,15 +2650,7 @@ def step_and_maybe_reset( tensordict = self.step(tensordict) # done and truncated are in done_keys # We read if any key is done. - tensordict_ = step_mdp( - tensordict, - keep_other=True, - exclude_action=False, - exclude_reward=True, - reward_keys=self.reward_keys, - action_keys=self.action_keys, - done_keys=self.done_keys, - ) + tensordict_ = self._step_mdp(tensordict) tensordict_ = self.maybe_reset(tensordict_) return tensordict, tensordict_ diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e779bfc165d..d50473995c4 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -12,6 +12,7 @@ import inspect import os import re +import warnings from enum import Enum from typing import Any, Dict, List, Union @@ -20,6 +21,7 @@ from tensordict import ( is_tensor_collection, LazyStackedTensorDict, + TensorDict, TensorDictBase, unravel_key, ) @@ -41,6 +43,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, + NO_DEFAULT, TensorSpec, UnboundedContinuousTensorSpec, ) @@ -82,6 +85,150 @@ def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() +class _StepMDP: + """Stateful version of step_mdp. + + Precomputes the list of keys to include and exclude during a call to step_mdp + to reduce runtime. + + """ + + def __init__( + self, + env, + *, + keep_other: bool = True, + exclude_reward: bool = True, + exclude_done: bool = False, + exclude_action: bool = True, + ): + action_keys = env.action_keys + done_keys = env.done_keys + reward_keys = env.reward_keys + observation_keys = env.full_observation_spec.keys(True, True) + state_keys = env.full_state_spec.keys(True, True) + self.action_keys = [unravel_key(key) for key in action_keys] + self.done_keys = [unravel_key(key) for key in done_keys] + self.reward_keys = [unravel_key(key) for key in reward_keys] + self.observation_keys = [unravel_key(key) for key in observation_keys] + self.state_keys = [unravel_key(key) for key in state_keys] + + excluded = set() + if exclude_reward: + excluded = excluded.union(self.reward_keys) + if exclude_done: + excluded = excluded.union(self.done_keys) + if exclude_action: + excluded = excluded.union(self.action_keys) + + self.excluded = [unravel_key(key) for key in excluded] + + self.keep_other = keep_other + self.exclude_action = exclude_action + + self.keys_from_next = list(self.observation_keys) + if not exclude_reward: + self.keys_from_next += self.reward_keys + if not exclude_done: + self.keys_from_next += self.done_keys + self.keys_from_root = [] + if not exclude_action: + self.keys_from_root += self.action_keys + if keep_other: + self.keys_from_root += self.state_keys + self.keys_from_root = self._repr_key_list_as_tree(self.keys_from_root) + self.keys_from_next = self._repr_key_list_as_tree(self.keys_from_next) + self.validated = None + + def validate(self, tensordict): + if self.validated: + return True + if self.validated is None: + # check that the key set of the tensordict matches what is expected + expected = ( + self.state_keys + + self.action_keys + + self.done_keys + + self.observation_keys + + [unravel_key(("next", key)) for key in self.observation_keys] + + [unravel_key(("next", key)) for key in self.done_keys] + + [unravel_key(("next", key)) for key in self.reward_keys] + ) + actual = set(tensordict.keys(True, True)) + self.validated = set(expected) == actual + if not self.validated: + warnings.warn( + "The expected key set and actual key set differ. " + "This will work but with a slower throughput than " + "when the specs match exactly the actual key set " + "in the data. " + f"Expected - Actual keys={set(expected) - actual}, \n" + f"Actual - Expected keys={actual- set(expected)}." + ) + return self.validated + + @staticmethod + def _repr_key_list_as_tree(key_list): + """Represents the keys as a tree to facilitate iteration.""" + key_dict = {key: torch.zeros(()) for key in key_list} + td = TensorDict(key_dict) + return tree_map(lambda x: None, td.to_dict()) + + @classmethod + def _grab_and_place( + cls, nested_key_dict: dict, data_in: TensorDictBase, data_out: TensorDictBase + ): + for key, subdict in nested_key_dict.items(): + val = data_in._get_str(key, NO_DEFAULT) + if subdict is not None: + val_out = data_out._get_str(key, None) + if val_out is None: + val_out = val.empty() + if isinstance(val, LazyStackedTensorDict): + + val = LazyStackedTensorDict( + *( + cls._grab_and_place(subdict, _val, _val_out) + for (_val, _val_out) in zip( + val.unbind(val.stack_dim), + val_out.unbind(val_out.stack_dim), + ) + ), + stack_dim=val.stack_dim, + ) + else: + val = cls._grab_and_place(subdict, val, val_out) + data_out._set_str(key, val, validated=True, inplace=False) + return data_out + + def __call__(self, tensordict): + if isinstance(tensordict, LazyStackedTensorDict): + out = LazyStackedTensorDict.lazy_stack( + [self.__call__(td) for td in tensordict.tensordicts], + tensordict.stack_dim, + ) + return out + + next_td = tensordict._get_str("next", None) + out = next_td.empty() + if self.validate(tensordict): + self._grab_and_place(self.keys_from_root, tensordict, out) + self._grab_and_place(self.keys_from_next, next_td, out) + return out + else: + total_key = () + if self.keep_other: + for key in tensordict.keys(): + if key != "next": + _set(tensordict, out, key, total_key, self.excluded) + elif not self.exclude_action: + for action_key in self.action_keys: + _set_single_key(tensordict, out, action_key) + for key in next_td.keys(): + _set(next_td, out, key, total_key, self.excluded) + return out + + def step_mdp( tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, @@ -297,6 +444,7 @@ def _set(source, dest, key, total_key, excluded): try: val = source.get(key) if is_tensor_collection(val): + # if val is a tensordict we need to copy the structure new_val = dest.get(key, None) if new_val is None: new_val = val.empty() From 44f99125629aedf59dfeb0e9296e5d88a7e912e6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 10:03:32 +0000 Subject: [PATCH 03/14] amend --- torchrl/data/tensor_specs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3bf92a0e7c4..3632e6d6f42 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -576,9 +576,9 @@ def encode( ): val = val.copy() if not ignore_device: - val = torch.tensor(val, device=self.device, dtype=self.dtype) + val = torch.as_tensor(val, device=self.device, dtype=self.dtype) else: - val = torch.tensor(val, dtype=self.dtype) + val = torch.as_tensor(val, dtype=self.dtype) if val.shape != self.shape: # if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end From c497b81c1c64eb5af6520f968349a3e4ad4c7d78 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 10:59:18 +0000 Subject: [PATCH 04/14] amend --- torchrl/envs/gym_like.py | 24 +++++++--------- torchrl/envs/libs/gym.py | 10 +++++++ torchrl/envs/transforms/transforms.py | 41 ++++++++++++++------------- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index d3b3dfd659c..4fc2b4abc7d 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -235,7 +235,14 @@ def read_reward(self, reward): reward (torch.Tensor or TensorDict): reward to be mapped. """ - return self.reward_spec.encode(reward, ignore_device=True) + if isinstance(reward, int) and reward == 0: + return self.reward_spec.zero() + reward = self.reward_spec.encode(reward, ignore_device=True) + + if reward is None: + reward = torch.tensor(np.nan).expand(self.reward_spec.shape) + + return reward def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] @@ -277,14 +284,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: done, info_dict, ) = self._output_transform(self._env.step(action_np)) - if isinstance(obs, list) and len(obs) == 1: - # Until gym 0.25.2 we had rendered frames returned in lists of length 1 - obs = obs[0] - - if _reward is None: - _reward = self.reward_spec.zero() - reward = reward + _reward + if _reward is not None: + reward = reward + _reward terminated, truncated, done, do_break = self.read_done( terminated=terminated, truncated=truncated, done=done @@ -294,17 +296,13 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = self.read_reward(reward) obs_dict = self.read_obs(obs) - - if reward is None: - reward = torch.tensor(np.nan).expand(self.reward_spec.shape) - obs_dict[self.reward_key] = reward # if truncated/terminated is not in the keys, we just don't pass it even if it # is defined. if terminated is None: terminated = done - if truncated is not None and "truncated" in self.done_keys: + if truncated is not None: obs_dict["truncated"] = truncated obs_dict["done"] = done obs_dict["terminated"] = terminated diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index a419b013722..aa43717e296 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1062,6 +1062,11 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811 # if it's not a ndarray, we must return bool # since it's not a bool, we make it so terminated = bool(terminated) + + if isinstance(observations, list) and len(obs) == 1: + # Until gym 0.25.2 we had rendered frames returned in lists of length 1 + obs = obs[0] + return (observations, reward, terminated, truncated, done, info) @implement_for("gym", "0.24", "0.26") @@ -1083,6 +1088,11 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811 # if it's not a ndarray, we must return bool # since it's not a bool, we make it so terminated = bool(terminated) + + if isinstance(observations, list) and len(obs) == 1: + # Until gym 0.25.2 we had rendered frames returned in lists of length 1 + obs = obs[0] + return (observations, reward, terminated, truncated, done, info) @implement_for("gym", "0.26", None) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6448bd9bd0f..b17bd06e966 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -732,7 +732,9 @@ def input_spec(self) -> TensorSpec: return input_spec def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = tensordict.clone(False) + # No need to clone here because inv does it already + # tensordict = tensordict.clone(False) + next_preset = tensordict.get("next", None) tensordict_in = self.transform.inv(tensordict) next_tensordict = self.base_env._step(tensordict_in) @@ -3831,30 +3833,29 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: self.in_keys = self._find_in_keys() self._initialized = True - if all(key in tensordict.keys(include_nested=True) for key in self.in_keys): - values = [tensordict.get(key) for key in self.in_keys] - if self.unsqueeze_if_oor: - pos_idx = self.dim > 0 - abs_idx = self.dim if pos_idx else -self.dim - 1 - values = [ - v - if abs_idx < v.ndimension() - else v.unsqueeze(0) - if not pos_idx - else v.unsqueeze(-1) - for v in values - ] - - out_tensor = torch.cat(values, dim=self.dim) - tensordict.set(self.out_keys[0], out_tensor) - if self._del_keys: - tensordict.exclude(*self.keys_to_exclude, inplace=True) - else: + values = [tensordict.get(key, None) for key in self.in_keys] + if any(value is None for value in values): raise Exception( f"CatTensor failed, as it expected input keys =" f" {sorted(self.in_keys, key=_sort_keys)} but got a TensorDict with keys" f" {sorted(tensordict.keys(include_nested=True), key=_sort_keys)}" ) + if self.unsqueeze_if_oor: + pos_idx = self.dim > 0 + abs_idx = self.dim if pos_idx else -self.dim - 1 + values = [ + v + if abs_idx < v.ndimension() + else v.unsqueeze(0) + if not pos_idx + else v.unsqueeze(-1) + for v in values + ] + + out_tensor = torch.cat(values, dim=self.dim) + tensordict.set(self.out_keys[0], out_tensor) + if self._del_keys: + tensordict.exclude(*self.keys_to_exclude, inplace=True) return tensordict forward = _call From 9d1205aa2af866fef9c60c4f3386f975c9b7357f Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 11:43:30 +0000 Subject: [PATCH 05/14] amend --- torchrl/data/tensor_specs.py | 10 +++++++--- torchrl/envs/gym_like.py | 25 ++++++++++++++++--------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3632e6d6f42..ebb2bd5cf98 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3599,13 +3599,17 @@ def project(self, val: TensorDictBase) -> TensorDictBase: def rand(self, shape=None) -> TensorDictBase: if shape is None: shape = torch.Size([]) - _dict = { - key: self[key].rand(shape) for key in self.keys() if self[key] is not None - } + _dict = {} + for key, item in self.items(): + if item is not None: + _dict[key] = item.rand(shape) return TensorDict( _dict, batch_size=[*shape, *self.shape], device=self._device, + # No need to run checks since we know Composite is compliant with + # TensorDict requirements + _run_checks=False, ) def keys( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 4fc2b4abc7d..f04567206c0 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -21,7 +21,7 @@ UnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper - +from collections import Mapping class BaseInfoDictReader(metaclass=abc.ABCMeta): """Base class for info-readers.""" @@ -260,14 +260,21 @@ def read_obs( # naming it 'state' will result in envs that have a different name for the state vector # when queried with and without pixels observations["observation"] = observations.pop("state") - if not isinstance(observations, (TensorDict, dict)): - (key,) = itertools.islice(self.observation_spec.keys(True, True), 1) - observations = {key: observations} - for key, val in observations.items(): - observations[key] = self.observation_spec[key].encode( - val, ignore_device=True - ) - # observations = self.observation_spec.encode(observations, ignore_device=True) + if not isinstance(observations, Mapping): + observations_dict = None + for key, spec in self.observation_spec.items(True, True): + if observations_dict is None: + observations_dict = {} + observations_dict[key] = spec.encode(observations, ignore_device=True) + else: + raise RuntimeError("There is more than one observation key but only one observation " + "was found in the step data.") + observations = observations_dict + else: + for key, val in observations.items(): + observations[key] = self.observation_spec[key].encode( + val, ignore_device=True + ) return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: From 8ceb5c202de6f1e064e834cdf2f2d48bfa6c5ce9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 11:47:54 +0000 Subject: [PATCH 06/14] amend --- torchrl/envs/gym_like.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index f04567206c0..710cfff0e9f 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -21,7 +21,7 @@ UnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper -from collections import Mapping +from typing import Mapping class BaseInfoDictReader(metaclass=abc.ABCMeta): """Base class for info-readers.""" From ecbc56a6745cc8d7e70b9597b15953cedc4359e0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 13:37:15 +0000 Subject: [PATCH 07/14] amend --- torchrl/envs/common.py | 2 +- torchrl/envs/gym_like.py | 17 +++++++++++------ torchrl/envs/libs/gym.py | 4 ++-- torchrl/envs/transforms/transforms.py | 15 ++++++++++++--- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d189626d15d..7808de80ad2 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2551,7 +2551,7 @@ def _rollout_stop_early( if i == max_steps - 1: # we don't truncated as one could potentially continue the run break - tensordict = self._step_mdp_stop_early(tensordict) + tensordict = self._step_mdp(tensordict) # done and truncated are in done_keys # We read if any key is done. diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 710cfff0e9f..73d8cded45e 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -8,7 +8,7 @@ import abc import itertools import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -21,7 +21,7 @@ UnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper -from typing import Mapping + class BaseInfoDictReader(metaclass=abc.ABCMeta): """Base class for info-readers.""" @@ -265,10 +265,14 @@ def read_obs( for key, spec in self.observation_spec.items(True, True): if observations_dict is None: observations_dict = {} - observations_dict[key] = spec.encode(observations, ignore_device=True) + observations_dict[key] = spec.encode( + observations, ignore_device=True + ) else: - raise RuntimeError("There is more than one observation key but only one observation " - "was found in the step data.") + raise RuntimeError( + "There is more than one observation key but only one observation " + "was found in the step data." + ) observations = observations_dict else: for key, val in observations.items(): @@ -327,7 +331,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, _run_checks=False ) - tensordict_out = tensordict_out.to(self.device, non_blocking=True) + if self.device is not None: + tensordict_out = tensordict_out.to(self.device, non_blocking=True) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index aa43717e296..045261059b9 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1089,9 +1089,9 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811 # since it's not a bool, we make it so terminated = bool(terminated) - if isinstance(observations, list) and len(obs) == 1: + if isinstance(observations, list) and len(observations) == 1: # Until gym 0.25.2 we had rendered frames returned in lists of length 1 - obs = obs[0] + observations = observations[0] return (observations, reward, terminated, truncated, done, info) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b17bd06e966..fb7e778f957 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -317,9 +317,8 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: return state def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: - # # We create a shallow copy of the tensordict to avoid that changes are - # # exposed to the user: we'd like that the input keys remain unchanged - # # in the originating script if they're being transformed. + if not self.in_keys_inv: + return tensordict for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): data = tensordict.get(in_key, None) if data is not None: @@ -3281,6 +3280,16 @@ def __init__( in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, ): + if in_keys is not None and in_keys_inv is None: + warnings.warn( + "in_keys have been provided but not in_keys_inv. From v0.5, " + "this will result in in_keys_inv being an empty list whereas " + "now the input keys are retrieved automatically. " + "To silence this warning, pass the (possibly empty) " + "list of in_keys_inv.", + category=DeprecationWarning, + ) + self.dtype_in = dtype_in self.dtype_out = dtype_out super().__init__( From f85e309ce676f58d556d0fb1a20a1069d42a6c6f Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 13:38:41 +0000 Subject: [PATCH 08/14] amend --- torchrl/envs/libs/gym.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 045261059b9..5c5a387f762 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1063,9 +1063,9 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811 # since it's not a bool, we make it so terminated = bool(terminated) - if isinstance(observations, list) and len(obs) == 1: + if isinstance(observations, list) and len(observations) == 1: # Until gym 0.25.2 we had rendered frames returned in lists of length 1 - obs = obs[0] + observations = observations[0] return (observations, reward, terminated, truncated, done, info) From dbfffff9cbcc8d6caa1ac7af6bee99845f613016 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 15:41:06 +0000 Subject: [PATCH 09/14] amend --- torchrl/data/tensor_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ebb2bd5cf98..98b365df28b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3605,7 +3605,7 @@ def rand(self, shape=None) -> TensorDictBase: _dict[key] = item.rand(shape) return TensorDict( _dict, - batch_size=[*shape, *self.shape], + batch_size=torch.Size([*shape, *self.shape]), device=self._device, # No need to run checks since we know Composite is compliant with # TensorDict requirements From a00dd1246603d6e3986492aca25b0fdd25fadf29 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 16:02:38 +0000 Subject: [PATCH 10/14] amend --- torchrl/envs/common.py | 2 +- torchrl/envs/transforms/transforms.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 7808de80ad2..d885a3ae832 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2680,7 +2680,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ if self._simple_done: done = tensordict._get_str("done", default=None) - tensordict._set_str("done", done.clone(), validated=True, inplace=False) + tensordict._set_str("_reset", done.clone(), validated=True, inplace=False) any_done = done.any() else: any_done = _terminated_or_truncated( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index fb7e778f957..7d3a7cb0ab9 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -733,7 +733,6 @@ def input_spec(self) -> TensorSpec: def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # No need to clone here because inv does it already # tensordict = tensordict.clone(False) - next_preset = tensordict.get("next", None) tensordict_in = self.transform.inv(tensordict) next_tensordict = self.base_env._step(tensordict_in) From 28795c5c864cd19614577a73b0acd2523887f9fe Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 16:13:44 +0000 Subject: [PATCH 11/14] amend --- torchrl/envs/common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d885a3ae832..59262c98eff 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2680,8 +2680,11 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ if self._simple_done: done = tensordict._get_str("done", default=None) - tensordict._set_str("_reset", done.clone(), validated=True, inplace=False) any_done = done.any() + if any_done: + tensordict._set_str( + "_reset", done.clone(), validated=True, inplace=False + ) else: any_done = _terminated_or_truncated( tensordict, From 994e38ea16a1f839f897233882071df881e84778 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 18:06:24 +0000 Subject: [PATCH 12/14] amend --- torchrl/envs/gym_like.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 73d8cded45e..6ea36823166 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -6,7 +6,6 @@ from __future__ import annotations import abc -import itertools import warnings from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union @@ -261,18 +260,14 @@ def read_obs( # when queried with and without pixels observations["observation"] = observations.pop("state") if not isinstance(observations, Mapping): - observations_dict = None for key, spec in self.observation_spec.items(True, True): - if observations_dict is None: - observations_dict = {} - observations_dict[key] = spec.encode( - observations, ignore_device=True - ) - else: - raise RuntimeError( - "There is more than one observation key but only one observation " - "was found in the step data." - ) + observations_dict = {} + observations_dict[key] = spec.encode(observations, ignore_device=True) + # we don't check that there is only one spec because obs spec also + # contains the data spec of the info dict. + break + else: + raise RuntimeError("Could not find any element in observation_spec.") observations = observations_dict else: for key, val in observations.items(): From 3035bcfa8714a16d5e212effbbf5bb9caa824f3c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 18:07:07 +0000 Subject: [PATCH 13/14] amend --- torchrl/envs/gym_like.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 6ea36823166..3c52f941bdb 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -31,7 +31,8 @@ def __call__( ) -> TensorDictBase: raise NotImplementedError - @abc.abstractproperty + @property + @abc.abstractmethod def info_spec(self) -> Dict[str, TensorSpec]: raise NotImplementedError From 8c2d463a7ca7795b3d5ab5e034ba7837e530434d Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Mar 2024 20:13:15 +0000 Subject: [PATCH 14/14] amend --- test/test_cost.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 6865af1238e..da4b01d621a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -12151,7 +12151,6 @@ def test_args_kwargs_timedim(self, device): time_dim=-3, )[0] - v2 = vec_generalized_advantage_estimate( gamma=gamma, lmbda=lmbda,