diff --git a/test/test_rb.py b/test/test_rb.py index 5ed778e0867..f6a8316d72a 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -12,6 +12,7 @@ import pytest import torch from _utils_internal import get_available_devices +from tensordict.prototype import is_tensorclass, tensorclass from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers import ( @@ -185,7 +186,6 @@ def test_sample(self, rb_type, sampler, writer, storage, size): new_data = new_data[0] for d in new_data: - found_similar = False for b in data: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) @@ -222,6 +222,27 @@ def test_index(self, rb_type, sampler, writer, storage, size): @pytest.mark.parametrize("shape", [[3, 4]]) @pytest.mark.parametrize("storage", [LazyTensorStorage, LazyMemmapStorage]) class TestStorages: + def _get_nested_tensorclass(self, shape): + @tensorclass + class NestedTensorClass: + key1: torch.Tensor + key2: torch.Tensor + + @tensorclass + class TensorClass: + key1: torch.Tensor + key2: torch.Tensor + next: NestedTensorClass + + return TensorClass( + key1=torch.ones(*shape), + key2=torch.ones(*shape), + next=NestedTensorClass( + key1=torch.ones(*shape), key2=torch.ones(*shape), batch_size=shape + ), + batch_size=shape, + ) + def _get_nested_td(self, shape): nested_td = TensorDict( { @@ -245,6 +266,31 @@ def test_init(self, max_size, shape, storage): mystorage._init(td) assert mystorage._storage.shape == (max_size, *shape) + def test_set(self, max_size, shape, storage): + td = self._get_nested_td(shape) + mystorage = storage(max_size=max_size) + mystorage.set(list(range(td.shape[0])), td) + assert mystorage._storage.shape == (max_size, *shape[1:]) + idx = list(range(1, td.shape[0] - 1)) + tc_sample = mystorage.get(idx) + assert tc_sample.shape == torch.Size([td.shape[0] - 2, *td.shape[1:]]) + + def test_init_tensorclass(self, max_size, shape, storage): + tc = self._get_nested_tensorclass(shape) + mystorage = storage(max_size=max_size) + mystorage._init(tc) + assert is_tensorclass(mystorage._storage) + assert mystorage._storage.shape == (max_size, *shape) + + def test_set_tensorclass(self, max_size, shape, storage): + tc = self._get_nested_tensorclass(shape) + mystorage = storage(max_size=max_size) + mystorage.set(list(range(tc.shape[0])), tc) + assert mystorage._storage.shape == (max_size, *shape[1:]) + idx = list(range(1, tc.shape[0] - 1)) + tc_sample = mystorage.get(idx) + assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]]) + @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 1bba001c1c1..c7f53ba12c9 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -11,6 +11,7 @@ import torch from tensordict.memmap import MemmapTensor +from tensordict.prototype import is_tensorclass from tensordict.tensordict import TensorDict, TensorDictBase from torchrl._utils import _CKPT_BACKEND @@ -235,6 +236,10 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: device=self.device, dtype=data.dtype, ) + elif is_tensorclass(data): + out = ( + data.expand(self.max_size, *data.shape).clone().zero_().to(self.device) + ) else: out = ( data.expand(self.max_size, *data.shape) @@ -360,6 +365,21 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: print( f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." ) + elif is_tensorclass(data): + out = ( + data.expand(self.max_size, *data.shape) + .clone() + .zero_() + .memmap_(prefix=self.scratch_dir) + .to(self.device) + ) + for key, tensor in sorted( + out.items(include_nested=True, leaves_only=True), key=str + ): + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 + print( + f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." + ) else: # out = TensorDict({}, [self.max_size, *data.shape]) print("The storage is being created: ")