From 1a480c73b66bb46d45dbf898c11aa1be716a5896 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:32:09 +0000 Subject: [PATCH 1/4] init --- torchrl/envs/common.py | 2 +- torchrl/envs/transforms/transforms.py | 60 +++++++++++++++------------ 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b9bf02fcdc5..c698c473200 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -399,7 +399,7 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - self.is_done = tensordict_reset.get( + self.is_done = tensordict_reset.set_default( "done", torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b14ae8b0508..d9805e654ac 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1960,40 +1960,48 @@ def base_env(self): def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Do no-op action for a number of steps in [1, noop_max].""" + td_reset = tensordict.clone(False) + tensordict = tensordict.clone(False) + # check that there is a single done state -- behaviour is undefined for multiple dones parent = self.parent - # keys = tensordict.keys() + if tensordict.get("done").numel() > 1: + raise ValueError( + "there is more than one done state in the parent environment. " + "NoopResetEnv is designed to work on single env instances, as partial reset " + "is currently not supported. If you feel like this is a missing feature, submit " + "an issue on TorchRL github repo. " + "In case you are trying to use NoopResetEnv over a batch of environments, know " + "that you can have a transformed batch of transformed envs, such as: " + "`TransformedEnv(ParallelEnv(3, lambda: TransformedEnv(MyEnv(), NoopResetEnv(3))), OtherTransform())`." + ) noops = ( self.noops if not self.random else torch.randint(self.noops, (1,)).item() ) - i = 0 trial = 0 - while i < noops: - i += 1 - tensordict = parent.rand_step(tensordict) - tensordict = step_mdp(tensordict) - if parent.is_done: - parent.reset() - i = 0 - trial += 1 - if trial > _MAX_NOOPS_TRIALS: - tensordict = parent.reset(tensordict) - tensordict = parent.rand_step(tensordict) + while True: + i = 0 + while i < noops: + i += 1 + tensordict = parent.rand_step(tensordict) + tensordict = step_mdp(tensordict, exclude_done=False) + if tensordict.get("is_done"): + tensordict = parent.reset(td_reset.clone(False)) break - if parent.is_done: - raise RuntimeError("NoopResetEnv concluded with done environment") - # td = step_mdp( - # tensordict, exclude_done=False, exclude_reward=True, exclude_action=True - # ) - - # for k in keys: - # if k not in td.keys(): - # td.set(k, tensordict.get(k)) - - # # replace the next_ prefix - # for out_key in parent.observation_spec: - # td.rename_key(out_key[5:], out_key) + else: + break + + trial += 1 + if trial > _MAX_NOOPS_TRIALS: + tensordict = parent.rand_step(tensordict) + if tensordict.get("is_done"): + raise RuntimeError( + f"parent is still done after a single random step (i={i})." + ) + break + if tensordict.get("is_done"): + raise RuntimeError("NoopResetEnv concluded with done environment") return tensordict def __repr__(self) -> str: From 75b78bd8096725671513fb7248a7a319094f2450 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:41:02 +0000 Subject: [PATCH 2/4] bf --- torchrl/envs/transforms/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d9805e654ac..62c0d6b22f4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1985,7 +1985,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: i += 1 tensordict = parent.rand_step(tensordict) tensordict = step_mdp(tensordict, exclude_done=False) - if tensordict.get("is_done"): + if tensordict.get("done"): tensordict = parent.reset(td_reset.clone(False)) break else: @@ -1994,13 +1994,13 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: trial += 1 if trial > _MAX_NOOPS_TRIALS: tensordict = parent.rand_step(tensordict) - if tensordict.get("is_done"): + if tensordict.get("done"): raise RuntimeError( f"parent is still done after a single random step (i={i})." ) break - if tensordict.get("is_done"): + if tensordict.get("done"): raise RuntimeError("NoopResetEnv concluded with done environment") return tensordict From c84fc02e7db5872a0b7f89ffac67047b579bfd69 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 09:59:17 +0000 Subject: [PATCH 3/4] test --- test/test_transforms.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index ed953e7ce45..c63f6f69828 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1275,6 +1275,22 @@ def test_noop_reset_env(self, random, device, compose): else: assert transformed_env.step_count == 30 + @pytest.mark.parametrize("random", [True, False]) + @pytest.mark.parametrize("compose", [True, False]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_noop_reset_env_error(self, random, device, compose): + torch.manual_seed(0) + env = SerialEnv(3, lambda: ContinuousActionVecMockEnv()) + env.set_seed(100) + noop_reset_env = NoopResetEnv(random=random) + transformed_env = TransformedEnv(env) + transformed_env.append_transform(noop_reset_env) + with pytest.raises( + ValueError, + match="there is more than one done state in the parent environment", + ): + transformed_env.reset() + @pytest.mark.parametrize( "default_keys", [["action"], ["action", "monkeys jumping on the bed"]] ) From 3726729c73935a6c7613687b25f3f19f04b7a18b Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 16 Dec 2022 11:34:05 +0000 Subject: [PATCH 4/4] bf --- torchrl/envs/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c698c473200..dbd2c118893 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -401,7 +401,9 @@ def reset( self.is_done = tensordict_reset.set_default( "done", - torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), + torch.zeros( + *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device + ), ) if self.is_done: raise RuntimeError(