diff --git a/test/test_env.py b/test/test_env.py index fe0c52dce22..b092be80e1c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -827,6 +827,28 @@ def test_batch_unlocked_with_batch_size(self, device): # env.observation_spec = env.observation_spec.clone() # assert not env._cache + @pytest.mark.parametrize("storing_device", get_default_devices()) + def test_storing_device(self, storing_device): + """Ensure rollout data tensors are moved to the requested storing_device.""" + env = ContinuousActionVecMockEnv(device="cpu") + + td = env.rollout( + 10, + storing_device=torch.device(storing_device) + if storing_device is not None + else None, + ) + + expected_device = ( + torch.device(storing_device) if storing_device is not None else env.device + ) + + assert td.device == expected_device + + for _, item in td.items(True, True): + if isinstance(item, torch.Tensor): + assert item.device == expected_device + class TestRollout: @pytest.mark.skipif(not _has_gym, reason="no gym") diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 0bf2d4bc34b..097dcd24df4 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3085,6 +3085,7 @@ def rollout( set_truncated: bool = False, out=None, trust_policy: bool = False, + storing_device: DEVICE_TYPING | None = None, ) -> TensorDictBase: """Executes a rollout in the environment. @@ -3140,6 +3141,8 @@ def rollout( trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. + storing_device (Device, optional): if provided, the tensordict will be stored on this device. + Defaults to ``None``. Returns: TensorDict object containing the resulting trajectory. @@ -3372,6 +3375,9 @@ def rollout( "policy": policy, "policy_device": policy_device, "env_device": env_device, + "storing_device": None + if storing_device is None + else torch.device(storing_device), "callback": callback, } if break_when_any_done or break_when_all_done: @@ -3508,6 +3514,7 @@ def _rollout_stop_early( policy, policy_device, env_device, + storing_device, callback, ): # Get the sync func @@ -3531,7 +3538,10 @@ def _rollout_stop_early( else: tensordict.clear_device_() tensordict = self.step(tensordict) - td_append = tensordict.copy() + if storing_device is None or tensordict.device == storing_device: + td_append = tensordict.copy() + else: + td_append = tensordict.to(storing_device) if break_when_all_done: if partial_steps is not True and not partial_steps.all(): # At least one step is partial @@ -3589,6 +3599,7 @@ def _rollout_nonstop( policy, policy_device, env_device, + storing_device, callback, ): if auto_cast_to_device: @@ -3614,7 +3625,10 @@ def _rollout_nonstop( tensordict = self.step(tensordict_) else: tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) - tensordicts.append(tensordict) + if storing_device is None or tensordict.device == storing_device: + tensordicts.append(tensordict) + else: + tensordicts.append(tensordict.to(storing_device)) if i == max_steps - 1: # we don't truncate as one could potentially continue the run break