From eeed744edb853f4d6947a19ecb4632c2ac6de6ee Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 17 Sep 2024 10:46:42 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- tensordict/_torch_func.py | 8 ++++---- tensordict/base.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 463629fa5..062a59fe5 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -419,11 +419,11 @@ def _stack( if not len(list_of_tensordicts): raise RuntimeError("list_of_tensordicts cannot be empty") is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) - if all(is_non_tensor(td) for td in list_of_tensordicts): - from tensordict.tensorclass import NonTensorData + if is_tc: + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData - return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) - elif is_tc: + return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) tc_type = type(list_of_tensordicts[0]) list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts] diff --git a/tensordict/base.py b/tensordict/base.py index f1493e1af..79d5ab4ed 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8559,13 +8559,13 @@ def _convert_to_tensor( castable = None if isinstance(array, (float, int, bool)): castable = True - elif isinstance(array, np.ndarray) and array.dtype.names is not None: - return TensorDictBase.from_struct_array(array, device=self.device) - elif isinstance(array, np.ndarray): - castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif isinstance(array, np.bool_): castable = True array = array.item() + elif isinstance(array, (np.ndarray, np.number)): + if array.dtype.names is not None: + return TensorDictBase.from_struct_array(array, device=self.device) + castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif isinstance(array, (list, tuple)): array = np.asarray(array) castable = array.dtype.kind in ("c", "i", "f", "b", "u") From 4a0b18cf05a0f794bf6f4966dd3c95bbd3ee3df4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 17 Sep 2024 10:55:24 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- tensordict/_torch_func.py | 8 ++++---- test/test_tensordict.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 062a59fe5..455f93ee4 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -419,11 +419,11 @@ def _stack( if not len(list_of_tensordicts): raise RuntimeError("list_of_tensordicts cannot be empty") is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) - if is_tc: - if all(is_non_tensor(td) for td in list_of_tensordicts): - from tensordict.tensorclass import NonTensorData + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData - return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + if is_tc: tc_type = type(list_of_tensordicts[0]) list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts] diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 5ff085150..34c068693 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -10386,7 +10386,7 @@ def test_memmap_stack_updates(self, tmpdir): data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0) assert is_non_tensor(data) data = torch.stack([data] * 3) - assert is_non_tensor(data) + assert is_non_tensor(data), data data = data.clone() assert is_non_tensor(data) data.memmap_(tmpdir)