diff --git a/test/test_rb.py b/test/test_rb.py index c95bcfdf936..74bce8d770d 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -65,7 +65,7 @@ ) @pytest.mark.parametrize("writer", [writers.RoundRobinWriter]) @pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage]) -@pytest.mark.parametrize("size", [3, 100]) +@pytest.mark.parametrize("size", [3, 5, 100]) class TestPrototypeBuffers: def _get_rb(self, rb_type, size, sampler, writer, storage): @@ -124,6 +124,26 @@ def test_add(self, rb_type, sampler, writer, storage, size): else: assert (s == data).all() + def test_cursor_position(self, rb_type, sampler, writer, storage, size): + storage = storage(size) + writer = writer() + writer.register_storage(storage) + batch1 = self._get_data(rb_type, size=5) + writer.extend(batch1) + + # Added less data than storage max size + if size > 5: + assert writer._cursor == 5 + # Added more data than storage max size + elif size < 5: + assert writer._cursor == 5 - size + # Added as data as storage max size + else: + assert writer._cursor == 0 + batch2 = self._get_data(rb_type, size=size - 1) + writer.extend(batch2) + assert writer._cursor == size - 1 + def test_extend(self, rb_type, sampler, writer, storage, size): torch.manual_seed(0) rb = self._get_rb( @@ -342,7 +362,7 @@ def test_rb_prototype_trajectories(stack): (TensorDictPrioritizedReplayBuffer, LazyMemmapStorage), ], ) -@pytest.mark.parametrize("size", [3, 100]) +@pytest.mark.parametrize("size", [3, 5, 100]) @pytest.mark.parametrize("prefetch", [0]) class TestBuffers: _default_params_rb = {} @@ -404,6 +424,25 @@ def _get_data(self, rbtype, size): raise NotImplementedError(rbtype) return data + def test_cursor_position2(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + batch1 = self._get_data(rbtype, size=5) + rb.extend(batch1) + + # Added less data than storage max size + if size > 5: + assert rb._cursor == 5 + # Added more data than storage max size + elif size < 5: + assert rb._cursor == 5 - size + # Added as data as storage max size + else: + assert rb._cursor == 0 + batch2 = self._get_data(rbtype, size=size - 1) + rb.extend(batch2) + assert rb._cursor == size - 1 + def test_add(self, rbtype, storage, size, prefetch): torch.manual_seed(0) rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 9dccba7ad07..053384e6bb8 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -195,37 +195,15 @@ def extend(self, data: Sequence[Any]): if not len(data): raise Exception("extending with empty data is not supported") with self._replay_lock: - cur_size = len(self._storage) batch_size = len(data) - # storage = self._storage - # cursor = self._cursor - if cur_size + batch_size <= self._capacity: - index = np.arange(cur_size, cur_size + batch_size) - # self._storage += data - self._cursor = (self._cursor + batch_size) % self._capacity - elif cur_size < self._capacity: - d = self._capacity - cur_size - index = np.empty(batch_size, dtype=np.int64) - index[:d] = np.arange(cur_size, self._capacity) - index[d:] = np.arange(batch_size - d) - # storage += data[:d] - # for i, v in enumerate(data[d:]): - # storage[i] = v - self._cursor = batch_size - d - elif self._cursor + batch_size <= self._capacity: + if self._cursor + batch_size <= self._capacity: index = np.arange(self._cursor, self._cursor + batch_size) - # for i, v in enumerate(data): - # storage[cursor + i] = v self._cursor = (self._cursor + batch_size) % self._capacity else: d = self._capacity - self._cursor index = np.empty(batch_size, dtype=np.int64) index[:d] = np.arange(self._cursor, self._capacity) index[d:] = np.arange(batch_size - d) - # for i, v in enumerate(data[:d]): - # storage[cursor + i] = v - # for i, v in enumerate(data[d:]): - # storage[i] = v self._cursor = batch_size - d # storage must convert the data to the appropriate format if needed self._storage[index] = data diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 88657e1fcba..f058dd32f2d 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -46,21 +46,12 @@ def extend(self, data: Sequence) -> torch.Tensor: if cur_size + batch_size <= self._storage.max_size: index = np.arange(cur_size, cur_size + batch_size) self._cursor = (self._cursor + batch_size) % self._storage.max_size - elif cur_size < self._storage.max_size: + else: d = self._storage.max_size - cur_size index = np.empty(batch_size, dtype=np.int64) index[:d] = np.arange(cur_size, self._storage.max_size) index[d:] = np.arange(batch_size - d) self._cursor = batch_size - d - elif self._cursor + batch_size <= self._storage.max_size: - index = np.arange(self._cursor, self._cursor + batch_size) - self._cursor = (self._cursor + batch_size) % self._storage.max_size - else: - d = self._storage.max_size - self._cursor - index = np.empty(batch_size, dtype=np.int64) - index[:d] = np.arange(self._cursor, self._storage.max_size) - index[d:] = np.arange(batch_size - d) - self._cursor = batch_size - d # storage must convert the data to the appropriate format if needed self._storage[index] = data return index