Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def _stack(
from tensordict.tensorclass import NonTensorData

return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)
elif is_tc:
if is_tc:
tc_type = type(list_of_tensordicts[0])
list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts]

Expand Down
8 changes: 4 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8559,13 +8559,13 @@ def _convert_to_tensor(
castable = None
if isinstance(array, (float, int, bool)):
castable = True
elif isinstance(array, np.ndarray) and array.dtype.names is not None:
return TensorDictBase.from_struct_array(array, device=self.device)
elif isinstance(array, np.ndarray):
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, np.bool_):
castable = True
array = array.item()
elif isinstance(array, (np.ndarray, np.number)):
if array.dtype.names is not None:
return TensorDictBase.from_struct_array(array, device=self.device)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, (list, tuple)):
array = np.asarray(array)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10386,7 +10386,7 @@ def test_memmap_stack_updates(self, tmpdir):
data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0)
assert is_non_tensor(data)
data = torch.stack([data] * 3)
assert is_non_tensor(data)
assert is_non_tensor(data), data
data = data.clone()
assert is_non_tensor(data)
data.memmap_(tmpdir)
Expand Down