diff --git a/test/test_transforms.py b/test/test_transforms.py index 4f84001480f..124e63aaad3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7711,10 +7711,10 @@ def _test_vecnorm_subproc_auto( queue_out.put(True) msg = queue_in.get(timeout=TIMEOUT) assert msg == "all_done" - t = env.transform - obs_sum = t._td.get("observation_sum").clone() - obs_ssq = t._td.get("observation_ssq").clone() - obs_count = t._td.get("observation_count").clone() + t = env.transform[1] + obs_sum = t._td.get(("some", "obs_sum")).clone() + obs_ssq = t._td.get(("some", "obs_ssq")).clone() + obs_count = t._td.get(("some", "obs_count")).clone() reward_sum = t._td.get("reward_sum").clone() reward_ssq = t._td.get("reward_ssq").clone() reward_count = t._td.get("reward_count").clone() @@ -7729,18 +7729,34 @@ def _test_vecnorm_subproc_auto( queue_in.close() del queue_in, queue_out + @property + def rename_t(self): + return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")]) + @pytest.mark.parametrize("nprc", [2, 5]) def test_vecnorm_parallel_auto(self, nprc): queues = [] prcs = [] if _has_gym: - make_env = EnvCreator( - lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm(decay=1.0)) + maker = lambda: TransformedEnv( + GymEnv(PENDULUM_VERSIONED()), + Compose( + self.rename_t, + VecNorm(decay=1.0, in_keys=[("some", "obs"), "reward"]), + ), ) + check_env_specs(maker()) + make_env = EnvCreator(maker) else: - make_env = EnvCreator( - lambda: TransformedEnv(ContinuousActionVecMockEnv(), VecNorm(decay=1.0)) + maker = lambda: TransformedEnv( + ContinuousActionVecMockEnv(), + Compose( + self.rename_t, + VecNorm(decay=1.0, in_keys=[("some", "obs"), "reward"]), + ), ) + check_env_specs(maker()) + make_env = EnvCreator(maker) for idx in range(nprc): prc_queue_in = mp.Queue(1) @@ -7764,11 +7780,11 @@ def test_vecnorm_parallel_auto(self, nprc): for idx in range(nprc): queues[idx][1].put(msg) - td = make_env.state_dict()["_extra_state"]["td"] + td = make_env.state_dict()["transforms.1._extra_state"]["td"] - obs_sum = td.get("observation_sum").clone() - obs_ssq = td.get("observation_ssq").clone() - obs_count = td.get("observation_count").clone() + obs_sum = td.get(("some", "obs_sum")).clone() + obs_ssq = td.get(("some", "obs_ssq")).clone() + obs_count = td.get(("some", "obs_count")).clone() reward_sum = td.get("reward_sum").clone() reward_ssq = td.get("reward_ssq").clone() reward_count = td.get("reward_count").clone() @@ -7836,11 +7852,21 @@ def _run_parallelenv(parallel_env, queue_in, queue_out): def test_parallelenv_vecnorm(self): if _has_gym: make_env = EnvCreator( - lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm()) + lambda: TransformedEnv( + GymEnv(PENDULUM_VERSIONED()), + Compose( + self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"]) + ), + ) ) else: make_env = EnvCreator( - lambda: TransformedEnv(ContinuousActionVecMockEnv(), VecNorm()) + lambda: TransformedEnv( + ContinuousActionVecMockEnv(), + Compose( + self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"]) + ), + ) ) parallel_env = ParallelEnv( 2, diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 1dff10eb122..161ae04573a 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -27,6 +27,7 @@ import numpy as np import torch from packaging.version import parse +from tensordict import unravel_key from tensordict.utils import NestedKey from torch import multiprocessing as mp @@ -719,6 +720,14 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: return key[:-1] + (new_ending,) +def _append_last(key: NestedKey, new_suffix: str) -> NestedKey: + key = unravel_key(key) + if isinstance(key, str): + return key + new_suffix + else: + return key[:-1] + (key[-1] + new_suffix,) + + class _rng_decorator(_DecoratorContextManager): """Temporarily sets the seed and sets back the rng state when exiting.""" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index c646ddc3292..4c61dd82f88 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -45,7 +45,7 @@ from torch import nn, Tensor from torch.utils._pytree import tree_map -from torchrl._utils import _ends_with, _replace_last +from torchrl._utils import _append_last, _ends_with, _replace_last from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, @@ -4854,9 +4854,9 @@ def __init__( if shared_td is not None: for key in in_keys: if ( - (key + "_sum" not in shared_td.keys()) - or (key + "_ssq" not in shared_td.keys()) - or (key + "_count" not in shared_td.keys()) + (_append_last(key, "_sum") not in shared_td.keys()) + or (_append_last(key, "_ssq") not in shared_td.keys()) + or (_append_last(key, "_count") not in shared_td.keys()) ): raise KeyError( f"key {key} not present in the shared tensordict " @@ -4868,16 +4868,12 @@ def __init__( self.shapes = shapes self.eps = eps - def _key_str(self, key): - if not isinstance(key, str): - key = "_".join(key) - return key - def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: + # TODO: remove this decorator when trackers are in data with _set_missing_tolerance(self, True): - tensordict_reset = self._call(tensordict_reset) + return self._call(tensordict_reset) return tensordict_reset def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -4886,6 +4882,9 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: for key in self.in_keys: if key not in tensordict.keys(include_nested=True): + # TODO: init missing rewards with this + # for key_suffix in [_append_last(key, suffix) for suffix in ("_sum", "_ssq", "_count")]: + # tensordict.set(key_suffix, self.container.observation_spec[key_suffix].zero()) continue self._init(tensordict, key) # update and standardize @@ -4903,18 +4902,17 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: forward = _call def _init(self, tensordict: TensorDictBase, key: str) -> None: - key_str = self._key_str(key) - if self._td is None or key_str + "_sum" not in self._td.keys(): - if key is not key_str and key_str in tensordict.keys(): + if self._td is None or _append_last(key, "_sum") not in self._td.keys(True): + if key is not key and key in tensordict.keys(): raise RuntimeError( - f"Conflicting key names: {key_str} from VecNorm and input tensordict keys." + f"Conflicting key names: {key} from VecNorm and input tensordict keys." ) if self.shapes is None: td_view = tensordict.view(-1) td_select = td_view[0] item = td_select.get(key) - d = {key_str + "_sum": torch.zeros_like(item)} - d.update({key_str + "_ssq": torch.zeros_like(item)}) + d = {_append_last(key, "_sum"): torch.zeros_like(item)} + d.update({_append_last(key, "_ssq"): torch.zeros_like(item)}) else: idx = 0 for in_key in self.in_keys: @@ -4925,13 +4923,13 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None: shape = self.shapes[idx] item = tensordict.get(key) d = { - key_str - + "_sum": torch.zeros(shape, device=item.device, dtype=item.dtype) + _append_last(key, "_sum"): torch.zeros( + shape, device=item.device, dtype=item.dtype + ) } d.update( { - key_str - + "_ssq": torch.zeros( + _append_last(key, "_ssq"): torch.zeros( shape, device=item.device, dtype=item.dtype ) } @@ -4939,8 +4937,9 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None: d.update( { - key_str - + "_count": torch.zeros(1, device=item.device, dtype=torch.float) + _append_last(key, "_count"): torch.zeros( + 1, device=item.device, dtype=torch.float + ) } ) if self._td is None: @@ -4951,34 +4950,32 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None: pass def _update(self, key, value, N) -> torch.Tensor: - key = self._key_str(key) - _sum = self._td.get(key + "_sum") - _ssq = self._td.get(key + "_ssq") - _count = self._td.get(key + "_count") + _sum = self._td.get(_append_last(key, "_sum")) + _ssq = self._td.get(_append_last(key, "_ssq")) + _count = self._td.get(_append_last(key, "_count")) - _sum = self._td.get(key + "_sum") value_sum = _sum_left(value, _sum) _sum *= self.decay _sum += value_sum self._td.set_( - key + "_sum", + _append_last(key, "_sum"), _sum, ) - _ssq = self._td.get(key + "_ssq") + _ssq = self._td.get(_append_last(key, "_ssq")) value_ssq = _sum_left(value.pow(2), _ssq) _ssq *= self.decay _ssq += value_ssq self._td.set_( - key + "_ssq", + _append_last(key, "_ssq"), _ssq, ) - _count = self._td.get(key + "_count") + _count = self._td.get(_append_last(key, "_count")) _count *= self.decay _count += N self._td.set_( - key + "_count", + _append_last(key, "_count"), _count, ) @@ -4990,9 +4987,9 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]: """Converts VecNorm into an ObservationNorm class that can be used at inference time.""" out = [] for key in self.in_keys: - _sum = self._td.get(key + "_sum") - _ssq = self._td.get(key + "_ssq") - _count = self._td.get(key + "_count") + _sum = self._td.get(_append_last(key, "_sum")) + _ssq = self._td.get(_append_last(key, "_ssq")) + _count = self._td.get(_append_last(key, "_count")) mean = _sum / _count std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt() @@ -5056,9 +5053,9 @@ def build_td_for_shared_vecnorm( ) keys = list(td_select.keys()) for key in keys: - td_select.set(key + "_ssq", td_select.get(key).clone()) + td_select.set(_append_last(key, "_ssq"), td_select.get(key).clone()) td_select.set( - key + "_count", + _append_last(key, "_count"), torch.zeros( *td.batch_size, 1, @@ -5066,7 +5063,7 @@ def build_td_for_shared_vecnorm( dtype=torch.float, ), ) - td_select.rename_key_(key, key + "_sum") + td_select.rename_key_(key, _append_last(key, "_sum")) td_select.exclude(*keys).zero_() td_select = td_select.unflatten_keys(sep) if memmap: @@ -5112,6 +5109,32 @@ def __setstate__(self, state: Dict[str, Any]): state["lock"] = _lock self.__dict__.update(state) + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if isinstance(observation_spec, BoundedTensorSpec): + return UnboundedContinuousTensorSpec( + shape=observation_spec.shape, + dtype=observation_spec.dtype, + device=observation_spec.device, + ) + return observation_spec + + # TODO: incorporate this when trackers are part of the data + # def transform_output_spec(self, output_spec: TensorSpec) -> TensorSpec: + # observation_spec = output_spec["full_observation_spec"] + # reward_spec = output_spec["full_reward_spec"] + # for key in list(observation_spec.keys(True, True)): + # if key in self.in_keys: + # observation_spec[_append_last(key, "_sum")] = observation_spec[key].clone() + # observation_spec[_append_last(key, "_ssq")] = observation_spec[key].clone() + # observation_spec[_append_last(key, "_count")] = observation_spec[key].clone() + # for key in list(reward_spec.keys(True, True)): + # if key in self.in_keys: + # observation_spec[_append_last(key, "_sum")] = reward_spec[key].clone() + # observation_spec[_append_last(key, "_ssq")] = reward_spec[key].clone() + # observation_spec[_append_last(key, "_count")] = reward_spec[key].clone() + # return output_spec + class RewardSum(Transform): """Tracks episode cumulative rewards.