Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest
import torch
from _utils_internal import get_available_devices
from tensordict.prototype import is_tensorclass, tensorclass
from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase
from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer, TensorDictReplayBuffer
from torchrl.data.replay_buffers import (
Expand Down Expand Up @@ -185,7 +186,6 @@ def test_sample(self, rb_type, sampler, writer, storage, size):
new_data = new_data[0]

for d in new_data:
found_similar = False
for b in data:
if isinstance(b, TensorDictBase):
keys = set(d.keys()).intersection(b.keys())
Expand Down Expand Up @@ -222,6 +222,27 @@ def test_index(self, rb_type, sampler, writer, storage, size):
@pytest.mark.parametrize("shape", [[3, 4]])
@pytest.mark.parametrize("storage", [LazyTensorStorage, LazyMemmapStorage])
class TestStorages:
def _get_nested_tensorclass(self, shape):
@tensorclass
class NestedTensorClass:
key1: torch.Tensor
key2: torch.Tensor

@tensorclass
class TensorClass:
key1: torch.Tensor
key2: torch.Tensor
next: NestedTensorClass

return TensorClass(
key1=torch.ones(*shape),
key2=torch.ones(*shape),
next=NestedTensorClass(
key1=torch.ones(*shape), key2=torch.ones(*shape), batch_size=shape
),
batch_size=shape,
)

def _get_nested_td(self, shape):
nested_td = TensorDict(
{
Expand All @@ -245,6 +266,31 @@ def test_init(self, max_size, shape, storage):
mystorage._init(td)
assert mystorage._storage.shape == (max_size, *shape)

def test_set(self, max_size, shape, storage):
td = self._get_nested_td(shape)
mystorage = storage(max_size=max_size)
mystorage.set(list(range(td.shape[0])), td)
assert mystorage._storage.shape == (max_size, *shape[1:])
idx = list(range(1, td.shape[0] - 1))
tc_sample = mystorage.get(idx)
assert tc_sample.shape == torch.Size([td.shape[0] - 2, *td.shape[1:]])

def test_init_tensorclass(self, max_size, shape, storage):
tc = self._get_nested_tensorclass(shape)
mystorage = storage(max_size=max_size)
mystorage._init(tc)
assert is_tensorclass(mystorage._storage)
assert mystorage._storage.shape == (max_size, *shape)

def test_set_tensorclass(self, max_size, shape, storage):
tc = self._get_nested_tensorclass(shape)
mystorage = storage(max_size=max_size)
mystorage.set(list(range(tc.shape[0])), tc)
assert mystorage._storage.shape == (max_size, *shape[1:])
idx = list(range(1, tc.shape[0] - 1))
tc_sample = mystorage.get(idx)
assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]])


@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
Expand Down
20 changes: 20 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from tensordict.memmap import MemmapTensor
from tensordict.prototype import is_tensorclass
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl._utils import _CKPT_BACKEND
Expand Down Expand Up @@ -235,6 +236,10 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
device=self.device,
dtype=data.dtype,
)
elif is_tensorclass(data):
out = (
data.expand(self.max_size, *data.shape).clone().zero_().to(self.device)
)
else:
out = (
data.expand(self.max_size, *data.shape)
Expand Down Expand Up @@ -360,6 +365,21 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
print(
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
)
elif is_tensorclass(data):
out = (
data.expand(self.max_size, *data.shape)
.clone()
.zero_()
.memmap_(prefix=self.scratch_dir)
.to(self.device)
)
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
print(
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
)
else:
# out = TensorDict({}, [self.max_size, *data.shape])
print("The storage is being created: ")
Expand Down