Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def test_parallel_env(
td = TensorDict(
source={"action": env0.action_spec.rand((N - 1,))}, batch_size=[N - 1]
)
td1 = env_parallel.step(td)
_ = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
Expand Down Expand Up @@ -592,7 +592,7 @@ def test_parallel_env_with_policy(
td = TensorDict(
source={"action": env0.action_spec.rand((N - 1,))}, batch_size=[N - 1]
)
td1 = env_parallel.step(td)
_ = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
Expand Down Expand Up @@ -710,7 +710,6 @@ def test_parallel_env_cast(
transformed_out,
device,
open_before,
T=10,
N=3,
):
# tests casting to device
Expand Down
20 changes: 10 additions & 10 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from tensordict import TensorDict
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from torch import multiprocessing as mp

from torchrl._utils import _check_for_faulty_process
from torchrl.data import CompositeSpec, TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
Expand Down Expand Up @@ -781,10 +780,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
for i in range(self.num_workers):
msg, data = self.parent_channels[i].recv()
if msg != "step_result":
if msg != "done":
raise RuntimeError(
f"Expected 'done' but received {msg} from worker {i}"
)
raise RuntimeError(
f"Expected 'step_result' but received {msg} from worker {i}"
)
# data is the set of updated keys
keys = keys.union(data)
# We must pass a clone of the tensordict, as the values of this tensordict
Expand Down Expand Up @@ -1004,13 +1002,18 @@ def _run_worker_pipe_shared_mem(
raise RuntimeError("call 'init' before resetting")
# _td = tensordict.select("observation").to(env.device).clone()
_td = env._reset(**reset_kwargs)
done = _td.get("done", None)
if done is None:
_td["done"] = done = torch.zeros(
*_td.batch_size, 1, dtype=torch.bool, device=env.device
)
if reset_keys is None:
reset_keys = set(_td.keys())
if pin_memory:
_td.pin_memory()
tensordict.update_(_td)
child_pipe.send(("reset_obs", reset_keys))
if _td.get("done").any():
if done.any():
raise RuntimeError(f"{env.__class__.__name__} is done after reset")

elif cmd == "step":
Expand All @@ -1029,10 +1032,7 @@ def _run_worker_pipe_shared_mem(
if pin_memory:
_td.pin_memory()
tensordict.update_(_td.select(*step_keys, strict=False))
if _td.get("done"):
msg = "done"
else:
msg = "step_result"
msg = "step_result"
data = (msg, step_keys)
child_pipe.send(data)

Expand Down