From 6c6c19eaba702b40526cce596967e3c4e4b17c1a Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 6 Mar 2024 13:00:28 -0800 Subject: [PATCH 1/2] init --- torchrl/data/replay_buffers/writers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 5f95f8f8be3..d1c7836181c 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -81,7 +81,13 @@ def _replicate_index(self, index): if self._storage.ndim == 1: return index mesh = torch.stack( - torch.meshgrid(*(torch.arange(dim) for dim in self._storage.shape[1:])), -1 + torch.meshgrid( + *( + torch.arange(dim, device=index.device) + for dim in self._storage.shape[1:] + ) + ), + -1, ).flatten(0, -2) if _is_int(index): index0 = torch.as_tensor(int(index)).expand(mesh.shape[0], 1) From 3fafe1d91be70edda2426f4c68a5b09a0acf8e5d Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 6 Mar 2024 13:04:29 -0800 Subject: [PATCH 2/2] amend --- test/test_rb.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 69d9b8e4faf..b6360682c95 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -2678,16 +2678,21 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls): ], ], ) + @pytest.mark.parametrize("env_device", get_default_devices()) def test_rb_multidim_collector( - self, rbtype, storage_cls, writer_cls, sampler_cls, transform + self, rbtype, storage_cls, writer_cls, sampler_cls, transform, env_device ): from _utils_internal import CARTPOLE_VERSIONED torch.manual_seed(0) - env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED())) + env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()), device=env_device) env.set_seed(0) collector = SyncDataCollector( - env, RandomPolicy(env.action_spec), frames_per_batch=4, total_frames=16 + env, + RandomPolicy(env.action_spec), + frames_per_batch=4, + total_frames=16, + device=env_device, ) if writer_cls is TensorDictMaxValueWriter: with pytest.raises( @@ -2712,6 +2717,7 @@ def test_rb_multidim_collector( rb.append_transform(t()) try: for i, data in enumerate(collector): # noqa: B007 + assert data.device == torch.device(env_device) rb.extend(data) if isinstance(rb, TensorDictReplayBuffer) and transform is not None: # this should fail bc we can't set the indices after executing the transform. @@ -2721,6 +2727,7 @@ def test_rb_multidim_collector( rb.sample() return s = rb.sample() + assert s.device == torch.device("cpu") rbtot = rb[:] assert rbtot.shape[0] == 2 assert len(rb) == rbtot.numel()