From 6a7dd630861d16f2f61f4341e0ec007fa9c4c9c3 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 20 Dec 2022 11:41:31 +0000 Subject: [PATCH 1/4] Support tensorclass in Lazy{Tensor,Memmap}Storage --- torchrl/data/replay_buffers/storages.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 1bba001c1c1..c7f53ba12c9 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -11,6 +11,7 @@ import torch from tensordict.memmap import MemmapTensor +from tensordict.prototype import is_tensorclass from tensordict.tensordict import TensorDict, TensorDictBase from torchrl._utils import _CKPT_BACKEND @@ -235,6 +236,10 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: device=self.device, dtype=data.dtype, ) + elif is_tensorclass(data): + out = ( + data.expand(self.max_size, *data.shape).clone().zero_().to(self.device) + ) else: out = ( data.expand(self.max_size, *data.shape) @@ -360,6 +365,21 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: print( f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." ) + elif is_tensorclass(data): + out = ( + data.expand(self.max_size, *data.shape) + .clone() + .zero_() + .memmap_(prefix=self.scratch_dir) + .to(self.device) + ) + for key, tensor in sorted( + out.items(include_nested=True, leaves_only=True), key=str + ): + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 + print( + f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." + ) else: # out = TensorDict({}, [self.max_size, *data.shape]) print("The storage is being created: ") From dea40c616058c4d4fb46ea958574baffabb16543 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 20 Dec 2022 11:57:39 +0000 Subject: [PATCH 2/4] Add test for lazy tensorclass storage --- test/test_rb.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_rb.py b/test/test_rb.py index 5ed778e0867..d5b3c814d6a 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -12,6 +12,7 @@ import pytest import torch from _utils_internal import get_available_devices +from tensordict.prototype import is_tensorclass, tensorclass from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers import ( @@ -222,6 +223,27 @@ def test_index(self, rb_type, sampler, writer, storage, size): @pytest.mark.parametrize("shape", [[3, 4]]) @pytest.mark.parametrize("storage", [LazyTensorStorage, LazyMemmapStorage]) class TestStorages: + def _get_nested_tensorclass(self, shape): + @tensorclass + class NestedTensorClass: + key1: torch.Tensor + key2: torch.Tensor + + @tensorclass + class TensorClass: + key1: torch.Tensor + key2: torch.Tensor + next: NestedTensorClass + + return TensorClass( + key1=torch.ones(*shape), + key2=torch.ones(*shape), + next=NestedTensorClass( + key1=torch.ones(*shape), key2=torch.ones(*shape), batch_size=shape + ), + batch_size=shape, + ) + def _get_nested_td(self, shape): nested_td = TensorDict( { @@ -245,6 +267,13 @@ def test_init(self, max_size, shape, storage): mystorage._init(td) assert mystorage._storage.shape == (max_size, *shape) + def test_init_tensorclass(self, max_size, shape, storage): + tc = self._get_nested_tensorclass(shape) + mystorage = storage(max_size=max_size) + mystorage._init(tc) + assert is_tensorclass(mystorage._storage) + assert mystorage._storage.shape == (max_size, *shape) + @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) From 4761fab632675cd87734bc454d91f16d3ceb2ad2 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 20 Dec 2022 13:38:16 +0000 Subject: [PATCH 3/4] Rerun CI From 741dbdcdcf9f968131625f466274f175a00514a3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Jan 2023 07:48:04 +0000 Subject: [PATCH 4/4] some more tests --- test/test_rb.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/test/test_rb.py b/test/test_rb.py index d5b3c814d6a..f6a8316d72a 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -186,7 +186,6 @@ def test_sample(self, rb_type, sampler, writer, storage, size): new_data = new_data[0] for d in new_data: - found_similar = False for b in data: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) @@ -267,6 +266,15 @@ def test_init(self, max_size, shape, storage): mystorage._init(td) assert mystorage._storage.shape == (max_size, *shape) + def test_set(self, max_size, shape, storage): + td = self._get_nested_td(shape) + mystorage = storage(max_size=max_size) + mystorage.set(list(range(td.shape[0])), td) + assert mystorage._storage.shape == (max_size, *shape[1:]) + idx = list(range(1, td.shape[0] - 1)) + tc_sample = mystorage.get(idx) + assert tc_sample.shape == torch.Size([td.shape[0] - 2, *td.shape[1:]]) + def test_init_tensorclass(self, max_size, shape, storage): tc = self._get_nested_tensorclass(shape) mystorage = storage(max_size=max_size) @@ -274,6 +282,15 @@ def test_init_tensorclass(self, max_size, shape, storage): assert is_tensorclass(mystorage._storage) assert mystorage._storage.shape == (max_size, *shape) + def test_set_tensorclass(self, max_size, shape, storage): + tc = self._get_nested_tensorclass(shape) + mystorage = storage(max_size=max_size) + mystorage.set(list(range(tc.shape[0])), tc) + assert mystorage._storage.shape == (max_size, *shape[1:]) + idx = list(range(1, tc.shape[0] - 1)) + tc_sample = mystorage.get(idx) + assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]]) + @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False])