diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 463629fa5..455f93ee4 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -423,7 +423,7 @@ def _stack( from tensordict.tensorclass import NonTensorData return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) - elif is_tc: + if is_tc: tc_type = type(list_of_tensordicts[0]) list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts] diff --git a/tensordict/base.py b/tensordict/base.py index f1493e1af..79d5ab4ed 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8559,13 +8559,13 @@ def _convert_to_tensor( castable = None if isinstance(array, (float, int, bool)): castable = True - elif isinstance(array, np.ndarray) and array.dtype.names is not None: - return TensorDictBase.from_struct_array(array, device=self.device) - elif isinstance(array, np.ndarray): - castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif isinstance(array, np.bool_): castable = True array = array.item() + elif isinstance(array, (np.ndarray, np.number)): + if array.dtype.names is not None: + return TensorDictBase.from_struct_array(array, device=self.device) + castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif isinstance(array, (list, tuple)): array = np.asarray(array) castable = array.dtype.kind in ("c", "i", "f", "b", "u") diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 5ff085150..34c068693 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -10386,7 +10386,7 @@ def test_memmap_stack_updates(self, tmpdir): data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0) assert is_non_tensor(data) data = torch.stack([data] * 3) - assert is_non_tensor(data) + assert is_non_tensor(data), data data = data.clone() assert is_non_tensor(data) data.memmap_(tmpdir)