From 9e30bff4112f25546a1bdbb90709c6bf57479a44 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Oct 2025 12:26:20 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_specs.py | 20 ++++++++++++++++++++ torchrl/data/tensor_specs.py | 22 ++++++++++++++++------ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 470aa3b4b0b..0edbd9de88e 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -4585,6 +4585,26 @@ def test_names_repr(self): assert "Composite" in repr_str assert "obs" in repr_str + def test_zero_create_names(self): + """Test that creating tensors with 'zero' propagates names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + td = spec.zero() + td.names = ["batch"] + + def test_rand_create_names(self): + """Test that creating tensors with 'rand' propagates names.""" + spec = Composite( + {"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, + shape=(10,), + names=["batch"], + ) + td = spec.rand() + td.names = ["batch"] + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index f6eca44be41..30240c48c72 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -5740,16 +5740,22 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase: for key, item in self.items(): if item is not None: _dict[key] = item.rand(shape) - if self.data_cls is None: - cls = TensorDict + + cls = self.data_cls if self.data_cls is not None else TensorDict + if cls is not TensorDict: + kwargs = {} + if self._td_dim_names is not None: + warnings.warn(f"names for cls {cls} is not supported for rand.") else: - cls = self.data_cls + kwargs = {"names": self._td_dim_names} + # No need to run checks since we know Composite is compliant with # TensorDict requirements return cls.from_dict( _dict, batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]), device=self.device, + **kwargs, ) def keys( @@ -6017,10 +6023,13 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase: except RuntimeError: device = self._device - if self.data_cls is not None: - cls = self.data_cls + cls = self.data_cls if self.data_cls is not None else TensorDict + if cls is not TensorDict: + kwargs = {} + if self._td_dim_names is not None: + warnings.warn(f"names for cls {cls} is not supported for zero.") else: - cls = TensorDict + kwargs = {"names": self._td_dim_names} return cls.from_dict( { @@ -6030,6 +6039,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase: }, batch_size=_size([*shape, *self._safe_shape]), device=device, + **kwargs, ) def __eq__(self, other: object) -> bool: