diff --git a/test/test_cost.py b/test/test_cost.py index 6865af1238e..da4b01d621a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -12151,7 +12151,6 @@ def test_args_kwargs_timedim(self, device): time_dim=-3, )[0] - v2 = vec_generalized_advantage_estimate( gamma=gamma, lmbda=lmbda, diff --git a/test/test_env.py b/test/test_env.py index d13a7ed7934..1241cde1dc9 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -79,6 +79,7 @@ from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv from torchrl.envs.utils import ( + _StepMDP, _terminated_or_truncated, check_env_specs, check_marl_grouping, @@ -1312,6 +1313,54 @@ def test_steptensordict( if has_out: assert out is next_tensordict + @pytest.mark.parametrize("keep_other", [True, False]) + @pytest.mark.parametrize("exclude_reward", [True, False]) + @pytest.mark.parametrize("exclude_done", [False, True]) + @pytest.mark.parametrize("exclude_action", [False, True]) + @pytest.mark.parametrize( + "envcls", + [ + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + NestedCountingEnv, + CountingBatchedEnv, + HeterogeneousCountingEnv, + DiscreteActionConvMockEnv, + ], + ) + def test_step_class( + self, + envcls, + keep_other, + exclude_reward, + exclude_done, + exclude_action, + ): + torch.manual_seed(0) + env = envcls() + + tensordict = env.rand_step(env.reset()) + out = step_mdp( + tensordict.lock_(), + keep_other=keep_other, + exclude_reward=exclude_reward, + exclude_done=exclude_done, + exclude_action=exclude_action, + done_keys=env.done_keys, + action_keys=env.action_keys, + reward_keys=env.reward_keys, + ) + step_func = _StepMDP( + env, + keep_other=keep_other, + exclude_reward=exclude_reward, + exclude_done=exclude_done, + exclude_action=exclude_action, + ) + out2 = step_func(tensordict) + assert (out == out2).all() + @pytest.mark.parametrize("nested_obs", [True, False]) @pytest.mark.parametrize("nested_action", [True, False]) @pytest.mark.parametrize("nested_done", [True, False]) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 69f3c7b3efd..98b365df28b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1555,9 +1555,9 @@ def __init__( dtype = torch.get_default_dtype() if not isinstance(low, torch.Tensor): - low = torch.as_tensor(low, dtype=dtype, device=device) + low = torch.tensor(low, dtype=dtype, device=device) if not isinstance(high, torch.Tensor): - high = torch.as_tensor(high, dtype=dtype, device=device) + high = torch.tensor(high, dtype=dtype, device=device) if high.device != device: high = high.to(device) if low.device != device: @@ -3599,13 +3599,17 @@ def project(self, val: TensorDictBase) -> TensorDictBase: def rand(self, shape=None) -> TensorDictBase: if shape is None: shape = torch.Size([]) - _dict = { - key: self[key].rand(shape) for key in self.keys() if self[key] is not None - } + _dict = {} + for key, item in self.items(): + if item is not None: + _dict[key] = item.rand(shape) return TensorDict( _dict, - batch_size=[*shape, *self.shape], + batch_size=torch.Size([*shape, *self.shape]), device=self._device, + # No need to run checks since we know Composite is compliant with + # TensorDict requirements + _run_checks=False, ) def keys( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 06a3ddfee3f..59262c98eff 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -28,10 +28,10 @@ from torchrl.envs.utils import ( _make_compatible_policy, _repr_by_depth, + _StepMDP, _terminated_or_truncated, _update_during_reset, get_available_libraries, - step_mdp, ) LIBRARIES = get_available_libraries() @@ -2513,6 +2513,14 @@ def rollout( out_td.refine_names(..., "time") return out_td + @property + def _step_mdp(self): + step_func = self.__dict__.get("_step_mdp_value", None) + if step_func is None: + step_func = _StepMDP(self, exclude_action=False) + self.__dict__["_step_mdp_value"] = step_func + return step_func + def _rollout_stop_early( self, *, @@ -2543,15 +2551,8 @@ def _rollout_stop_early( if i == max_steps - 1: # we don't truncated as one could potentially continue the run break - tensordict = step_mdp( - tensordict, - keep_other=True, - exclude_action=False, - exclude_reward=True, - reward_keys=self.reward_keys, - action_keys=self.action_keys, - done_keys=self.done_keys, - ) + tensordict = self._step_mdp(tensordict) + # done and truncated are in done_keys # We read if any key is done. any_done = _terminated_or_truncated( @@ -2649,18 +2650,23 @@ def step_and_maybe_reset( tensordict = self.step(tensordict) # done and truncated are in done_keys # We read if any key is done. - tensordict_ = step_mdp( - tensordict, - keep_other=True, - exclude_action=False, - exclude_reward=True, - reward_keys=self.reward_keys, - action_keys=self.action_keys, - done_keys=self.done_keys, - ) + tensordict_ = self._step_mdp(tensordict) tensordict_ = self.maybe_reset(tensordict_) return tensordict, tensordict_ + @property + def _simple_done(self): + _simple_done = self.__dict__.get("_simple_done_value", None) + if _simple_done is None: + key_set = set(self.full_done_spec.keys()) + _simple_done = key_set == { + "done", + "truncated", + "terminated", + } or key_set == {"done", "terminated"} + self.__dict__["_simple_done_value"] = _simple_done + return _simple_done + def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Checks the done keys of the input tensordict and, if needed, resets the environment where it is done. @@ -2672,11 +2678,19 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: not reset and contains the new reset data where the environment was reset. """ - any_done = _terminated_or_truncated( - tensordict, - full_done_spec=self.output_spec["full_done_spec"], - key="_reset", - ) + if self._simple_done: + done = tensordict._get_str("done", default=None) + any_done = done.any() + if any_done: + tensordict._set_str( + "_reset", done.clone(), validated=True, inplace=False + ) + else: + any_done = _terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key="_reset", + ) if any_done: tensordict = self.reset(tensordict) return tensordict diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index d3b3dfd659c..3c52f941bdb 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -6,9 +6,8 @@ from __future__ import annotations import abc -import itertools import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -32,7 +31,8 @@ def __call__( ) -> TensorDictBase: raise NotImplementedError - @abc.abstractproperty + @property + @abc.abstractmethod def info_spec(self) -> Dict[str, TensorSpec]: raise NotImplementedError @@ -235,7 +235,14 @@ def read_reward(self, reward): reward (torch.Tensor or TensorDict): reward to be mapped. """ - return self.reward_spec.encode(reward, ignore_device=True) + if isinstance(reward, int) and reward == 0: + return self.reward_spec.zero() + reward = self.reward_spec.encode(reward, ignore_device=True) + + if reward is None: + reward = torch.tensor(np.nan).expand(self.reward_spec.shape) + + return reward def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] @@ -253,14 +260,21 @@ def read_obs( # naming it 'state' will result in envs that have a different name for the state vector # when queried with and without pixels observations["observation"] = observations.pop("state") - if not isinstance(observations, (TensorDict, dict)): - (key,) = itertools.islice(self.observation_spec.keys(True, True), 1) - observations = {key: observations} - for key, val in observations.items(): - observations[key] = self.observation_spec[key].encode( - val, ignore_device=True - ) - # observations = self.observation_spec.encode(observations, ignore_device=True) + if not isinstance(observations, Mapping): + for key, spec in self.observation_spec.items(True, True): + observations_dict = {} + observations_dict[key] = spec.encode(observations, ignore_device=True) + # we don't check that there is only one spec because obs spec also + # contains the data spec of the info dict. + break + else: + raise RuntimeError("Could not find any element in observation_spec.") + observations = observations_dict + else: + for key, val in observations.items(): + observations[key] = self.observation_spec[key].encode( + val, ignore_device=True + ) return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -277,14 +291,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: done, info_dict, ) = self._output_transform(self._env.step(action_np)) - if isinstance(obs, list) and len(obs) == 1: - # Until gym 0.25.2 we had rendered frames returned in lists of length 1 - obs = obs[0] - if _reward is None: - _reward = self.reward_spec.zero() - - reward = reward + _reward + if _reward is not None: + reward = reward + _reward terminated, truncated, done, do_break = self.read_done( terminated=terminated, truncated=truncated, done=done @@ -294,17 +303,13 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = self.read_reward(reward) obs_dict = self.read_obs(obs) - - if reward is None: - reward = torch.tensor(np.nan).expand(self.reward_spec.shape) - obs_dict[self.reward_key] = reward # if truncated/terminated is not in the keys, we just don't pass it even if it # is defined. if terminated is None: terminated = done - if truncated is not None and "truncated" in self.done_keys: + if truncated is not None: obs_dict["truncated"] = truncated obs_dict["done"] = done obs_dict["terminated"] = terminated @@ -322,7 +327,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, _run_checks=False ) - tensordict_out = tensordict_out.to(self.device, non_blocking=True) + if self.device is not None: + tensordict_out = tensordict_out.to(self.device, non_blocking=True) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index a419b013722..5c5a387f762 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1062,6 +1062,11 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811 # if it's not a ndarray, we must return bool # since it's not a bool, we make it so terminated = bool(terminated) + + if isinstance(observations, list) and len(observations) == 1: + # Until gym 0.25.2 we had rendered frames returned in lists of length 1 + observations = observations[0] + return (observations, reward, terminated, truncated, done, info) @implement_for("gym", "0.24", "0.26") @@ -1083,6 +1088,11 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811 # if it's not a ndarray, we must return bool # since it's not a bool, we make it so terminated = bool(terminated) + + if isinstance(observations, list) and len(observations) == 1: + # Until gym 0.25.2 we had rendered frames returned in lists of length 1 + observations = observations[0] + return (observations, reward, terminated, truncated, done, info) @implement_for("gym", "0.26", None) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6448bd9bd0f..7d3a7cb0ab9 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -317,9 +317,8 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: return state def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: - # # We create a shallow copy of the tensordict to avoid that changes are - # # exposed to the user: we'd like that the input keys remain unchanged - # # in the originating script if they're being transformed. + if not self.in_keys_inv: + return tensordict for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): data = tensordict.get(in_key, None) if data is not None: @@ -732,7 +731,8 @@ def input_spec(self) -> TensorSpec: return input_spec def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = tensordict.clone(False) + # No need to clone here because inv does it already + # tensordict = tensordict.clone(False) next_preset = tensordict.get("next", None) tensordict_in = self.transform.inv(tensordict) next_tensordict = self.base_env._step(tensordict_in) @@ -3279,6 +3279,16 @@ def __init__( in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, ): + if in_keys is not None and in_keys_inv is None: + warnings.warn( + "in_keys have been provided but not in_keys_inv. From v0.5, " + "this will result in in_keys_inv being an empty list whereas " + "now the input keys are retrieved automatically. " + "To silence this warning, pass the (possibly empty) " + "list of in_keys_inv.", + category=DeprecationWarning, + ) + self.dtype_in = dtype_in self.dtype_out = dtype_out super().__init__( @@ -3831,30 +3841,29 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: self.in_keys = self._find_in_keys() self._initialized = True - if all(key in tensordict.keys(include_nested=True) for key in self.in_keys): - values = [tensordict.get(key) for key in self.in_keys] - if self.unsqueeze_if_oor: - pos_idx = self.dim > 0 - abs_idx = self.dim if pos_idx else -self.dim - 1 - values = [ - v - if abs_idx < v.ndimension() - else v.unsqueeze(0) - if not pos_idx - else v.unsqueeze(-1) - for v in values - ] - - out_tensor = torch.cat(values, dim=self.dim) - tensordict.set(self.out_keys[0], out_tensor) - if self._del_keys: - tensordict.exclude(*self.keys_to_exclude, inplace=True) - else: + values = [tensordict.get(key, None) for key in self.in_keys] + if any(value is None for value in values): raise Exception( f"CatTensor failed, as it expected input keys =" f" {sorted(self.in_keys, key=_sort_keys)} but got a TensorDict with keys" f" {sorted(tensordict.keys(include_nested=True), key=_sort_keys)}" ) + if self.unsqueeze_if_oor: + pos_idx = self.dim > 0 + abs_idx = self.dim if pos_idx else -self.dim - 1 + values = [ + v + if abs_idx < v.ndimension() + else v.unsqueeze(0) + if not pos_idx + else v.unsqueeze(-1) + for v in values + ] + + out_tensor = torch.cat(values, dim=self.dim) + tensordict.set(self.out_keys[0], out_tensor) + if self._del_keys: + tensordict.exclude(*self.keys_to_exclude, inplace=True) return tensordict forward = _call diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e779bfc165d..d50473995c4 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -12,6 +12,7 @@ import inspect import os import re +import warnings from enum import Enum from typing import Any, Dict, List, Union @@ -20,6 +21,7 @@ from tensordict import ( is_tensor_collection, LazyStackedTensorDict, + TensorDict, TensorDictBase, unravel_key, ) @@ -41,6 +43,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, + NO_DEFAULT, TensorSpec, UnboundedContinuousTensorSpec, ) @@ -82,6 +85,150 @@ def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() +class _StepMDP: + """Stateful version of step_mdp. + + Precomputes the list of keys to include and exclude during a call to step_mdp + to reduce runtime. + + """ + + def __init__( + self, + env, + *, + keep_other: bool = True, + exclude_reward: bool = True, + exclude_done: bool = False, + exclude_action: bool = True, + ): + action_keys = env.action_keys + done_keys = env.done_keys + reward_keys = env.reward_keys + observation_keys = env.full_observation_spec.keys(True, True) + state_keys = env.full_state_spec.keys(True, True) + self.action_keys = [unravel_key(key) for key in action_keys] + self.done_keys = [unravel_key(key) for key in done_keys] + self.reward_keys = [unravel_key(key) for key in reward_keys] + self.observation_keys = [unravel_key(key) for key in observation_keys] + self.state_keys = [unravel_key(key) for key in state_keys] + + excluded = set() + if exclude_reward: + excluded = excluded.union(self.reward_keys) + if exclude_done: + excluded = excluded.union(self.done_keys) + if exclude_action: + excluded = excluded.union(self.action_keys) + + self.excluded = [unravel_key(key) for key in excluded] + + self.keep_other = keep_other + self.exclude_action = exclude_action + + self.keys_from_next = list(self.observation_keys) + if not exclude_reward: + self.keys_from_next += self.reward_keys + if not exclude_done: + self.keys_from_next += self.done_keys + self.keys_from_root = [] + if not exclude_action: + self.keys_from_root += self.action_keys + if keep_other: + self.keys_from_root += self.state_keys + self.keys_from_root = self._repr_key_list_as_tree(self.keys_from_root) + self.keys_from_next = self._repr_key_list_as_tree(self.keys_from_next) + self.validated = None + + def validate(self, tensordict): + if self.validated: + return True + if self.validated is None: + # check that the key set of the tensordict matches what is expected + expected = ( + self.state_keys + + self.action_keys + + self.done_keys + + self.observation_keys + + [unravel_key(("next", key)) for key in self.observation_keys] + + [unravel_key(("next", key)) for key in self.done_keys] + + [unravel_key(("next", key)) for key in self.reward_keys] + ) + actual = set(tensordict.keys(True, True)) + self.validated = set(expected) == actual + if not self.validated: + warnings.warn( + "The expected key set and actual key set differ. " + "This will work but with a slower throughput than " + "when the specs match exactly the actual key set " + "in the data. " + f"Expected - Actual keys={set(expected) - actual}, \n" + f"Actual - Expected keys={actual- set(expected)}." + ) + return self.validated + + @staticmethod + def _repr_key_list_as_tree(key_list): + """Represents the keys as a tree to facilitate iteration.""" + key_dict = {key: torch.zeros(()) for key in key_list} + td = TensorDict(key_dict) + return tree_map(lambda x: None, td.to_dict()) + + @classmethod + def _grab_and_place( + cls, nested_key_dict: dict, data_in: TensorDictBase, data_out: TensorDictBase + ): + for key, subdict in nested_key_dict.items(): + val = data_in._get_str(key, NO_DEFAULT) + if subdict is not None: + val_out = data_out._get_str(key, None) + if val_out is None: + val_out = val.empty() + if isinstance(val, LazyStackedTensorDict): + + val = LazyStackedTensorDict( + *( + cls._grab_and_place(subdict, _val, _val_out) + for (_val, _val_out) in zip( + val.unbind(val.stack_dim), + val_out.unbind(val_out.stack_dim), + ) + ), + stack_dim=val.stack_dim, + ) + else: + val = cls._grab_and_place(subdict, val, val_out) + data_out._set_str(key, val, validated=True, inplace=False) + return data_out + + def __call__(self, tensordict): + if isinstance(tensordict, LazyStackedTensorDict): + out = LazyStackedTensorDict.lazy_stack( + [self.__call__(td) for td in tensordict.tensordicts], + tensordict.stack_dim, + ) + return out + + next_td = tensordict._get_str("next", None) + out = next_td.empty() + if self.validate(tensordict): + self._grab_and_place(self.keys_from_root, tensordict, out) + self._grab_and_place(self.keys_from_next, next_td, out) + return out + else: + total_key = () + if self.keep_other: + for key in tensordict.keys(): + if key != "next": + _set(tensordict, out, key, total_key, self.excluded) + elif not self.exclude_action: + for action_key in self.action_keys: + _set_single_key(tensordict, out, action_key) + for key in next_td.keys(): + _set(next_td, out, key, total_key, self.excluded) + return out + + def step_mdp( tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, @@ -297,6 +444,7 @@ def _set(source, dest, key, total_key, excluded): try: val = source.get(key) if is_tensor_collection(val): + # if val is a tensordict we need to copy the structure new_val = dest.get(key, None) if new_val is None: new_val = val.empty()