From aff12fb5afdb5a626eba5434e1cefff425d48c5f Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 25 Nov 2022 10:30:43 +0100 Subject: [PATCH 1/5] removed extra code --- torchrl/data/replay_buffers/writers.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 88657e1fcba..f058dd32f2d 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -46,21 +46,12 @@ def extend(self, data: Sequence) -> torch.Tensor: if cur_size + batch_size <= self._storage.max_size: index = np.arange(cur_size, cur_size + batch_size) self._cursor = (self._cursor + batch_size) % self._storage.max_size - elif cur_size < self._storage.max_size: + else: d = self._storage.max_size - cur_size index = np.empty(batch_size, dtype=np.int64) index[:d] = np.arange(cur_size, self._storage.max_size) index[d:] = np.arange(batch_size - d) self._cursor = batch_size - d - elif self._cursor + batch_size <= self._storage.max_size: - index = np.arange(self._cursor, self._cursor + batch_size) - self._cursor = (self._cursor + batch_size) % self._storage.max_size - else: - d = self._storage.max_size - self._cursor - index = np.empty(batch_size, dtype=np.int64) - index[:d] = np.arange(self._cursor, self._storage.max_size) - index[d:] = np.arange(batch_size - d) - self._cursor = batch_size - d # storage must convert the data to the appropriate format if needed self._storage[index] = data return index From 730359e5dc97d19276c2acd8fea00f8cbf2c5579 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 25 Nov 2022 17:02:07 +0100 Subject: [PATCH 2/5] added testing of cursor positions --- test/test_rb.py | 22 ++++++++++++++++++- torchrl/data/replay_buffers/replay_buffers.py | 18 --------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index c95bcfdf936..84b9e412fc1 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -65,7 +65,7 @@ ) @pytest.mark.parametrize("writer", [writers.RoundRobinWriter]) @pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage]) -@pytest.mark.parametrize("size", [3, 100]) +@pytest.mark.parametrize("size", [3, 5, 100]) class TestPrototypeBuffers: def _get_rb(self, rb_type, size, sampler, writer, storage): @@ -124,6 +124,26 @@ def test_add(self, rb_type, sampler, writer, storage, size): else: assert (s == data).all() + def test_cursor_position(self, rb_type, sampler, writer, storage, size): + storage = storage(size) + writer = writer() + writer.register_storage(storage) + batch1 = self._get_data(rb_type, size=5) + writer.extend(batch1) + + # Added less data than storage max size + if size > 5: + assert writer._cursor == 5 + # Added more data than storage max size + elif size < 5: + assert writer._cursor == 5 - size + # Added as data as storage max size + else: + assert writer._cursor == 0 + batch2 = self._get_data(rb_type, size=size - 1) + writer.extend(batch2) + assert writer._cursor == size - 1 + def test_extend(self, rb_type, sampler, writer, storage, size): torch.manual_seed(0) rb = self._get_rb( diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 9dccba7ad07..6c0ee3f37ea 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -203,29 +203,11 @@ def extend(self, data: Sequence[Any]): index = np.arange(cur_size, cur_size + batch_size) # self._storage += data self._cursor = (self._cursor + batch_size) % self._capacity - elif cur_size < self._capacity: - d = self._capacity - cur_size - index = np.empty(batch_size, dtype=np.int64) - index[:d] = np.arange(cur_size, self._capacity) - index[d:] = np.arange(batch_size - d) - # storage += data[:d] - # for i, v in enumerate(data[d:]): - # storage[i] = v - self._cursor = batch_size - d - elif self._cursor + batch_size <= self._capacity: - index = np.arange(self._cursor, self._cursor + batch_size) - # for i, v in enumerate(data): - # storage[cursor + i] = v - self._cursor = (self._cursor + batch_size) % self._capacity else: d = self._capacity - self._cursor index = np.empty(batch_size, dtype=np.int64) index[:d] = np.arange(self._cursor, self._capacity) index[d:] = np.arange(batch_size - d) - # for i, v in enumerate(data[:d]): - # storage[cursor + i] = v - # for i, v in enumerate(data[d:]): - # storage[i] = v self._cursor = batch_size - d # storage must convert the data to the appropriate format if needed self._storage[index] = data From 13b1f23d98755f89625feaf7b0050d69c88f6f02 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 25 Nov 2022 17:07:06 +0100 Subject: [PATCH 3/5] reset changes --- torchrl/data/replay_buffers/replay_buffers.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 6c0ee3f37ea..9dccba7ad07 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -203,11 +203,29 @@ def extend(self, data: Sequence[Any]): index = np.arange(cur_size, cur_size + batch_size) # self._storage += data self._cursor = (self._cursor + batch_size) % self._capacity + elif cur_size < self._capacity: + d = self._capacity - cur_size + index = np.empty(batch_size, dtype=np.int64) + index[:d] = np.arange(cur_size, self._capacity) + index[d:] = np.arange(batch_size - d) + # storage += data[:d] + # for i, v in enumerate(data[d:]): + # storage[i] = v + self._cursor = batch_size - d + elif self._cursor + batch_size <= self._capacity: + index = np.arange(self._cursor, self._cursor + batch_size) + # for i, v in enumerate(data): + # storage[cursor + i] = v + self._cursor = (self._cursor + batch_size) % self._capacity else: d = self._capacity - self._cursor index = np.empty(batch_size, dtype=np.int64) index[:d] = np.arange(self._cursor, self._capacity) index[d:] = np.arange(batch_size - d) + # for i, v in enumerate(data[:d]): + # storage[cursor + i] = v + # for i, v in enumerate(data[d:]): + # storage[i] = v self._cursor = batch_size - d # storage must convert the data to the appropriate format if needed self._storage[index] = data From 80e28ca70b6bf775ac1c66027a36ce0f7d6af02a Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 25 Nov 2022 17:21:48 +0100 Subject: [PATCH 4/5] ReplayBuffer extend changes --- test/test_rb.py | 19 +++++++++++++++ torchrl/data/replay_buffers/replay_buffers.py | 24 +------------------ 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 84b9e412fc1..c59a92d06a0 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -424,6 +424,25 @@ def _get_data(self, rbtype, size): raise NotImplementedError(rbtype) return data + def test_cursor_position2(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + batch1 = self._get_data(rbtype, size=5) + rb.extend(batch1) + + # Added less data than storage max size + if size > 5: + assert rb._cursor == 5 + # Added more data than storage max size + elif size < 5: + assert rb._cursor == 5 - size + # Added as data as storage max size + else: + assert rb._cursor == 0 + batch2 = self._get_data(rb_type, size=size - 1) + rb.extend(batch2) + assert rb._cursor == size - 1 + def test_add(self, rbtype, storage, size, prefetch): torch.manual_seed(0) rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 9dccba7ad07..053384e6bb8 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -195,37 +195,15 @@ def extend(self, data: Sequence[Any]): if not len(data): raise Exception("extending with empty data is not supported") with self._replay_lock: - cur_size = len(self._storage) batch_size = len(data) - # storage = self._storage - # cursor = self._cursor - if cur_size + batch_size <= self._capacity: - index = np.arange(cur_size, cur_size + batch_size) - # self._storage += data - self._cursor = (self._cursor + batch_size) % self._capacity - elif cur_size < self._capacity: - d = self._capacity - cur_size - index = np.empty(batch_size, dtype=np.int64) - index[:d] = np.arange(cur_size, self._capacity) - index[d:] = np.arange(batch_size - d) - # storage += data[:d] - # for i, v in enumerate(data[d:]): - # storage[i] = v - self._cursor = batch_size - d - elif self._cursor + batch_size <= self._capacity: + if self._cursor + batch_size <= self._capacity: index = np.arange(self._cursor, self._cursor + batch_size) - # for i, v in enumerate(data): - # storage[cursor + i] = v self._cursor = (self._cursor + batch_size) % self._capacity else: d = self._capacity - self._cursor index = np.empty(batch_size, dtype=np.int64) index[:d] = np.arange(self._cursor, self._capacity) index[d:] = np.arange(batch_size - d) - # for i, v in enumerate(data[:d]): - # storage[cursor + i] = v - # for i, v in enumerate(data[d:]): - # storage[i] = v self._cursor = batch_size - d # storage must convert the data to the appropriate format if needed self._storage[index] = data From b818af8ee8326b389ee2ceca67737c2d9338f47a Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 25 Nov 2022 17:27:09 +0100 Subject: [PATCH 5/5] test fix --- test/test_rb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index c59a92d06a0..74bce8d770d 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -362,7 +362,7 @@ def test_rb_prototype_trajectories(stack): (TensorDictPrioritizedReplayBuffer, LazyMemmapStorage), ], ) -@pytest.mark.parametrize("size", [3, 100]) +@pytest.mark.parametrize("size", [3, 5, 100]) @pytest.mark.parametrize("prefetch", [0]) class TestBuffers: _default_params_rb = {} @@ -439,7 +439,7 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): # Added as data as storage max size else: assert rb._cursor == 0 - batch2 = self._get_data(rb_type, size=size - 1) + batch2 = self._get_data(rbtype, size=size - 1) rb.extend(batch2) assert rb._cursor == size - 1