Skip to content
Merged
85 changes: 72 additions & 13 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,19 @@ def test_mask_td(device):
"key2": torch.randn(4, 5, 10, device=device),
}
mask = torch.zeros(4, 5, dtype=torch.bool, device=device).bernoulli_()
mask_list = mask.cpu().numpy().tolist()
td = TensorDict(batch_size=(4, 5), source=d)

td_masked = torch.masked_select(td, mask)
td_masked1 = td[mask_list]
assert len(td_masked.get("key1")) == td_masked.shape[0]
assert len(td_masked1.get("key1")) == td_masked1.shape[0]

mask_list = [False, True, False, True]

td_masked2 = td[mask_list, 0]
torch.testing.assert_allclose(td.get("key1")[mask_list, 0], td_masked2.get("key1"))
torch.testing.assert_allclose(td.get("key2")[mask_list, 0], td_masked2.get("key2"))


@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -782,6 +792,49 @@ def test_masking(self, td_name, device):
assert td_masked.batch_size[0] == mask.sum()
assert td_masked.batch_dims == 1

mask_list = mask.cpu().numpy().tolist()
td_masked3 = td[mask_list]
assert_allclose_td(td_masked3, td_masked2)
assert td_masked3.batch_size[0] == mask.sum()
assert td_masked3.batch_dims == 1

@pytest.mark.parametrize("from_list", [True, False])
def test_masking_set(self, td_name, device, from_list):
def zeros_like(item, n, d):
if isinstance(item, (MemmapTensor, torch.Tensor)):
return torch.zeros(n, *item.shape[d:], dtype=item.dtype, device=device)
elif isinstance(item, _TensorDict):
batch_size = item.batch_size
batch_size = [n, *batch_size[d:]]
out = TensorDict(
{k: zeros_like(_item, n, d) for k, _item in item.items()},
batch_size,
device=device,
)
return out

torch.manual_seed(1)
td = getattr(self, td_name)(device)
mask = torch.zeros(td.batch_size, dtype=torch.bool, device=device).bernoulli_(
0.8
)
n = mask.sum()
d = td.ndimension()
pseudo_td = TensorDict(
{k: zeros_like(item, n, d) for k, item in td.items()}, [n], device=device
)
if from_list:
td_mask = mask.cpu().numpy().tolist()
else:
td_mask = mask
if td_name == "stacked_td":
with pytest.raises(RuntimeError, match="is not supported"):
td[td_mask] = pseudo_td
else:
td[td_mask] = pseudo_td
for k, item in td.items():
assert (item[mask] == 0).all()

@pytest.mark.skipif(
torch.cuda.device_count() == 0, reason="No cuda device detected"
)
Expand Down Expand Up @@ -1779,15 +1832,9 @@ def test_stack_keys():
td.get("e")


def test_getitem_batch_size():
shape = [
10,
7,
11,
5,
]
mocking_tensor = torch.zeros(*shape)
for idx in [
@pytest.mark.parametrize(
"idx",
[
(slice(None),),
slice(None),
(3, 4),
Expand All @@ -1800,10 +1847,22 @@ def test_getitem_batch_size():
torch.tensor([0, 10, 2]),
torch.tensor([2, 4, 1]),
),
]:
expected_shape = mocking_tensor[idx].shape
resulting_shape = _getitem_batch_size(shape, idx)
assert expected_shape == resulting_shape, idx
torch.zeros(10, 7, 11, 5, dtype=torch.bool).bernoulli_(),
torch.zeros(10, 7, 11, dtype=torch.bool).bernoulli_(),
(0, torch.zeros(7, dtype=torch.bool).bernoulli_()),
],
)
def test_getitem_batch_size(idx):
shape = [
10,
7,
11,
5,
]
mocking_tensor = torch.zeros(*shape)
expected_shape = mocking_tensor[idx].shape
resulting_shape = _getitem_batch_size(shape, idx)
assert expected_shape == resulting_shape, (idx, expected_shape, resulting_shape)


@pytest.mark.parametrize("device", get_available_devices())
Expand Down
18 changes: 18 additions & 0 deletions torchrl/data/tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ def _load_item(
if idx is not None:
if isinstance(idx, torch.Tensor):
idx = idx.cpu()
elif isinstance(idx, tuple) and any(
isinstance(sub_index, torch.Tensor) for sub_index in idx
):
idx = tuple(
sub_index.cpu()
if isinstance(sub_index, torch.Tensor)
else sub_index
for sub_index in idx
)
memmap_array = memmap_array[idx]
out = self._np_to_tensor(memmap_array, from_numpy=from_numpy)
if (
Expand Down Expand Up @@ -465,6 +474,15 @@ def __setitem__(self, idx: INDEX_TYPING, value: torch.Tensor):
if self.device == torch.device("cpu"):
self._load_item()[idx] = value
else:
if isinstance(idx, torch.Tensor):
idx = idx.cpu()
elif isinstance(idx, tuple) and any(
isinstance(_idx, torch.Tensor) for _idx in idx
):
idx = tuple(
_idx.cpu() if isinstance(_idx, torch.Tensor) else _idx
for _idx in idx
)
self.memmap_array[idx] = to_numpy(value)

def __setstate__(self, state: dict) -> None:
Expand Down
78 changes: 75 additions & 3 deletions torchrl/data/tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,10 @@ def masked_select(self, mask: torch.Tensor) -> _TensorDict:
"""
d = dict()
for key, value in self.items():
mask_expand = mask.squeeze(-1)
while mask.ndimension() > self.batch_dims:
mask_expand = mask.squeeze(-1)
else:
mask_expand = mask
value_select = value[mask_expand]
d[key] = value_select
dim = int(mask.sum().item())
Expand Down Expand Up @@ -1471,6 +1474,17 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
>>> print(td.get("a")) # values have not changed

"""
if isinstance(idx, list):
idx = torch.tensor(idx, device=self.device)
if isinstance(idx, tuple) and any(
isinstance(sub_index, list) for sub_index in idx
):
idx = tuple(
torch.tensor(sub_index, device=self.device)
if isinstance(sub_index, list)
else sub_index
for sub_index in idx
)
if isinstance(idx, str):
return self.get(idx)
if isinstance(idx, tuple) and sum(
Expand All @@ -1487,8 +1501,8 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
return out[idx[1:]]
else:
return out
elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool:
return self.masked_select(idx)
# elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool:
# return self.masked_select(idx)

contiguous_input = (int, slice)
return_simple_view = isinstance(idx, contiguous_input) or (
Expand Down Expand Up @@ -1521,6 +1535,17 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
def __setitem__(self, index: INDEX_TYPING, value: _TensorDict) -> None:
if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index):
index = convert_ellipsis_to_idx(index, self.batch_size)
if isinstance(index, list):
index = torch.tensor(index, device=self.device)
if isinstance(index, tuple) and any(
isinstance(sub_index, list) for sub_index in index
):
index = tuple(
torch.tensor(sub_index, device=self.device)
if isinstance(sub_index, list)
else sub_index
for sub_index in index
)
if isinstance(index, tuple) and sum(
isinstance(_index, str) for _index in index
) not in [len(index), 0]:
Expand Down Expand Up @@ -3291,13 +3316,49 @@ def select(self, *keys: str, inplace: bool = False) -> _TensorDict:
stack_dim=self.stack_dim,
)

def __setitem__(self, item: INDEX_TYPING, value: _TensorDict) -> _TensorDict:
if isinstance(item, list):
item = torch.tensor(item, device=self.device)
if isinstance(item, tuple) and any(
isinstance(sub_index, list) for sub_index in item
):
item = tuple(
torch.tensor(sub_index, device=self.device)
if isinstance(sub_index, list)
else sub_index
for sub_index in item
)
if (isinstance(item, torch.Tensor) and item.dtype is torch.bool) or (
isinstance(item, tuple)
and any(
isinstance(_item, torch.Tensor) and _item.dtype is torch.bool
for _item in item
)
):
raise RuntimeError(
"setting values to a LazyStackTensorDict using boolean values is not supported yet."
"If this feature is needed, feel free to raise an issue on github."
)
return super().__setitem__(item, value)

def __getitem__(self, item: INDEX_TYPING) -> _TensorDict:
if item is Ellipsis or (isinstance(item, tuple) and Ellipsis in item):
item = convert_ellipsis_to_idx(item, self.batch_size)
if isinstance(item, tuple) and sum(
isinstance(_item, str) for _item in item
) not in [len(item), 0]:
raise IndexError(_STR_MIXED_INDEX_ERROR)
if isinstance(item, list):
item = torch.tensor(item, device=self.device)
if isinstance(item, tuple) and any(
isinstance(sub_index, list) for sub_index in item
):
item = tuple(
torch.tensor(sub_index, device=self.device)
if isinstance(sub_index, list)
else sub_index
for sub_index in item
)
if isinstance(item, str):
return self.get(item)
elif isinstance(item, tuple) and all(
Expand Down Expand Up @@ -3761,6 +3822,17 @@ def __reduce__(self, *args, **kwargs):
return super().__reduce__(*args, **kwargs)

def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict:
if isinstance(idx, list):
idx = torch.tensor(idx, device=self.device)
if isinstance(idx, tuple) and any(
isinstance(sub_index, list) for sub_index in idx
):
idx = tuple(
torch.tensor(sub_index, device=self.device)
if isinstance(sub_index, list)
else sub_index
for sub_index in idx
)
if idx is Ellipsis or (isinstance(idx, tuple) and Ellipsis in idx):
idx = convert_ellipsis_to_idx(idx, self.batch_size)

Expand Down
7 changes: 6 additions & 1 deletion torchrl/data/tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def _getitem_batch_size(
items = items[0]
if isinstance(items, int):
return shape[1:]
if isinstance(items, torch.Tensor) and items.dtype is torch.bool:
return torch.Size([items.sum(), *shape[items.ndimension() :]])
if (
isinstance(items, (torch.Tensor, np.ndarray)) and len(items.shape) <= 1
) or isinstance(items, list):
Expand Down Expand Up @@ -78,7 +80,10 @@ def _getitem_batch_size(
v = len(range(*_item.indices(batch)))
elif isinstance(_item, (list, torch.Tensor, np.ndarray)):
batch = next(iter_bs)
v = len(_item)
if isinstance(_item, torch.Tensor) and _item.dtype is torch.bool:
v = _item.sum()
else:
v = len(_item)
elif _item is None:
v = 1
elif isinstance(_item, Number):
Expand Down