From b4149b1d201fa4278d8ccca6e46fa2f18e9275bc Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:33:12 +0200 Subject: [PATCH 1/4] storing_device for rollout --- torchrl/envs/common.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 0bf2d4bc34b..da8caa54a77 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3085,6 +3085,7 @@ def rollout( set_truncated: bool = False, out=None, trust_policy: bool = False, + storing_device: DEVICE_TYPING | None = None, ) -> TensorDictBase: """Executes a rollout in the environment. @@ -3140,6 +3141,8 @@ def rollout( trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. + storing_device (Device, optional): if provided, the tensordict will be stored on this device. + Defaults to ``None``. Returns: TensorDict object containing the resulting trajectory. @@ -3372,6 +3375,7 @@ def rollout( "policy": policy, "policy_device": policy_device, "env_device": env_device, + "storing_device": storing_device, "callback": callback, } if break_when_any_done or break_when_all_done: @@ -3508,6 +3512,7 @@ def _rollout_stop_early( policy, policy_device, env_device, + storing_device, callback, ): # Get the sync func @@ -3531,7 +3536,7 @@ def _rollout_stop_early( else: tensordict.clear_device_() tensordict = self.step(tensordict) - td_append = tensordict.copy() + td_append = tensordict.copy().to(storing_device) if break_when_all_done: if partial_steps is not True and not partial_steps.all(): # At least one step is partial @@ -3589,6 +3594,7 @@ def _rollout_nonstop( policy, policy_device, env_device, + storing_device, callback, ): if auto_cast_to_device: @@ -3614,7 +3620,7 @@ def _rollout_nonstop( tensordict = self.step(tensordict_) else: tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) - tensordicts.append(tensordict) + tensordicts.append(tensordict.to(storing_device)) if i == max_steps - 1: # we don't truncate as one could potentially continue the run break From a62a2feab135ab35d452dc707252b3787c9dd48f Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:33:23 +0200 Subject: [PATCH 2/4] test --- test/test_env.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_env.py b/test/test_env.py index fe0c52dce22..b092be80e1c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -827,6 +827,28 @@ def test_batch_unlocked_with_batch_size(self, device): # env.observation_spec = env.observation_spec.clone() # assert not env._cache + @pytest.mark.parametrize("storing_device", get_default_devices()) + def test_storing_device(self, storing_device): + """Ensure rollout data tensors are moved to the requested storing_device.""" + env = ContinuousActionVecMockEnv(device="cpu") + + td = env.rollout( + 10, + storing_device=torch.device(storing_device) + if storing_device is not None + else None, + ) + + expected_device = ( + torch.device(storing_device) if storing_device is not None else env.device + ) + + assert td.device == expected_device + + for _, item in td.items(True, True): + if isinstance(item, torch.Tensor): + assert item.device == expected_device + class TestRollout: @pytest.mark.skipif(not _has_gym, reason="no gym") From 541c3b125c51cc5935da38d1b91cc91a40b18d87 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:41:55 +0200 Subject: [PATCH 3/4] reduced overhead --- torchrl/envs/common.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index da8caa54a77..bf204449e07 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3536,7 +3536,10 @@ def _rollout_stop_early( else: tensordict.clear_device_() tensordict = self.step(tensordict) - td_append = tensordict.copy().to(storing_device) + if storing_device is None or tensordict.device == storing_device: + td_append = tensordict.copy() + else: + td_append = tensordict.to(storing_device) if break_when_all_done: if partial_steps is not True and not partial_steps.all(): # At least one step is partial @@ -3620,7 +3623,10 @@ def _rollout_nonstop( tensordict = self.step(tensordict_) else: tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) - tensordicts.append(tensordict.to(storing_device)) + if storing_device is None or tensordict.device == storing_device: + tensordicts.append(tensordict) + else: + tensordicts.append(tensordict.to(storing_device)) if i == max_steps - 1: # we don't truncate as one could potentially continue the run break From 658b219d79e875e548ea7ee380d446edfa22ff22 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Thu, 16 Oct 2025 14:14:57 +0200 Subject: [PATCH 4/4] ensure we compare devices --- torchrl/envs/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index bf204449e07..097dcd24df4 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3375,7 +3375,9 @@ def rollout( "policy": policy, "policy_device": policy_device, "env_device": env_device, - "storing_device": storing_device, + "storing_device": None + if storing_device is None + else torch.device(storing_device), "callback": callback, } if break_when_any_done or break_when_all_done: