Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2678,16 +2678,21 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
],
],
)
@pytest.mark.parametrize("env_device", get_default_devices())
def test_rb_multidim_collector(
self, rbtype, storage_cls, writer_cls, sampler_cls, transform
self, rbtype, storage_cls, writer_cls, sampler_cls, transform, env_device
):
from _utils_internal import CARTPOLE_VERSIONED

torch.manual_seed(0)
env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()))
env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()), device=env_device)
env.set_seed(0)
collector = SyncDataCollector(
env, RandomPolicy(env.action_spec), frames_per_batch=4, total_frames=16
env,
RandomPolicy(env.action_spec),
frames_per_batch=4,
total_frames=16,
device=env_device,
)
if writer_cls is TensorDictMaxValueWriter:
with pytest.raises(
Expand All @@ -2712,6 +2717,7 @@ def test_rb_multidim_collector(
rb.append_transform(t())
try:
for i, data in enumerate(collector): # noqa: B007
assert data.device == torch.device(env_device)
rb.extend(data)
if isinstance(rb, TensorDictReplayBuffer) and transform is not None:
# this should fail bc we can't set the indices after executing the transform.
Expand All @@ -2721,6 +2727,7 @@ def test_rb_multidim_collector(
rb.sample()
return
s = rb.sample()
assert s.device == torch.device("cpu")
rbtot = rb[:]
assert rbtot.shape[0] == 2
assert len(rb) == rbtot.numel()
Expand Down
8 changes: 7 additions & 1 deletion torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ def _replicate_index(self, index):
if self._storage.ndim == 1:
return index
mesh = torch.stack(
torch.meshgrid(*(torch.arange(dim) for dim in self._storage.shape[1:])), -1
torch.meshgrid(
*(
torch.arange(dim, device=index.device)
for dim in self._storage.shape[1:]
)
),
-1,
).flatten(0, -2)
if _is_int(index):
index0 = torch.as_tensor(int(index)).expand(mesh.shape[0], 1)
Expand Down