From 1834af5a7f322d5eade225fa8e43dffdea5cfe9f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Mar 2024 15:09:02 +0000 Subject: [PATCH] init --- torchrl/data/replay_buffers/writers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index d1c7836181c..3de8d7f4773 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -80,12 +80,12 @@ def _replicate_index(self, index): # elements truly written when the storage is multidim if self._storage.ndim == 1: return index + device = ( + index.device if isinstance(index, torch.Tensor) else torch.device("cpu") + ) mesh = torch.stack( torch.meshgrid( - *( - torch.arange(dim, device=index.device) - for dim in self._storage.shape[1:] - ) + *(torch.arange(dim, device=device) for dim in self._storage.shape[1:]) ), -1, ).flatten(0, -2)