From a2f55cf94e646458e327e95f2fe201c04437bf08 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Jan 2023 16:12:04 +0100 Subject: [PATCH 1/6] Fixing done issues --- torchrl/envs/vec_env.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index a74a70c3d86..48c29800583 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -18,7 +18,6 @@ from tensordict import TensorDict from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp - from torchrl._utils import _check_for_faulty_process from torchrl.data import CompositeSpec, TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -781,10 +780,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in range(self.num_workers): msg, data = self.parent_channels[i].recv() if msg != "step_result": - if msg != "done": - raise RuntimeError( - f"Expected 'done' but received {msg} from worker {i}" - ) + raise RuntimeError( + f"Expected 'step_result' but received {msg} from worker {i}" + ) # data is the set of updated keys keys = keys.union(data) # We must pass a clone of the tensordict, as the values of this tensordict @@ -1003,7 +1001,7 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() - _td = env._reset(**reset_kwargs) + _td = env.reset(**reset_kwargs) if reset_keys is None: reset_keys = set(_td.keys()) if pin_memory: @@ -1029,10 +1027,7 @@ def _run_worker_pipe_shared_mem( if pin_memory: _td.pin_memory() tensordict.update_(_td.select(*step_keys, strict=False)) - if _td.get("done"): - msg = "done" - else: - msg = "step_result" + msg = "step_result" data = (msg, step_keys) child_pipe.send(data) From ec9da9993f74e3cf44aa74c50c2de4e34c4d64e8 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Jan 2023 22:40:25 +0100 Subject: [PATCH 2/6] Undo `reset()` --- torchrl/envs/vec_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 48c29800583..bcac4273a2d 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1001,14 +1001,14 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() - _td = env.reset(**reset_kwargs) + _td = env._reset(**reset_kwargs) if reset_keys is None: reset_keys = set(_td.keys()) if pin_memory: _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - if _td.get("done").any(): + if _td.get("done", torch.zeros([], dtype=torch.bool)).any(): raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step": From 7602e09d3b418339170f9ef92fbbc673b8ed58a6 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Jan 2023 23:42:16 +0100 Subject: [PATCH 3/6] use set_default instead --- torchrl/envs/vec_env.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index bcac4273a2d..2a1b3636e8d 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1002,13 +1002,17 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() _td = env._reset(**reset_kwargs) + _td.set_default( + "done", + torch.zeros(*_td.batch_size, 1, dtype=torch.bool, device=_td.device), + ) if reset_keys is None: reset_keys = set(_td.keys()) if pin_memory: _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - if _td.get("done", torch.zeros([], dtype=torch.bool)).any(): + if _td.get("done").any(): raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step": From a1c22387302289daa00b6e40c075e715cbde1dbc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 4 Jan 2023 09:59:27 +0100 Subject: [PATCH 4/6] change device --- torchrl/envs/vec_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 2a1b3636e8d..d79dca7c494 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1004,7 +1004,7 @@ def _run_worker_pipe_shared_mem( _td = env._reset(**reset_kwargs) _td.set_default( "done", - torch.zeros(*_td.batch_size, 1, dtype=torch.bool, device=_td.device), + torch.zeros(*_td.batch_size, 1, dtype=torch.bool, device=env.device), ) if reset_keys is None: reset_keys = set(_td.keys()) From d21705d1543932bfd42ace4fcc16560350a1d616 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Jan 2023 08:59:35 +0000 Subject: [PATCH 5/6] empty From 7a6904bb795e2b50524999d901460e504cad747f Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Jan 2023 10:43:22 +0000 Subject: [PATCH 6/6] minor --- test/test_env.py | 5 ++--- torchrl/envs/vec_env.py | 11 ++++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index a36c88b6282..4661211f863 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -508,7 +508,7 @@ def test_parallel_env( td = TensorDict( source={"action": env0.action_spec.rand((N - 1,))}, batch_size=[N - 1] ) - td1 = env_parallel.step(td) + _ = env_parallel.step(td) td_reset = TensorDict( source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, @@ -592,7 +592,7 @@ def test_parallel_env_with_policy( td = TensorDict( source={"action": env0.action_spec.rand((N - 1,))}, batch_size=[N - 1] ) - td1 = env_parallel.step(td) + _ = env_parallel.step(td) td_reset = TensorDict( source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, @@ -710,7 +710,6 @@ def test_parallel_env_cast( transformed_out, device, open_before, - T=10, N=3, ): # tests casting to device diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index d79dca7c494..af44cc533e5 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1002,17 +1002,18 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() _td = env._reset(**reset_kwargs) - _td.set_default( - "done", - torch.zeros(*_td.batch_size, 1, dtype=torch.bool, device=env.device), - ) + done = _td.get("done", None) + if done is None: + _td["done"] = done = torch.zeros( + *_td.batch_size, 1, dtype=torch.bool, device=env.device + ) if reset_keys is None: reset_keys = set(_td.keys()) if pin_memory: _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - if _td.get("done").any(): + if done.any(): raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step":