diff --git a/test/test_env.py b/test/test_env.py index a36c88b6282..4661211f863 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -508,7 +508,7 @@ def test_parallel_env( td = TensorDict( source={"action": env0.action_spec.rand((N - 1,))}, batch_size=[N - 1] ) - td1 = env_parallel.step(td) + _ = env_parallel.step(td) td_reset = TensorDict( source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, @@ -592,7 +592,7 @@ def test_parallel_env_with_policy( td = TensorDict( source={"action": env0.action_spec.rand((N - 1,))}, batch_size=[N - 1] ) - td1 = env_parallel.step(td) + _ = env_parallel.step(td) td_reset = TensorDict( source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, @@ -710,7 +710,6 @@ def test_parallel_env_cast( transformed_out, device, open_before, - T=10, N=3, ): # tests casting to device diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index a74a70c3d86..af44cc533e5 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -18,7 +18,6 @@ from tensordict import TensorDict from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp - from torchrl._utils import _check_for_faulty_process from torchrl.data import CompositeSpec, TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -781,10 +780,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in range(self.num_workers): msg, data = self.parent_channels[i].recv() if msg != "step_result": - if msg != "done": - raise RuntimeError( - f"Expected 'done' but received {msg} from worker {i}" - ) + raise RuntimeError( + f"Expected 'step_result' but received {msg} from worker {i}" + ) # data is the set of updated keys keys = keys.union(data) # We must pass a clone of the tensordict, as the values of this tensordict @@ -1004,13 +1002,18 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() _td = env._reset(**reset_kwargs) + done = _td.get("done", None) + if done is None: + _td["done"] = done = torch.zeros( + *_td.batch_size, 1, dtype=torch.bool, device=env.device + ) if reset_keys is None: reset_keys = set(_td.keys()) if pin_memory: _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - if _td.get("done").any(): + if done.any(): raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step": @@ -1029,10 +1032,7 @@ def _run_worker_pipe_shared_mem( if pin_memory: _td.pin_memory() tensordict.update_(_td.select(*step_keys, strict=False)) - if _td.get("done"): - msg = "done" - else: - msg = "step_result" + msg = "step_result" data = (msg, step_keys) child_pipe.send(data)