Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,28 @@ def test_batch_unlocked_with_batch_size(self, device):
# env.observation_spec = env.observation_spec.clone()
# assert not env._cache

@pytest.mark.parametrize("storing_device", get_default_devices())
def test_storing_device(self, storing_device):
"""Ensure rollout data tensors are moved to the requested storing_device."""
env = ContinuousActionVecMockEnv(device="cpu")

td = env.rollout(
10,
storing_device=torch.device(storing_device)
if storing_device is not None
else None,
)

expected_device = (
torch.device(storing_device) if storing_device is not None else env.device
)

assert td.device == expected_device

for _, item in td.items(True, True):
if isinstance(item, torch.Tensor):
assert item.device == expected_device


class TestRollout:
@pytest.mark.skipif(not _has_gym, reason="no gym")
Expand Down
18 changes: 16 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3085,6 +3085,7 @@ def rollout(
set_truncated: bool = False,
out=None,
trust_policy: bool = False,
storing_device: DEVICE_TYPING | None = None,
) -> TensorDictBase:
"""Executes a rollout in the environment.

Expand Down Expand Up @@ -3140,6 +3141,8 @@ def rollout(
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
and ``False`` otherwise.
storing_device (Device, optional): if provided, the tensordict will be stored on this device.
Defaults to ``None``.

Returns:
TensorDict object containing the resulting trajectory.
Expand Down Expand Up @@ -3372,6 +3375,9 @@ def rollout(
"policy": policy,
"policy_device": policy_device,
"env_device": env_device,
"storing_device": None
if storing_device is None
else torch.device(storing_device),
"callback": callback,
}
if break_when_any_done or break_when_all_done:
Expand Down Expand Up @@ -3508,6 +3514,7 @@ def _rollout_stop_early(
policy,
policy_device,
env_device,
storing_device,
callback,
):
# Get the sync func
Expand All @@ -3531,7 +3538,10 @@ def _rollout_stop_early(
else:
tensordict.clear_device_()
tensordict = self.step(tensordict)
td_append = tensordict.copy()
if storing_device is None or tensordict.device == storing_device:
td_append = tensordict.copy()
else:
td_append = tensordict.to(storing_device)
if break_when_all_done:
if partial_steps is not True and not partial_steps.all():
# At least one step is partial
Expand Down Expand Up @@ -3589,6 +3599,7 @@ def _rollout_nonstop(
policy,
policy_device,
env_device,
storing_device,
callback,
):
if auto_cast_to_device:
Expand All @@ -3614,7 +3625,10 @@ def _rollout_nonstop(
tensordict = self.step(tensordict_)
else:
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
tensordicts.append(tensordict)
if storing_device is None or tensordict.device == storing_device:
tensordicts.append(tensordict)
else:
tensordicts.append(tensordict.to(storing_device))
if i == max_steps - 1:
# we don't truncate as one could potentially continue the run
break
Expand Down