Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,50 @@ def test_requires_grad(device):
td5 = SavedTensorDict(tensordicts[5])


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"td_type", ["tensordict", "view", "unsqueeze", "squeeze", "saved", "stack"]
)
@pytest.mark.parametrize("update", [True, False])
def test_filling_empty_tensordict(device, td_type, update):
if td_type == "tensordict":
td = TensorDict({}, batch_size=[16], device=device)
elif td_type == "view":
td = TensorDict({}, batch_size=[4, 4], device=device).view(-1)
elif td_type == "unsqueeze":
td = TensorDict({}, batch_size=[16], device=device).unsqueeze(-1)
elif td_type == "squeeze":
td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1)
elif td_type == "saved":
td = TensorDict({}, batch_size=[16], device=device).to(SavedTensorDict)
elif td_type == "stack":
td = torch.stack([TensorDict({}, [], device=device) for _ in range(16)], 0)
else:
raise NotImplementedError

for i in range(16):
other_td = TensorDict({"a": torch.randn(10), "b": torch.ones(1)}, [])
if td_type == "unsqueeze":
other_td = other_td.unsqueeze(-1).to_tensordict()
if update:
subtd = td.get_sub_tensordict(i)
subtd.update(other_td, inplace=True)
else:
td[i] = other_td

assert td.device == device
assert td.get("a").device == device
assert (td.get("b") == 1).all()
if td_type == "view":
assert td._source["a"].shape == torch.Size([4, 4, 10])
elif td_type == "unsqueeze":
assert td._source["a"].shape == torch.Size([16, 10])
elif td_type == "squeeze":
assert td._source["a"].shape == torch.Size([16, 1, 10])
elif td_type == "stack":
assert (td[-1] == other_td.to(device)).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
15 changes: 11 additions & 4 deletions torchrl/data/tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,8 +1299,14 @@ def __setitem__(self, index: INDEX_TYPING, value: _TensorDict) -> None:
f"(batch_size = {self.batch_size}, index={index}), "
f"which differs from the source batch size {value.batch_size}"
)
keys = set(self.keys())
if not all(key in keys for key in value.keys()):
subtd = self.get_sub_tensordict(index)
for key, item in value.items():
self.set_at_(key, item, index)
if key in keys:
self.set_at_(key, item, index)
else:
subtd.set(key, item)

def __delitem__(self, index: INDEX_TYPING) -> _TensorDict:
if isinstance(index, str):
Expand Down Expand Up @@ -2246,9 +2252,10 @@ def set(
) -> _TensorDict:
if self.is_locked:
raise RuntimeError("Cannot modify immutable TensorDict")
if inplace and key in self.keys():
keys = set(self.keys())
if inplace and key in keys:
return self.set_(key, tensor)
elif key in self.keys():
elif key in keys:
raise RuntimeError(
"Calling `SubTensorDict.set(key, value, inplace=False)` is prohibited for existing tensors. "
"Consider calling `SubTensorDict.set_(...)` or cloning your tensordict first."
Expand All @@ -2265,7 +2272,7 @@ def set(
device=self.device,
)

if self.is_shared():
if self.is_shared() and self.device == torch.device("cpu"):
tensor_expand.share_memory_()
elif self.is_memmap():
tensor_expand = MemmapTensor(tensor_expand)
Expand Down