From 1f564847e7dc0f9e25740525f483b82af9ef25bd Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 10 Sep 2025 09:52:06 +0100 Subject: [PATCH] [Test] Fix flaky parallel env test --- torchrl/envs/batched_envs.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 559935ee976..986fe31b1e0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2451,13 +2451,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if event is not None: event.record() event.synchronize() - mp_event.set() if _non_tensor_keys: child_pipe.send( ("non_tensor", cur_td.select(*_non_tensor_keys, strict=False)) ) + # Set event only after non-tensor data is sent to avoid race condition + mp_event.set() + del cur_td elif cmd == "step": @@ -2483,7 +2485,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if event is not None: event.record() event.synchronize() - mp_event.set() # Make sure the root is updated root_shared_tensordict.update_(env._step_mdp(input)) @@ -2493,6 +2494,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda): ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) ) + # Set event only after non-tensor data is sent to avoid race condition + mp_event.set() + del next_td elif cmd == "step_and_maybe_reset": @@ -2525,13 +2529,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if event is not None: event.record() event.synchronize() - mp_event.set() if _non_tensor_keys: ntd = root_next_td.select(*_non_tensor_keys) ntd.set("next", td_next.select(*_non_tensor_keys)) child_pipe.send(("non_tensor", ntd)) + # Set event only after non-tensor data is sent to avoid race condition + mp_event.set() + del td, root_next_td elif cmd == "close":