Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
TensorDictPrioritizedReplayBuffer,
writers,
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
SamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
LazyTensorStorage,
Expand Down Expand Up @@ -796,6 +800,39 @@ def test_smoke_replay_buffer_transform_no_inkeys(transform):
assert rb._transform.called


@pytest.mark.parametrize("size", [10, 15, 20])
@pytest.mark.parametrize("samples", [5, 9, 11, 14, 16])
@pytest.mark.parametrize("drop_last", [True, False])
def test_samplerwithoutrep(size, samples, drop_last):
torch.manual_seed(0)
storage = ListStorage(size)
storage.set(range(size), range(size))
assert len(storage) == size
sampler = SamplerWithoutReplacement(drop_last=drop_last)
visited = False
for _ in range(10):
_n_left = (
sampler._sample_list.numel() if sampler._sample_list is not None else size
)
if samples > size and drop_last:
with pytest.raises(
ValueError,
match=r"The batch size .* is greater than the storage capacity",
):
idx, _ = sampler.sample(storage, samples)
break
idx, _ = sampler.sample(storage, samples)
assert idx.numel() == samples
if drop_last or _n_left >= samples:
assert idx.unique().numel() == idx.numel()
else:
visited = True
if not drop_last and (size % samples > 0):
assert visited
else:
assert not visited


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
64 changes: 62 additions & 2 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,68 @@ def default_priority(self) -> float:
class RandomSampler(Sampler):
"""A uniformly random sampler for composable replay buffers."""

def sample(self, storage: Storage, batch_size: int) -> Tuple[np.array, dict]:
index = np.random.randint(0, len(storage), size=batch_size)
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
index = torch.randint(0, len(storage), (batch_size,))
return index, {}


class SamplerWithoutReplacement(Sampler):
"""A data-consuming sampler that ensures that the same sample is not present in consecutive batches.

Args:
drop_last (bool, optional): if True, the last incomplete sample (if any) will be dropped.
If False, this last sample will be kept and (unlike with torch dataloaders)
completed with other samples from a fresh indices permutation.

*Caution*: If the size of the storage changes in between two calls, the samples will be re-shuffled
(as we can't generally keep track of which samples have been sampled before and which haven't).

Similarly, it is expected that the storage content remains the same in between two calls,
but this is not enforced.

When the sampler reaches the end of the list of available indices, a new sample order
will be generated and the resulting indices will be completed with this new draw, which
can lead to duplicated indices, unless the :obj:`drop_last` argument is set to :obj:`True`.

"""

def __init__(self, drop_last: bool = False):
self._sample_list = None
self.len_storage = 0
self.drop_last = drop_last

def _single_sample(self, len_storage, batch_size):
index = self._sample_list[:batch_size]
self._sample_list = self._sample_list[batch_size:]
if not self._sample_list.numel():
self._sample_list = torch.randperm(len_storage)
return index

def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
len_storage = len(storage)
if not len_storage:
raise RuntimeError("An empty storage was passed")
if self.len_storage != len_storage or self._sample_list is None:
self._sample_list = torch.randperm(len_storage)
if len_storage < batch_size and self.drop_last:
raise ValueError(
f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). "
"This makes it impossible to return a sample without repeating indices. "
"Consider changing the sampler class or turn the 'drop_last' argument to False."
)
self.len_storage = len_storage
index = self._single_sample(len_storage, batch_size)
while index.numel() < batch_size:
if self.drop_last:
index = self._single_sample(len_storage, batch_size)
else:
index = torch.cat(
[
index,
self._single_sample(len_storage, batch_size - index.numel()),
],
0,
)
return index, {}


Expand Down