diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 559935ee976..986fe31b1e0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2451,13 +2451,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if event is not None: event.record() event.synchronize() - mp_event.set() if _non_tensor_keys: child_pipe.send( ("non_tensor", cur_td.select(*_non_tensor_keys, strict=False)) ) + # Set event only after non-tensor data is sent to avoid race condition + mp_event.set() + del cur_td elif cmd == "step": @@ -2483,7 +2485,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if event is not None: event.record() event.synchronize() - mp_event.set() # Make sure the root is updated root_shared_tensordict.update_(env._step_mdp(input)) @@ -2493,6 +2494,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda): ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) ) + # Set event only after non-tensor data is sent to avoid race condition + mp_event.set() + del next_td elif cmd == "step_and_maybe_reset": @@ -2525,13 +2529,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if event is not None: event.record() event.synchronize() - mp_event.set() if _non_tensor_keys: ntd = root_next_td.select(*_non_tensor_keys) ntd.set("next", td_next.select(*_non_tensor_keys)) child_pipe.send(("non_tensor", ntd)) + # Set event only after non-tensor data is sent to avoid race condition + mp_event.set() + del td, root_next_td elif cmd == "close":