From e6d13098bd8da0869223347391d5b9f2cbd84f2f Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Dec 2022 16:04:27 +0000 Subject: [PATCH] init --- test/test_rb.py | 39 ++++++++++++++- torchrl/data/replay_buffers/samplers.py | 64 ++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 74bce8d770d..5ed778e0867 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -20,7 +20,11 @@ TensorDictPrioritizedReplayBuffer, writers, ) -from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler +from torchrl.data.replay_buffers.samplers import ( + PrioritizedSampler, + RandomSampler, + SamplerWithoutReplacement, +) from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, LazyTensorStorage, @@ -796,6 +800,39 @@ def test_smoke_replay_buffer_transform_no_inkeys(transform): assert rb._transform.called +@pytest.mark.parametrize("size", [10, 15, 20]) +@pytest.mark.parametrize("samples", [5, 9, 11, 14, 16]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_samplerwithoutrep(size, samples, drop_last): + torch.manual_seed(0) + storage = ListStorage(size) + storage.set(range(size), range(size)) + assert len(storage) == size + sampler = SamplerWithoutReplacement(drop_last=drop_last) + visited = False + for _ in range(10): + _n_left = ( + sampler._sample_list.numel() if sampler._sample_list is not None else size + ) + if samples > size and drop_last: + with pytest.raises( + ValueError, + match=r"The batch size .* is greater than the storage capacity", + ): + idx, _ = sampler.sample(storage, samples) + break + idx, _ = sampler.sample(storage, samples) + assert idx.numel() == samples + if drop_last or _n_left >= samples: + assert idx.unique().numel() == idx.numel() + else: + visited = True + if not drop_last and (size % samples > 0): + assert visited + else: + assert not visited + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 2bb159d0b8d..93dd50142a3 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -44,8 +44,68 @@ def default_priority(self) -> float: class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers.""" - def sample(self, storage: Storage, batch_size: int) -> Tuple[np.array, dict]: - index = np.random.randint(0, len(storage), size=batch_size) + def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + index = torch.randint(0, len(storage), (batch_size,)) + return index, {} + + +class SamplerWithoutReplacement(Sampler): + """A data-consuming sampler that ensures that the same sample is not present in consecutive batches. + + Args: + drop_last (bool, optional): if True, the last incomplete sample (if any) will be dropped. + If False, this last sample will be kept and (unlike with torch dataloaders) + completed with other samples from a fresh indices permutation. + + *Caution*: If the size of the storage changes in between two calls, the samples will be re-shuffled + (as we can't generally keep track of which samples have been sampled before and which haven't). + + Similarly, it is expected that the storage content remains the same in between two calls, + but this is not enforced. + + When the sampler reaches the end of the list of available indices, a new sample order + will be generated and the resulting indices will be completed with this new draw, which + can lead to duplicated indices, unless the :obj:`drop_last` argument is set to :obj:`True`. + + """ + + def __init__(self, drop_last: bool = False): + self._sample_list = None + self.len_storage = 0 + self.drop_last = drop_last + + def _single_sample(self, len_storage, batch_size): + index = self._sample_list[:batch_size] + self._sample_list = self._sample_list[batch_size:] + if not self._sample_list.numel(): + self._sample_list = torch.randperm(len_storage) + return index + + def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: + len_storage = len(storage) + if not len_storage: + raise RuntimeError("An empty storage was passed") + if self.len_storage != len_storage or self._sample_list is None: + self._sample_list = torch.randperm(len_storage) + if len_storage < batch_size and self.drop_last: + raise ValueError( + f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). " + "This makes it impossible to return a sample without repeating indices. " + "Consider changing the sampler class or turn the 'drop_last' argument to False." + ) + self.len_storage = len_storage + index = self._single_sample(len_storage, batch_size) + while index.numel() < batch_size: + if self.drop_last: + index = self._single_sample(len_storage, batch_size) + else: + index = torch.cat( + [ + index, + self._single_sample(len_storage, batch_size - index.numel()), + ], + 0, + ) return index, {}