From 1a480c73b66bb46d45dbf898c11aa1be716a5896 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:32:09 +0000 Subject: [PATCH 01/17] init --- torchrl/envs/common.py | 2 +- torchrl/envs/transforms/transforms.py | 60 +++++++++++++++------------ 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b9bf02fcdc5..c698c473200 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -399,7 +399,7 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - self.is_done = tensordict_reset.get( + self.is_done = tensordict_reset.set_default( "done", torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b14ae8b0508..d9805e654ac 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1960,40 +1960,48 @@ def base_env(self): def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Do no-op action for a number of steps in [1, noop_max].""" + td_reset = tensordict.clone(False) + tensordict = tensordict.clone(False) + # check that there is a single done state -- behaviour is undefined for multiple dones parent = self.parent - # keys = tensordict.keys() + if tensordict.get("done").numel() > 1: + raise ValueError( + "there is more than one done state in the parent environment. " + "NoopResetEnv is designed to work on single env instances, as partial reset " + "is currently not supported. If you feel like this is a missing feature, submit " + "an issue on TorchRL github repo. " + "In case you are trying to use NoopResetEnv over a batch of environments, know " + "that you can have a transformed batch of transformed envs, such as: " + "`TransformedEnv(ParallelEnv(3, lambda: TransformedEnv(MyEnv(), NoopResetEnv(3))), OtherTransform())`." + ) noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) - i = 0 trial = 0 - while i < noops: - i += 1 - tensordict = parent.rand_step(tensordict) - tensordict = step_mdp(tensordict) - if parent.is_done: - parent.reset() - i = 0 - trial += 1 - if trial > _MAX_NOOPS_TRIALS: - tensordict = parent.reset(tensordict) - tensordict = parent.rand_step(tensordict) + while True: + i = 0 + while i < noops: + i += 1 + tensordict = parent.rand_step(tensordict) + tensordict = step_mdp(tensordict, exclude_done=False) + if tensordict.get("is_done"): + tensordict = parent.reset(td_reset.clone(False)) break - if parent.is_done: - raise RuntimeError("NoopResetEnv concluded with done environment") - # td = step_mdp( - # tensordict, exclude_done=False, exclude_reward=True, exclude_action=True - # ) - - # for k in keys: - # if k not in td.keys(): - # td.set(k, tensordict.get(k)) - - # # replace the next_ prefix - # for out_key in parent.observation_spec: - # td.rename_key(out_key[5:], out_key) + else: + break + + trial += 1 + if trial > _MAX_NOOPS_TRIALS: + tensordict = parent.rand_step(tensordict) + if tensordict.get("is_done"): + raise RuntimeError( + f"parent is still done after a single random step (i={i})." + ) + break + if tensordict.get("is_done"): + raise RuntimeError("NoopResetEnv concluded with done environment") return tensordict def __repr__(self) -> str: From 35ea59704dee3136d9212fffc4a4d80103cd16d4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:40:21 +0000 Subject: [PATCH 02/17] init --- torchrl/envs/common.py | 62 ++++++++++----------------- torchrl/envs/env_creator.py | 2 +- torchrl/envs/gym_like.py | 4 +- torchrl/envs/libs/brax.py | 1 - torchrl/envs/libs/gym.py | 2 +- torchrl/envs/libs/jumanji.py | 4 -- torchrl/envs/model_based/common.py | 1 - torchrl/envs/transforms/transforms.py | 18 ++------ torchrl/envs/vec_env.py | 14 +----- 9 files changed, 30 insertions(+), 78 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c698c473200..5bf65d16c6b 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -186,8 +186,6 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta): - reward_spec (TensorSpec): sampling spec of the rewards; - batch_size (torch.Size): number of environments contained in the instance; - device (torch.device): device where the env input and output are expected to live - - is_done (torch.Tensor): boolean value(s) indicating if the environment has reached a done state since the - last reset - run_type_checks (bool): if True, the observation and reward dtypes will be compared against their respective spec and an exception will be raised if they don't match. @@ -212,7 +210,6 @@ def __init__( super().__init__() if device is not None: self.device = torch.device(device) - self._is_done = None self.dtype = dtype_map.get(dtype, dtype) if "is_closed" not in self.__dir__(): self.is_closed = True @@ -331,7 +328,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " "tensordict.select()) inside _step before writing new tensors onto this new instance." ) - self.is_done = tensordict_out.get("done") if self.run_type_checks: for key in self._select_observation_keys(tensordict_out): obs = tensordict_out.get(key) @@ -399,11 +395,11 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - self.is_done = tensordict_reset.set_default( + tensordict_reset.set_default( "done", torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), ) - if self.is_done: + if tensordict.get("done").any(): raise RuntimeError( f"Env {self} was done after reset. This is (currently) not allowed." ) @@ -452,16 +448,6 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: f"got {tensordict.batch_size} and {self.batch_size}" ) - def is_done_get_fn(self) -> bool: - if self._is_done is None: - self._is_done = torch.zeros(self.batch_size, device=self.device) - return self._is_done.all() - - def is_done_set_fn(self, val: torch.Tensor) -> None: - self._is_done = val - - is_done = property(is_done_get_fn, is_done_set_fn) - def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: """Performs a random step in the environment given the action_spec attribute. @@ -548,30 +534,27 @@ def policy(td): return td.set("action", self.action_spec.rand(self.batch_size)) tensordicts = [] - if not self.is_done: - for i in range(max_steps): - if auto_cast_to_device: - tensordict = tensordict.to(policy_device) - tensordict = policy(tensordict) - if auto_cast_to_device: - tensordict = tensordict.to(env_device) - tensordict = self.step(tensordict) - tensordicts.append(tensordict.clone()) - if ( - break_when_any_done and tensordict.get("done").any() - ) or i == max_steps - 1: - break - tensordict = step_mdp( - tensordict, - keep_other=True, - exclude_reward=False, - exclude_action=False, - ) + for i in range(max_steps): + if auto_cast_to_device: + tensordict = tensordict.to(policy_device) + tensordict = policy(tensordict) + if auto_cast_to_device: + tensordict = tensordict.to(env_device) + tensordict = self.step(tensordict) + tensordicts.append(tensordict.clone()) + if ( + break_when_any_done and tensordict.get("done").any() + ) or i == max_steps - 1: + break + tensordict = step_mdp( + tensordict, + keep_other=True, + exclude_reward=False, + exclude_action=False, + ) - if callback is not None: - callback(self, tensordict) - else: - raise Exception("reset env before calling rollout!") + if callback is not None: + callback(self, tensordict) batch_size = self.batch_size if tensordict is None else tensordict.batch_size @@ -640,7 +623,6 @@ def to(self, device: DEVICE_TYPING) -> EnvBase: self.reward_spec = self.reward_spec.to(device) self.observation_spec = self.observation_spec.to(device) self.input_spec = self.input_spec.to(device) - self.is_done = self.is_done.to(device) self.device = device return super().to(device) diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 0fbb2b15943..c9121c76dd1 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -49,7 +49,7 @@ class EnvCreator: ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) - ... if env.is_done: + ... if tensordict.get("done"): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 7f832a10b74..9021a4f590b 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -215,7 +215,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = np.nan reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) done = self._to_tensor(done, dtype=torch.bool) - self.is_done = done tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, device=self.device @@ -240,8 +239,7 @@ def _reset( batch_size=self.batch_size, device=self.device, ) - self._is_done = torch.zeros(self.batch_size, dtype=torch.bool) - tensordict_out.set("done", self._is_done) + tensordict_out.set("done", torch.zeros(*self.batch_size, 1, dtype=torch.bool)) return tensordict_out def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index e0e271b5ede..47dc69ba2a5 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -224,7 +224,6 @@ def _step_with_grad(self, tensordict: TensorDictBase): # extract done values next_done = next_state_nograd["done"].bool() - self._is_done = next_done # merge with tensors with grad function next_state = next_state_nograd diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 1c4d0a2680c..309ed3b066a 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -272,7 +272,7 @@ def _make_specs(self, env: "gym.Env") -> None: ) def _init_env(self): - self.reset() # make sure that _is_done is populated + self.reset() def __repr__(self) -> str: return ( diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index e430acba5bc..2e59e452b19 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -255,8 +255,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: done = timestep.step_type == self.lib.types.StepType.LAST done = _ndarray_to_tensor(done).view(torch.bool).to(self.device) - self._is_done = done - # build results tensordict_out = TensorDict( source=obs_dict, @@ -288,8 +286,6 @@ def _reset( obs_dict = self.read_obs(timestep.observation) done = torch.zeros(self.batch_size, dtype=torch.bool) - self._is_done = done - # build results tensordict_out = TensorDict( source=obs_dict, diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 1ff0cd03712..e63825f13f7 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -91,7 +91,6 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): - input_spec (CompositeSpec): sampling spec of the inputs; - batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes. - device (torch.device): device where the env input and output are expected to live - - is_done (torch.Tensor): boolean value(s) indicating if the environment has reached a done state since the last reset Args: world_model (nn.Module): model that generates world states and its corresponding rewards; diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d9805e654ac..4164cc44cfd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -485,16 +485,6 @@ def is_closed(self) -> bool: def is_closed(self, value: bool): self.base_env.is_closed = value - @property - def is_done(self) -> bool: - if self._is_done is None: - return self.base_env.is_done - return self._is_done.all() - - @is_done.setter - def is_done(self, val: torch.Tensor) -> None: - self._is_done = val - def close(self): self.base_env.close() self.is_closed = True @@ -568,8 +558,6 @@ def to(self, device: DEVICE_TYPING) -> TransformedEnv: self.base_env.to(device) self.transform.to(device) - self.is_done = self.is_done.to(device) - if self.cache_specs: self._input_spec = None self._observation_spec = None @@ -1985,7 +1973,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: i += 1 tensordict = parent.rand_step(tensordict) tensordict = step_mdp(tensordict, exclude_done=False) - if tensordict.get("is_done"): + if tensordict.get("done"): tensordict = parent.reset(td_reset.clone(False)) break else: @@ -1994,13 +1982,13 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: trial += 1 if trial > _MAX_NOOPS_TRIALS: tensordict = parent.rand_step(tensordict) - if tensordict.get("is_done"): + if tensordict.get("done"): raise RuntimeError( f"parent is still done after a single random step (i={i})." ) break - if tensordict.get("is_done"): + if tensordict.get("done"): raise RuntimeError("NoopResetEnv concluded with done environment") return tensordict diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 878d851f946..5e0b2beb6ee 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -356,9 +356,6 @@ def reward_spec(self) -> TensorSpec: def reward_spec(self, value: TensorSpec) -> None: self._reward_spec = value - def is_done_set_fn(self, value: bool) -> None: - self._is_done = value.all() - def _create_td(self) -> None: """Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations.""" if self._single_task: @@ -989,10 +986,8 @@ def _run_worker_pipe_shared_mem( tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) just_reset = True - if env.is_done: - raise RuntimeError( - f"{env.__class__.__name__}.is_done is {env.is_done} after reset" - ) + if _td.get("done").any(): + raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step": if not initialized: @@ -1002,10 +997,6 @@ def _run_worker_pipe_shared_mem( *env_input_keys, strict=False, ) - if env.is_done and not allow_step_when_done: - raise RuntimeError( - f"calling step when env is done, just reset = {just_reset}" - ) _td = env._step(_td) if step_keys is None: step_keys = set(env.observation_spec.keys()).union( @@ -1020,7 +1011,6 @@ def _run_worker_pipe_shared_mem( msg = "step_result" data = (msg, step_keys) child_pipe.send(data) - just_reset = False elif cmd == "close": del tensordict, _td, data From 75b78bd8096725671513fb7248a7a319094f2450 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:41:02 +0000 Subject: [PATCH 03/17] bf --- torchrl/envs/transforms/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d9805e654ac..62c0d6b22f4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1985,7 +1985,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: i += 1 tensordict = parent.rand_step(tensordict) tensordict = step_mdp(tensordict, exclude_done=False) - if tensordict.get("is_done"): + if tensordict.get("done"): tensordict = parent.reset(td_reset.clone(False)) break else: @@ -1994,13 +1994,13 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: trial += 1 if trial > _MAX_NOOPS_TRIALS: tensordict = parent.rand_step(tensordict) - if tensordict.get("is_done"): + if tensordict.get("done"): raise RuntimeError( f"parent is still done after a single random step (i={i})." ) break - if tensordict.get("is_done"): + if tensordict.get("done"): raise RuntimeError("NoopResetEnv concluded with done environment") return tensordict From 58aeb772b23a9b105745ea5df401d3b2c1ae2e15 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:43:04 +0000 Subject: [PATCH 04/17] bf --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5bf65d16c6b..3edb5a14b63 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -399,7 +399,7 @@ def reset( "done", torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), ) - if tensordict.get("done").any(): + if tensordict_reset.get("done").any(): raise RuntimeError( f"Env {self} was done after reset. This is (currently) not allowed." ) From d421582a64dc539ca30d663fa71bff6c29a0b2c1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:52:13 +0000 Subject: [PATCH 05/17] bf --- test/mocking_classes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index db26b1c687b..7629d2874e7 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -340,7 +340,6 @@ def _step( if not self.categorical_action_encoding: assert (a.sum(-1) == 1).all() - assert not self.is_done, "trying to execute step in done env" obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep tensordict = tensordict.select() # empty tensordict @@ -423,7 +422,6 @@ def _step( self.step_count += 1 tensordict = tensordict.to(self.device) a = tensordict.get("action") - assert not self.is_done, "trying to execute step in done env" obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a) tensordict = tensordict.select() # empty tensordict From 40b6965ba5230746e96b98ae7325d5d368ea548f Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:52:22 +0000 Subject: [PATCH 06/17] bf --- torchrl/envs/vec_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 5e0b2beb6ee..a601b417172 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -985,7 +985,6 @@ def _run_worker_pipe_shared_mem( _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - just_reset = True if _td.get("done").any(): raise RuntimeError(f"{env.__class__.__name__} is done after reset") From c84fc02e7db5872a0b7f89ffac67047b579bfd69 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:59:17 +0000 Subject: [PATCH 07/17] test --- test/test_transforms.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index ed953e7ce45..c63f6f69828 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1275,6 +1275,22 @@ def test_noop_reset_env(self, random, device, compose): else: assert transformed_env.step_count == 30 + @pytest.mark.parametrize("random", [True, False]) + @pytest.mark.parametrize("compose", [True, False]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_noop_reset_env_error(self, random, device, compose): + torch.manual_seed(0) + env = SerialEnv(3, lambda: ContinuousActionVecMockEnv()) + env.set_seed(100) + noop_reset_env = NoopResetEnv(random=random) + transformed_env = TransformedEnv(env) + transformed_env.append_transform(noop_reset_env) + with pytest.raises( + ValueError, + match="there is more than one done state in the parent environment", + ): + transformed_env.reset() + @pytest.mark.parametrize( "default_keys", [["action"], ["action", "monkeys jumping on the bed"]] ) From 8a06d43ade024e6e0680f5debf99e2f54d1352fe Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 11:02:21 +0000 Subject: [PATCH 08/17] init --- test/test_transforms.py | 44 ++++++-- torchrl/envs/transforms/transforms.py | 144 +++++++++++++++----------- 2 files changed, 122 insertions(+), 66 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index c63f6f69828..5dc6147bca2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -59,6 +59,7 @@ _has_tv, CenterCrop, DiscreteActionProjection, + FrameSkipTransform, gSDENoise, NoopResetEnv, PinMemoryTransform, @@ -375,17 +376,48 @@ def test_transform_parent(): t3 = RewardClipping(0.1, 0.5) env.append_transform(t3) - t1_parent_gt = t1._parent - t2_parent_gt = t2._parent - t3_parent_gt = t3._parent + t1_parent_gt = t1._container + t2_parent_gt = t2._container + t3_parent_gt = t3._container _ = t1.parent _ = t2.parent _ = t3.parent - assert t1_parent_gt == t1._parent - assert t2_parent_gt == t2._parent - assert t3_parent_gt == t3._parent + assert t1_parent_gt == t1._container + assert t2_parent_gt == t2._container + assert t3_parent_gt == t3._container + + +def test_transform_parent_cache(): + """Tests the caching and uncaching of the transformed envs.""" + env = TransformedEnv( + ContinuousActionVecMockEnv(), + FrameSkipTransform(3), + ) + + # print the parent + assert ( + type(env.transform.parent.transform) is Compose + and len(env.transform.parent.transform) == 0 + ) + parent1 = env.transform.parent + parent2 = env.transform.parent + assert parent1 is parent2 + + # change the env, re-print the parent + env.insert_transform(0, NoopResetEnv(3)) + parent3 = env.transform[-1].parent + assert parent1 is not parent3 + assert type(parent3.transform[0]) is NoopResetEnv + + # change the env, re-print the parent + env.insert_transform(0, CatTensors(["observation"])) + parent4 = env.transform[-1].parent + assert parent1 is not parent4 + assert parent3 is not parent4 + assert type(parent4.transform[0]) is CatTensors + assert type(parent4.transform[1]) is NoopResetEnv class TestTransforms: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4164cc44cfd..dd891ea573c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -107,6 +107,7 @@ def __init__( if out_keys_inv is None: out_keys_inv = copy(self.in_keys_inv) self.out_keys_inv = out_keys_inv + self.__dict__["_container"] = None self.__dict__["_parent"] = None def reset(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -213,15 +214,16 @@ def dump(self, **kwargs) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__}(keys={self.in_keys})" - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: - if self.__dict__["_parent"] is not None: + def set_container(self, container: Union[Transform, EnvBase]) -> None: + if self.__dict__["_container"] is not None: raise AttributeError( - "parent of transform already set. " + f"parent of transform {type(self)} already set. " "Call `transform.clone()` to get a similar transform with no parent set." ) - self.__dict__["_parent"] = parent + self.__dict__["_container"] = container def reset_parent(self) -> None: + self.__dict__["_container"] = None self.__dict__["_parent"] = None def clone(self): @@ -231,45 +233,46 @@ def clone(self): @property def parent(self) -> Optional[EnvBase]: - if not hasattr(self, "_parent"): - raise AttributeError("transform parent uninitialized") - parent = self._parent - if parent is None: - return parent - out = None - if not isinstance(parent, EnvBase): - # if it's not an env, it should be a Compose transform - if not isinstance(parent, Compose): - raise ValueError( - "A transform parent must be either another Compose transform or an environment object." - ) - compose = parent - if compose.parent: - # the parent of the compose must be a TransformedEnv - compose_parent = compose.parent - if compose_parent.transform is not compose: - comp_parent_trans = compose_parent.transform.clone() - else: - comp_parent_trans = None - out = TransformedEnv( - compose_parent.base_env, - transform=comp_parent_trans, - ) - for orig_trans in compose.transforms: - if orig_trans is self: - break - transform = copy(orig_trans) - transform.reset_parent() - out.append_transform(transform) - elif isinstance(parent, TransformedEnv): - out = TransformedEnv(parent.base_env) - else: - raise ValueError(f"parent is of type {type(parent)}") - return out + if self.__dict__.get("_parent", None) is None: + if "_container" not in self.__dict__: + raise AttributeError("transform parent uninitialized") + parent = self.__dict__["_container"] + if parent is None: + return parent + out = None + if not isinstance(parent, EnvBase): + # if it's not an env, it should be a Compose transform + if not isinstance(parent, Compose): + raise ValueError( + "A transform parent must be either another Compose transform or an environment object." + ) + compose = parent + if compose.parent: + # the parent of the compose must be a TransformedEnv + compose_parent = compose.parent + if compose_parent.transform is not compose: + comp_parent_trans = compose_parent.transform.clone() + else: + comp_parent_trans = None + out = TransformedEnv( + compose_parent.base_env, + transform=comp_parent_trans, + ) + for orig_trans in compose.transforms: + if orig_trans is self: + break + transform = copy(orig_trans) + transform.reset_parent() + out.append_transform(transform) + elif isinstance(parent, TransformedEnv): + out = TransformedEnv(parent.base_env) + else: + raise ValueError(f"parent is of type {type(parent)}") + self.__dict__["_parent"] = out + return self.__dict__["_parent"] def empty_cache(self): - if self.parent is not None: - self.parent.empty_cache() + self.__dict__["_parent"] = None class TransformedEnv(EnvBase): @@ -353,7 +356,7 @@ def transform(self, transform: Transform): f"""Expected a transform of type torchrl.envs.transforms.Transform, but got an object of type {type(transform)}.""" ) - transform.set_parent(self) + transform.set_container(self) transform.eval() self._transform = transform @@ -432,13 +435,9 @@ def reward_spec(self) -> TensorSpec: return reward_spec def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - # selected_keys = [key for key in tensordict.keys() if "action" in key] - # tensordict_in = tensordict.select(*selected_keys).clone() tensordict = tensordict.clone() tensordict_in = self.transform.inv(tensordict) tensordict_out = self.base_env._step(tensordict_in) - # tensordict should already have been processed by the transforms - # for logging purposes tensordict_out = tensordict_out.update( tensordict.exclude(*tensordict_out.keys()) ) @@ -518,8 +517,8 @@ def insert_transform(self, index: int, transform: Transform) -> None: ) transform = transform.to(self.device) if not isinstance(self.transform, Compose): - self.transform = Compose(self.transform) - self.transform.set_parent(self) + compose = Compose(self.transform.clone()) + self.transform = compose # parent set automatically self.transform.insert(index, transform) self._erase_metadata() @@ -620,7 +619,7 @@ def __init__(self, *transforms: Transform): super().__init__(in_keys=[]) self.transforms = nn.ModuleList(transforms) for t in self.transforms: - t.set_parent(self) + t.set_container(self) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: for t in self.transforms: @@ -657,7 +656,7 @@ def __getitem__(self, item: Union[int, slice, List]) -> Union: transform = transform[item] if not isinstance(transform, Transform): out = Compose(*self.transforms[item]) - out.set_parent(self.parent) + out.set_container(self.parent) return out return transform @@ -683,7 +682,7 @@ def append(self, transform): ) transform.eval() self.transforms.append(transform) - transform.set_parent(self) + transform.set_container(self) def insert(self, index: int, transform: Transform) -> None: if not isinstance(transform, Transform): @@ -697,12 +696,13 @@ def insert(self, index: int, transform: Transform) -> None: f"Index expected to be between [-{len(self.transforms)}, {len(self.transforms)}] got index={index}" ) + # empty cache of all transforms to reset parents and specs self.empty_cache() if index < 0: index = index + len(self.transforms) transform.eval() self.transforms.insert(index, transform) - transform.set_parent(self) + transform.set_container(self) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Compose: for t in self.transforms: @@ -721,6 +721,11 @@ def __repr__(self) -> str: ) return f"{self.__class__.__name__}(\n{indent(layers_str, 4 * ' ')})" + def empty_cache(self): + for t in self.transforms: + t.empty_cache() + super().empty_cache() + class ToTensorImage(ObservationTransform): """Transforms a numpy-like image (3 x W x H) to a pytorch image (3 x W x H). @@ -1034,8 +1039,8 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation = torch.flatten(observation, self.first_dim, self.last_dim) return observation - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: - out = super().set_parent(parent) + def set_container(self, container: Union[Transform, EnvBase]) -> None: + out = super().set_container(container) try: observation_spec = self.parent.observation_spec for key in self.in_keys: @@ -1112,13 +1117,13 @@ def __init__( ) self._unsqueeze_dim_orig = unsqueeze_dim - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Union[Transform, EnvBase]) -> None: if self._unsqueeze_dim_orig < 0: self._unsqueeze_dim = self._unsqueeze_dim_orig else: - parent = self.parent + container = self.parent try: - batch_size = parent.batch_size + batch_size = container.batch_size except AttributeError: raise ValueError( f"Got the unsqueeze dimension {self._unsqueeze_dim_orig} which is greater or equal to zero. " @@ -1127,7 +1132,7 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None: f"`TransformedEnv.append_transform()` method." ) self._unsqueeze_dim = self._unsqueeze_dim_orig + len(batch_size) - return super().set_parent(parent) + return super().set_container(container) @property def unsqueeze_dim(self): @@ -1915,6 +1920,25 @@ def __repr__(self) -> str: ) +class FrameSkipTransform(Transform): + inplace = False + + def __init__(self, frame_skip: int = 1): + super().__init__([]) + if frame_skip < 1: + raise ValueError("frame_skip should have a value greater or equal to one.") + self.frame_skip = frame_skip + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + parent = self.parent + reward = tensordict.get("reward") + for _ in range(self.frame_skip - 1): + parent.step(tensordict) + reward = reward + tensordict.get("reward") + tensordict.set("reward", reward) + return super()._step(tensordict) + + class NoopResetEnv(Transform): """Runs a series of random actions when an environment is reset. @@ -2079,8 +2103,8 @@ def transform_observation_spec( observation_spec[key] = spec.to(self.device) return observation_spec - def set_parent(self, parent: Union[Transform, EnvBase]) -> None: - super().set_parent(parent) + def set_container(self, container: Union[Transform, EnvBase]) -> None: + super().set_container(container) @property def _batch_size(self): From dd885f244e9c3b639dd1f8a6618ac7a7b8ba07bb Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 11:29:48 +0000 Subject: [PATCH 09/17] remove container when replacing transform --- test/test_transforms.py | 5 +++++ torchrl/envs/transforms/transforms.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 5dc6147bca2..3e1673858d2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -401,6 +401,7 @@ def test_transform_parent_cache(): type(env.transform.parent.transform) is Compose and len(env.transform.parent.transform) == 0 ) + transform = env.transform parent1 = env.transform.parent parent2 = env.transform.parent assert parent1 is parent2 @@ -419,6 +420,10 @@ def test_transform_parent_cache(): assert type(parent4.transform[0]) is CatTensors assert type(parent4.transform[1]) is NoopResetEnv + # check that we don't keep track of the wrong parent + env.transform = NoopResetEnv(3) + assert transform.parent is None + class TestTransforms: @pytest.mark.skipif(not _has_tv, reason="no torchvision") diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index dd891ea573c..d1bd1d8415e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -356,6 +356,10 @@ def transform(self, transform: Transform): f"""Expected a transform of type torchrl.envs.transforms.Transform, but got an object of type {type(transform)}.""" ) + prev_transform = self.transform + if prev_transform is not None: + prev_transform.empty_cache() + prev_transform.__dict__["_container"] = None transform.set_container(self) transform.eval() self._transform = transform From 3726729c73935a6c7613687b25f3f19f04b7a18b Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 11:34:05 +0000 Subject: [PATCH 10/17] bf --- torchrl/envs/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c698c473200..dbd2c118893 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -401,7 +401,9 @@ def reset( self.is_done = tensordict_reset.set_default( "done", - torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), + torch.zeros( + *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device + ), ) if self.is_done: raise RuntimeError( From 0cf77b80cce42c24afbd28badd4119a0ec4b908c Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Dec 2022 13:23:55 +0000 Subject: [PATCH 11/17] amend --- torchrl/envs/common.py | 5 +++-- torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index fef7249fa38..37dcaab1537 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -143,7 +143,7 @@ def build_tensordict( """ # build a tensordict from specs - td = TensorDict({}, batch_size=torch.Size([])) + td = TensorDict({}, batch_size=torch.Size([]), _run_checks=False) action_placeholder = torch.zeros( self["action_spec"].shape, dtype=self["action_spec"].dtype ) @@ -462,7 +462,7 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa """ if tensordict is None: - tensordict = TensorDict({}, device=self.device, batch_size=self.batch_size) + tensordict = TensorDict({}, device=self.device, batch_size=self.batch_size, _run_checks=False) action = self.action_spec.rand(self.batch_size) tensordict.set("action", action) return self.step(tensordict) @@ -646,6 +646,7 @@ def fake_tensordict(self) -> TensorDictBase: }, batch_size=self.batch_size, device=self.device, + _run_checks=False, ) return fake_td diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 098c7545812..6c682c6210b 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -13,6 +13,7 @@ DoubleToFloat, FiniteTensorDictCheck, FlattenObservation, + FrameSkipTransform, GrayScale, gSDENoise, NoopResetEnv, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d1bd1d8415e..b7583c4beea 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -439,7 +439,7 @@ def reward_spec(self) -> TensorSpec: return reward_spec def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = tensordict.clone() + tensordict = tensordict.clone(False) tensordict_in = self.transform.inv(tensordict) tensordict_out = self.base_env._step(tensordict_in) tensordict_out = tensordict_out.update( @@ -1937,10 +1937,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: parent = self.parent reward = tensordict.get("reward") for _ in range(self.frame_skip - 1): - parent.step(tensordict) + parent._step(tensordict) reward = reward + tensordict.get("reward") tensordict.set("reward", reward) - return super()._step(tensordict) + return tensordict class NoopResetEnv(Transform): From 80d6e028ad818e43d124dbbe2e6cdb45c7711e9d Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 16:25:26 +0100 Subject: [PATCH 12/17] bf dm_control --- torchrl/envs/common.py | 4 +++- torchrl/envs/libs/dm_control.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 37dcaab1537..f8f1c2ddd11 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -462,7 +462,9 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa """ if tensordict is None: - tensordict = TensorDict({}, device=self.device, batch_size=self.batch_size, _run_checks=False) + tensordict = TensorDict( + {}, device=self.device, batch_size=self.batch_size, _run_checks=False + ) action = self.action_spec.rand(self.batch_size) tensordict.set("action", action) return self.step(tensordict) diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index a3a08158c21..e14d7c18d3d 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -239,9 +239,12 @@ def observation_spec(self, value: TensorSpec) -> None: @property def reward_spec(self) -> TensorSpec: if self._reward_spec is None: - self._reward_spec = _dmcontrol_to_torchrl_spec_transform( + _reward_spec = _dmcontrol_to_torchrl_spec_transform( self._env.reward_spec(), device=self.device ) + if _reward_spec.shape == torch.Size([]): + _reward_spec.shape = torch.Size([1]) + self._reward_spec = _reward_spec return self._reward_spec @reward_spec.setter From 91172262e38429bac534c8977aa82cbef5dbc195 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 17:04:44 +0100 Subject: [PATCH 13/17] amend --- torchrl/envs/transforms/functional.py | 4 ++-- torchrl/envs/transforms/transforms.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index e2135c3e814..6ef23b11fd5 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -22,8 +22,8 @@ def _assert_channels(img: Tensor, permitted: List[int]) -> None: c = _get_image_num_channels(img) if c not in permitted: raise TypeError( - "Input image tensor permitted channel values are {}, but found" - "{}".format(permitted, c) + f"Input image tensor permitted channel values are {permitted}, but found " + f"{c} (full shape: {img.shape})" ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b7583c4beea..3d2c57ebdd1 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1925,6 +1925,17 @@ def __repr__(self) -> str: class FrameSkipTransform(Transform): + """A frame-skip transform. + + This transform applies the same action repeatedly in the parent environment, + which improves stability on certain training algorithms. + + Args: + frame_skip (int, optional): a positive integer representing the number + of frames during which the same action must be applied. + + """ + inplace = False def __init__(self, frame_skip: int = 1): From 4403537df8efee3930e5831f404a92b39a068e01 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 22 Dec 2022 17:27:37 +0100 Subject: [PATCH 14/17] amend --- torchrl/envs/transforms/transforms.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3d2c57ebdd1..55d8bfd2c12 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -236,20 +236,22 @@ def parent(self) -> Optional[EnvBase]: if self.__dict__.get("_parent", None) is None: if "_container" not in self.__dict__: raise AttributeError("transform parent uninitialized") - parent = self.__dict__["_container"] - if parent is None: - return parent + container = self.__dict__["_container"] + if container is None: + return container out = None - if not isinstance(parent, EnvBase): + if not isinstance(container, EnvBase): # if it's not an env, it should be a Compose transform - if not isinstance(parent, Compose): + if not isinstance(container, Compose): raise ValueError( "A transform parent must be either another Compose transform or an environment object." ) - compose = parent - if compose.parent: + compose = container + if compose.__dict__["_container"]: # the parent of the compose must be a TransformedEnv - compose_parent = compose.parent + compose_parent = TransformedEnv( + compose.__dict__["_container"].base_env + ) if compose_parent.transform is not compose: comp_parent_trans = compose_parent.transform.clone() else: @@ -261,13 +263,13 @@ def parent(self) -> Optional[EnvBase]: for orig_trans in compose.transforms: if orig_trans is self: break - transform = copy(orig_trans) + transform = orig_trans.clone() transform.reset_parent() out.append_transform(transform) - elif isinstance(parent, TransformedEnv): - out = TransformedEnv(parent.base_env) + elif isinstance(container, TransformedEnv): + out = TransformedEnv(container.base_env) else: - raise ValueError(f"parent is of type {type(parent)}") + raise ValueError(f"container is of type {type(container)}") self.__dict__["_parent"] = out return self.__dict__["_parent"] From 0aa77af6b9e0caf556e9851495a0a77998524b7e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Dec 2022 09:03:34 +0100 Subject: [PATCH 15/17] tests --- test/test_transforms.py | 60 +++++++++++++++++++++++++++ torchrl/envs/transforms/transforms.py | 2 +- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 3e1673858d2..3c6d29d995d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -547,6 +547,66 @@ def test_flatten(self, keys, size, nchannels, batch, device): for key in keys: assert observation_spec[key].shape[-3] == expected_size + @pytest.mark.skipif(not _has_gym, reason="gym not installed") + @pytest.mark.parametrize("skip", [-1, 1, 2, 3]) + def test_frame_skip_transform_builtin(self, skip): + torch.manual_seed(0) + if skip < 0: + with pytest.raises( + ValueError, + match="frame_skip should have a value greater or equal to one", + ): + FrameSkipTransform(skip) + return + else: + fs = FrameSkipTransform(skip) + base_env = GymEnv("Pendulum-v1", frame_skip=skip) + tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) + env = TransformedEnv(GymEnv("Pendulum-v1"), fs) + base_env.set_seed(0) + env.base_env.set_seed(0) + td1 = base_env.reset() + td2 = env.reset() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + for i in range(10): + td1 = base_env.step(tensordicts[i].clone()).flatten_keys() + td2 = env.step(tensordicts[i].clone()).flatten_keys() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + + @pytest.mark.skipif(not _has_gym, reason="gym not installed") + @pytest.mark.parametrize("skip", [-1, 1, 2, 3]) + def test_frame_skip_transform_unroll(self, skip): + torch.manual_seed(0) + if skip < 0: + with pytest.raises( + ValueError, + match="frame_skip should have a value greater or equal to one", + ): + FrameSkipTransform(skip) + return + else: + fs = FrameSkipTransform(skip) + base_env = GymEnv("Pendulum-v1") + tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) + env = TransformedEnv(GymEnv("Pendulum-v1"), fs) + base_env.set_seed(0) + env.base_env.set_seed(0) + td1 = base_env.reset() + td2 = env.reset() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + for i in range(10): + r = 0.0 + for _ in range(skip): + td1 = base_env.step(tensordicts[i].clone()).flatten_keys() + r = td1.get("reward") + r + td1.set("reward", r) + td2 = env.step(tensordicts[i].clone()).flatten_keys() + for key in td1.keys(): + torch.testing.assert_close(td1[key], td2[key]) + @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 55d8bfd2c12..be650cb6c8a 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1950,7 +1950,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: parent = self.parent reward = tensordict.get("reward") for _ in range(self.frame_skip - 1): - parent._step(tensordict) + tensordict = parent._step(tensordict) reward = reward + tensordict.get("reward") tensordict.set("reward", reward) return tensordict From 34753775453f5da96b37dc78d74aa69976407182 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Dec 2022 10:30:54 +0100 Subject: [PATCH 16/17] bf --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 3c6d29d995d..9a1b8f1eed3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -562,7 +562,7 @@ def test_frame_skip_transform_builtin(self, skip): fs = FrameSkipTransform(skip) base_env = GymEnv("Pendulum-v1", frame_skip=skip) tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) - env = TransformedEnv(GymEnv("Pendulum-v1"), fs) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), fs) base_env.set_seed(0) env.base_env.set_seed(0) td1 = base_env.reset() From a8d0f34878340caccf57bcb0b61d1fc2df3542fa Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 23 Dec 2022 10:31:24 +0100 Subject: [PATCH 17/17] bf --- test/test_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 9a1b8f1eed3..c07ee7cb8b3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -560,7 +560,7 @@ def test_frame_skip_transform_builtin(self, skip): return else: fs = FrameSkipTransform(skip) - base_env = GymEnv("Pendulum-v1", frame_skip=skip) + base_env = GymEnv(PENDULUM_VERSIONED, frame_skip=skip) tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), fs) base_env.set_seed(0) @@ -588,9 +588,9 @@ def test_frame_skip_transform_unroll(self, skip): return else: fs = FrameSkipTransform(skip) - base_env = GymEnv("Pendulum-v1") + base_env = GymEnv(PENDULUM_VERSIONED) tensordicts = TensorDict({"action": base_env.action_spec.rand((10,))}, [10]) - env = TransformedEnv(GymEnv("Pendulum-v1"), fs) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED), fs) base_env.set_seed(0) env.base_env.set_seed(0) td1 = base_env.reset()