Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,22 @@ def test_noop_reset_env(self, random, device, compose):
else:
assert transformed_env.step_count == 30

@pytest.mark.parametrize("random", [True, False])
@pytest.mark.parametrize("compose", [True, False])
@pytest.mark.parametrize("device", get_available_devices())
def test_noop_reset_env_error(self, random, device, compose):
torch.manual_seed(0)
env = SerialEnv(3, lambda: ContinuousActionVecMockEnv())
env.set_seed(100)
noop_reset_env = NoopResetEnv(random=random)
transformed_env = TransformedEnv(env)
transformed_env.append_transform(noop_reset_env)
with pytest.raises(
ValueError,
match="there is more than one done state in the parent environment",
):
transformed_env.reset()

@pytest.mark.parametrize(
"default_keys", [["action"], ["action", "monkeys jumping on the bed"]]
)
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,11 @@ def reset(
f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected."
)

self.is_done = tensordict_reset.get(
self.is_done = tensordict_reset.set_default(
"done",
torch.zeros(self.batch_size, dtype=torch.bool, device=self.device),
torch.zeros(
*tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device
),
)
if self.is_done:
raise RuntimeError(
Expand Down
60 changes: 34 additions & 26 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,40 +1960,48 @@ def base_env(self):

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Do no-op action for a number of steps in [1, noop_max]."""
td_reset = tensordict.clone(False)
tensordict = tensordict.clone(False)
# check that there is a single done state -- behaviour is undefined for multiple dones
parent = self.parent
# keys = tensordict.keys()
if tensordict.get("done").numel() > 1:
raise ValueError(
"there is more than one done state in the parent environment. "
"NoopResetEnv is designed to work on single env instances, as partial reset "
"is currently not supported. If you feel like this is a missing feature, submit "
"an issue on TorchRL github repo. "
"In case you are trying to use NoopResetEnv over a batch of environments, know "
"that you can have a transformed batch of transformed envs, such as: "
"`TransformedEnv(ParallelEnv(3, lambda: TransformedEnv(MyEnv(), NoopResetEnv(3))), OtherTransform())`."
)
noops = (
self.noops if not self.random else torch.randint(self.noops, (1,)).item()
)
i = 0
trial = 0

while i < noops:
i += 1
tensordict = parent.rand_step(tensordict)
tensordict = step_mdp(tensordict)
if parent.is_done:
parent.reset()
i = 0
trial += 1
if trial > _MAX_NOOPS_TRIALS:
tensordict = parent.reset(tensordict)
tensordict = parent.rand_step(tensordict)
while True:
i = 0
while i < noops:
i += 1
tensordict = parent.rand_step(tensordict)
tensordict = step_mdp(tensordict, exclude_done=False)
if tensordict.get("done"):
tensordict = parent.reset(td_reset.clone(False))
break
if parent.is_done:
raise RuntimeError("NoopResetEnv concluded with done environment")
# td = step_mdp(
# tensordict, exclude_done=False, exclude_reward=True, exclude_action=True
# )

# for k in keys:
# if k not in td.keys():
# td.set(k, tensordict.get(k))

# # replace the next_ prefix
# for out_key in parent.observation_spec:
# td.rename_key(out_key[5:], out_key)
else:
break

trial += 1
if trial > _MAX_NOOPS_TRIALS:
tensordict = parent.rand_step(tensordict)
if tensordict.get("done"):
raise RuntimeError(
f"parent is still done after a single random step (i={i})."
)
break

if tensordict.get("done"):
raise RuntimeError("NoopResetEnv concluded with done environment")
return tensordict

def __repr__(self) -> str:
Expand Down