diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 8fd7693af0a..c28c57d33b5 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1556,6 +1556,50 @@ def test_requires_grad(device): td5 = SavedTensorDict(tensordicts[5]) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "td_type", ["tensordict", "view", "unsqueeze", "squeeze", "saved", "stack"] +) +@pytest.mark.parametrize("update", [True, False]) +def test_filling_empty_tensordict(device, td_type, update): + if td_type == "tensordict": + td = TensorDict({}, batch_size=[16], device=device) + elif td_type == "view": + td = TensorDict({}, batch_size=[4, 4], device=device).view(-1) + elif td_type == "unsqueeze": + td = TensorDict({}, batch_size=[16], device=device).unsqueeze(-1) + elif td_type == "squeeze": + td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1) + elif td_type == "saved": + td = TensorDict({}, batch_size=[16], device=device).to(SavedTensorDict) + elif td_type == "stack": + td = torch.stack([TensorDict({}, [], device=device) for _ in range(16)], 0) + else: + raise NotImplementedError + + for i in range(16): + other_td = TensorDict({"a": torch.randn(10), "b": torch.ones(1)}, []) + if td_type == "unsqueeze": + other_td = other_td.unsqueeze(-1).to_tensordict() + if update: + subtd = td.get_sub_tensordict(i) + subtd.update(other_td, inplace=True) + else: + td[i] = other_td + + assert td.device == device + assert td.get("a").device == device + assert (td.get("b") == 1).all() + if td_type == "view": + assert td._source["a"].shape == torch.Size([4, 4, 10]) + elif td_type == "unsqueeze": + assert td._source["a"].shape == torch.Size([16, 10]) + elif td_type == "squeeze": + assert td._source["a"].shape == torch.Size([16, 1, 10]) + elif td_type == "stack": + assert (td[-1] == other_td.to(device)).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index 423fcd5f957..d17305973b6 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -1299,8 +1299,14 @@ def __setitem__(self, index: INDEX_TYPING, value: _TensorDict) -> None: f"(batch_size = {self.batch_size}, index={index}), " f"which differs from the source batch size {value.batch_size}" ) + keys = set(self.keys()) + if not all(key in keys for key in value.keys()): + subtd = self.get_sub_tensordict(index) for key, item in value.items(): - self.set_at_(key, item, index) + if key in keys: + self.set_at_(key, item, index) + else: + subtd.set(key, item) def __delitem__(self, index: INDEX_TYPING) -> _TensorDict: if isinstance(index, str): @@ -2246,9 +2252,10 @@ def set( ) -> _TensorDict: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") - if inplace and key in self.keys(): + keys = set(self.keys()) + if inplace and key in keys: return self.set_(key, tensor) - elif key in self.keys(): + elif key in keys: raise RuntimeError( "Calling `SubTensorDict.set(key, value, inplace=False)` is prohibited for existing tensors. " "Consider calling `SubTensorDict.set_(...)` or cloning your tensordict first." @@ -2265,7 +2272,7 @@ def set( device=self.device, ) - if self.is_shared(): + if self.is_shared() and self.device == torch.device("cpu"): tensor_expand.share_memory_() elif self.is_memmap(): tensor_expand = MemmapTensor(tensor_expand)