diff --git a/test/test_rb.py b/test/test_rb.py index 69d9b8e4faf..cf7ad4b4ef3 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -2555,12 +2555,19 @@ def test_rb_indexing(self, explicit): def _rbtype(datatype): if datatype in ("pytree", "tensorclass"): - return [ReplayBuffer, PrioritizedReplayBuffer] + return [ + (ReplayBuffer, RandomSampler), + (PrioritizedReplayBuffer, RandomSampler), + (ReplayBuffer, SamplerWithoutReplacement), + (PrioritizedReplayBuffer, SamplerWithoutReplacement), + ] return [ - ReplayBuffer, - PrioritizedReplayBuffer, - TensorDictReplayBuffer, - TensorDictPrioritizedReplayBuffer, + (ReplayBuffer, RandomSampler), + (ReplayBuffer, SamplerWithoutReplacement), + (PrioritizedReplayBuffer, None), + (TensorDictReplayBuffer, RandomSampler), + (TensorDictReplayBuffer, SamplerWithoutReplacement), + (TensorDictPrioritizedReplayBuffer, None), ] @@ -2598,19 +2605,19 @@ def _make_data(self, datatype, datadim): batch_size=shape, ) - datatype_rb_pairs = [ - [datatype, rbtype] + datatype_rb_tuples = [ + [datatype, *rbtype] for datatype in ["pytree", "tensordict", "tensorclass"] for rbtype in _rbtype(datatype) ] - @pytest.mark.parametrize("datatype,rbtype", datatype_rb_pairs) + @pytest.mark.parametrize("datatype,rbtype,sampler_cls", datatype_rb_tuples) @pytest.mark.parametrize("datadim", [1, 2]) @pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage]) - def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls): + def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls): data = self._make_data(datatype, datadim) if rbtype not in (PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer): - rbtype = functools.partial(rbtype, sampler=RandomSampler()) + rbtype = functools.partial(rbtype, sampler=sampler_cls()) else: rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 03ecdb76594..5b49b067956 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -177,10 +177,12 @@ def _get_sample_list(self, storage: Storage, len_storage: int): device = self._sample_list.device else: device = storage.device if hasattr(storage, "device") else None + if self.shuffle: - self._sample_list = torch.randperm(len_storage, device=device) + _sample_list = torch.randperm(len_storage, device=device) else: - self._sample_list = torch.arange(len_storage, device=device) + _sample_list = torch.arange(len_storage, device=device) + self._sample_list = _sample_list def _single_sample(self, len_storage, batch_size): index = self._sample_list[:batch_size] @@ -188,7 +190,7 @@ def _single_sample(self, len_storage, batch_size): # check if we have enough elements for one more batch, assuming same batch size # will be used each time sample is called - if self._sample_list.numel() == 0 or ( + if self._sample_list.shape[0] == 0 or ( self.drop_last and len(self._sample_list) < batch_size ): self.ran_out = True @@ -201,7 +203,6 @@ def _storage_len(self, storage): return len(storage) def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: - storage = storage.flatten() len_storage = self._storage_len(storage) if len_storage == 0: raise RuntimeError(_EMPTY_STORAGE_ERROR) @@ -217,6 +218,8 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: ) self.len_storage = len_storage index = self._single_sample(len_storage, batch_size) + if storage.ndim > 1: + index = torch.unravel_index(index, storage.shape) # we 'always' return the indices. The 'drop_last' just instructs the # sampler to turn to 'ran_out = True` whenever the next sample # will be too short. This will be read by the replay buffer @@ -834,10 +837,7 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool): ) # Using transpose ensures the start and stop are sorted the same way stop_idx = end.transpose(0, -1).nonzero() - # beginnings = torch.cat([end[-1:], end[:-1]], 0) - # start_idx = beginnings.transpose(0, -1).nonzero() - # start_idx = torch.cat([start_idx[:, -1:], start_idx[:, :-1]], -1) - stop_idx = torch.cat([stop_idx[:, -1:], stop_idx[:, :-1]], -1) + stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone() # First build the start indices as the stop + 1, we'll shift it later start_idx = stop_idx.clone() start_idx[:, 0] += 1 @@ -991,6 +991,8 @@ def _sample_slices( storage_length: int, traj_idx: torch.Tensor | None = None, ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: + # start_idx and stop_idx are 2d tensors organized like a non-zero + def get_traj_idx(lengths=lengths): return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device) @@ -999,7 +1001,7 @@ def get_traj_idx(lengths=lengths): idx = lengths == seq_length if not idx.any(): raise RuntimeError( - "Did not find a single trajectory with sufficient length." + f"Did not find a single trajectory with sufficient length (length range: {lengths.min()} - {lengths.max()} / required={seq_length}))." ) if ( isinstance(seq_length, torch.Tensor) @@ -1307,6 +1309,16 @@ def sample( seq_length, num_slices = self._adjusted_batch_size(batch_size) indices, _ = SamplerWithoutReplacement.sample(self, storage, num_slices) storage_length = storage.shape[0] + + # traj_idx will either be a single tensor or a tuple that can be reorganized + # like a non-zero through stacking. + def tuple_to_tensor(traj_idx, lengths=lengths): + if isinstance(traj_idx, tuple): + traj_idx = torch.arange(len(storage), device=lengths.device).view( + storage.shape + )[traj_idx] + return traj_idx + idx, info = self._sample_slices( lengths, start_idx, @@ -1314,7 +1326,7 @@ def sample( seq_length, num_slices, storage_length, - traj_idx=indices, + traj_idx=tuple_to_tensor(indices), ) return idx, info diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 5f95f8f8be3..d1c7836181c 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -81,7 +81,13 @@ def _replicate_index(self, index): if self._storage.ndim == 1: return index mesh = torch.stack( - torch.meshgrid(*(torch.arange(dim) for dim in self._storage.shape[1:])), -1 + torch.meshgrid( + *( + torch.arange(dim, device=index.device) + for dim in self._storage.shape[1:] + ) + ), + -1, ).flatten(0, -2) if _is_int(index): index0 = torch.as_tensor(int(index)).expand(mesh.shape[0], 1)