diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index d1c7836181c..3de8d7f4773 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -80,12 +80,12 @@ def _replicate_index(self, index): # elements truly written when the storage is multidim if self._storage.ndim == 1: return index + device = ( + index.device if isinstance(index, torch.Tensor) else torch.device("cpu") + ) mesh = torch.stack( torch.meshgrid( - *( - torch.arange(dim, device=index.device) - for dim in self._storage.shape[1:] - ) + *(torch.arange(dim, device=device) for dim in self._storage.shape[1:]) ), -1, ).flatten(0, -2)