diff --git a/test/test_transforms.py b/test/test_transforms.py index ed953e7ce45..c63f6f69828 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1275,6 +1275,22 @@ def test_noop_reset_env(self, random, device, compose): else: assert transformed_env.step_count == 30 + @pytest.mark.parametrize("random", [True, False]) + @pytest.mark.parametrize("compose", [True, False]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_noop_reset_env_error(self, random, device, compose): + torch.manual_seed(0) + env = SerialEnv(3, lambda: ContinuousActionVecMockEnv()) + env.set_seed(100) + noop_reset_env = NoopResetEnv(random=random) + transformed_env = TransformedEnv(env) + transformed_env.append_transform(noop_reset_env) + with pytest.raises( + ValueError, + match="there is more than one done state in the parent environment", + ): + transformed_env.reset() + @pytest.mark.parametrize( "default_keys", [["action"], ["action", "monkeys jumping on the bed"]] ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b9bf02fcdc5..dbd2c118893 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -399,9 +399,11 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - self.is_done = tensordict_reset.get( + self.is_done = tensordict_reset.set_default( "done", - torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), + torch.zeros( + *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device + ), ) if self.is_done: raise RuntimeError( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b14ae8b0508..62c0d6b22f4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1960,40 +1960,48 @@ def base_env(self): def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Do no-op action for a number of steps in [1, noop_max].""" + td_reset = tensordict.clone(False) + tensordict = tensordict.clone(False) + # check that there is a single done state -- behaviour is undefined for multiple dones parent = self.parent - # keys = tensordict.keys() + if tensordict.get("done").numel() > 1: + raise ValueError( + "there is more than one done state in the parent environment. " + "NoopResetEnv is designed to work on single env instances, as partial reset " + "is currently not supported. If you feel like this is a missing feature, submit " + "an issue on TorchRL github repo. " + "In case you are trying to use NoopResetEnv over a batch of environments, know " + "that you can have a transformed batch of transformed envs, such as: " + "`TransformedEnv(ParallelEnv(3, lambda: TransformedEnv(MyEnv(), NoopResetEnv(3))), OtherTransform())`." + ) noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) - i = 0 trial = 0 - while i < noops: - i += 1 - tensordict = parent.rand_step(tensordict) - tensordict = step_mdp(tensordict) - if parent.is_done: - parent.reset() - i = 0 - trial += 1 - if trial > _MAX_NOOPS_TRIALS: - tensordict = parent.reset(tensordict) - tensordict = parent.rand_step(tensordict) + while True: + i = 0 + while i < noops: + i += 1 + tensordict = parent.rand_step(tensordict) + tensordict = step_mdp(tensordict, exclude_done=False) + if tensordict.get("done"): + tensordict = parent.reset(td_reset.clone(False)) break - if parent.is_done: - raise RuntimeError("NoopResetEnv concluded with done environment") - # td = step_mdp( - # tensordict, exclude_done=False, exclude_reward=True, exclude_action=True - # ) - - # for k in keys: - # if k not in td.keys(): - # td.set(k, tensordict.get(k)) - - # # replace the next_ prefix - # for out_key in parent.observation_spec: - # td.rename_key(out_key[5:], out_key) + else: + break + + trial += 1 + if trial > _MAX_NOOPS_TRIALS: + tensordict = parent.rand_step(tensordict) + if tensordict.get("done"): + raise RuntimeError( + f"parent is still done after a single random step (i={i})." + ) + break + if tensordict.get("done"): + raise RuntimeError("NoopResetEnv concluded with done environment") return tensordict def __repr__(self) -> str: