Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
@pytest.mark.parametrize("writer", [writers.RoundRobinWriter])
@pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage])
@pytest.mark.parametrize("size", [3, 100])
@pytest.mark.parametrize("size", [3, 5, 100])
class TestPrototypeBuffers:
def _get_rb(self, rb_type, size, sampler, writer, storage):

Expand Down Expand Up @@ -124,6 +124,26 @@ def test_add(self, rb_type, sampler, writer, storage, size):
else:
assert (s == data).all()

def test_cursor_position(self, rb_type, sampler, writer, storage, size):
storage = storage(size)
writer = writer()
writer.register_storage(storage)
batch1 = self._get_data(rb_type, size=5)
writer.extend(batch1)

# Added less data than storage max size
if size > 5:
assert writer._cursor == 5
# Added more data than storage max size
elif size < 5:
assert writer._cursor == 5 - size
# Added as data as storage max size
else:
assert writer._cursor == 0
batch2 = self._get_data(rb_type, size=size - 1)
writer.extend(batch2)
assert writer._cursor == size - 1

def test_extend(self, rb_type, sampler, writer, storage, size):
torch.manual_seed(0)
rb = self._get_rb(
Expand Down Expand Up @@ -342,7 +362,7 @@ def test_rb_prototype_trajectories(stack):
(TensorDictPrioritizedReplayBuffer, LazyMemmapStorage),
],
)
@pytest.mark.parametrize("size", [3, 100])
@pytest.mark.parametrize("size", [3, 5, 100])
@pytest.mark.parametrize("prefetch", [0])
class TestBuffers:
_default_params_rb = {}
Expand Down Expand Up @@ -404,6 +424,25 @@ def _get_data(self, rbtype, size):
raise NotImplementedError(rbtype)
return data

def test_cursor_position2(self, rbtype, storage, size, prefetch):
torch.manual_seed(0)
rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch)
batch1 = self._get_data(rbtype, size=5)
rb.extend(batch1)

# Added less data than storage max size
if size > 5:
assert rb._cursor == 5
# Added more data than storage max size
elif size < 5:
assert rb._cursor == 5 - size
# Added as data as storage max size
else:
assert rb._cursor == 0
batch2 = self._get_data(rbtype, size=size - 1)
rb.extend(batch2)
assert rb._cursor == size - 1

def test_add(self, rbtype, storage, size, prefetch):
torch.manual_seed(0)
rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch)
Expand Down
24 changes: 1 addition & 23 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,37 +195,15 @@ def extend(self, data: Sequence[Any]):
if not len(data):
raise Exception("extending with empty data is not supported")
with self._replay_lock:
cur_size = len(self._storage)
batch_size = len(data)
# storage = self._storage
# cursor = self._cursor
if cur_size + batch_size <= self._capacity:
index = np.arange(cur_size, cur_size + batch_size)
# self._storage += data
self._cursor = (self._cursor + batch_size) % self._capacity
elif cur_size < self._capacity:
d = self._capacity - cur_size
index = np.empty(batch_size, dtype=np.int64)
index[:d] = np.arange(cur_size, self._capacity)
index[d:] = np.arange(batch_size - d)
# storage += data[:d]
# for i, v in enumerate(data[d:]):
# storage[i] = v
self._cursor = batch_size - d
elif self._cursor + batch_size <= self._capacity:
if self._cursor + batch_size <= self._capacity:
index = np.arange(self._cursor, self._cursor + batch_size)
# for i, v in enumerate(data):
# storage[cursor + i] = v
self._cursor = (self._cursor + batch_size) % self._capacity
else:
d = self._capacity - self._cursor
index = np.empty(batch_size, dtype=np.int64)
index[:d] = np.arange(self._cursor, self._capacity)
index[d:] = np.arange(batch_size - d)
# for i, v in enumerate(data[:d]):
# storage[cursor + i] = v
# for i, v in enumerate(data[d:]):
# storage[i] = v
self._cursor = batch_size - d
# storage must convert the data to the appropriate format if needed
self._storage[index] = data
Expand Down
11 changes: 1 addition & 10 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,12 @@ def extend(self, data: Sequence) -> torch.Tensor:
if cur_size + batch_size <= self._storage.max_size:
index = np.arange(cur_size, cur_size + batch_size)
self._cursor = (self._cursor + batch_size) % self._storage.max_size
elif cur_size < self._storage.max_size:
else:
d = self._storage.max_size - cur_size
index = np.empty(batch_size, dtype=np.int64)
index[:d] = np.arange(cur_size, self._storage.max_size)
index[d:] = np.arange(batch_size - d)
self._cursor = batch_size - d
elif self._cursor + batch_size <= self._storage.max_size:
index = np.arange(self._cursor, self._cursor + batch_size)
self._cursor = (self._cursor + batch_size) % self._storage.max_size
else:
d = self._storage.max_size - self._cursor
index = np.empty(batch_size, dtype=np.int64)
index[:d] = np.arange(self._cursor, self._storage.max_size)
index[d:] = np.arange(batch_size - d)
self._cursor = batch_size - d
# storage must convert the data to the appropriate format if needed
self._storage[index] = data
return index