diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index bcd54d433..cb520178e 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -36,14 +36,7 @@ import numpy as np import torch -from tensordict.utils import ( - _get_item, - _is_shared, - _ndimension, - _requires_grad, - _set_item, - _shape, -) +from tensordict.utils import _get_item, _is_shared, _requires_grad, _set_item, _shape from torch import Tensor from torch.utils._pytree import tree_map @@ -824,12 +817,6 @@ def _process_input( f"={_shape(tensor)[: self.batch_dims]} with tensor {tensor}" ) - # minimum ndimension is 1 - if _ndimension(tensor) == self.ndimension() and not isinstance( - tensor, (TensorDictBase, KeyedJaggedTensor) - ): - tensor = tensor.unsqueeze(-1) - return tensor @abc.abstractmethod @@ -1325,17 +1312,7 @@ def to( def _check_new_batch_size(self, new_size: torch.Size): n = len(new_size) for key, meta_tensor in self.items_meta(): - if not meta_tensor.is_kjt() and not meta_tensor.is_tensordict(): - c1 = meta_tensor.ndimension() <= n - else: - c1 = meta_tensor.ndimension() < n - if c1 or (meta_tensor.shape[:n] != new_size): - if meta_tensor.ndimension() == n and meta_tensor.shape == new_size: - raise RuntimeError( - "TensorDict requires tensors that have at least one more " - f'dimension than the batch_size. The tensor "{key}" has shape ' - f"{meta_tensor.shape} which is the same as the new size." - ) + if meta_tensor.shape[:n] != new_size: raise RuntimeError( f"the tensor {key} has shape {meta_tensor.shape} which " f"is incompatible with the new shape {new_size}." diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 4c567a4de..1fcbcd918 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -357,14 +357,14 @@ def test_permute(device): td2 = torch.permute(td1, dims=(-1, -3, -2)) assert td2.shape == torch.Size((6, 4, 5)) - assert td2["c"].shape == torch.Size((6, 4, 5, 1)) + assert td2["c"].shape == torch.Size((6, 4, 5)) td2 = torch.permute(td1, dims=(0, 1, 2)) assert td2["a"].shape == torch.Size((4, 5, 6, 9)) t = TensorDict({"a": torch.randn(3, 4, 1)}, [3, 4]) torch.permute(t, dims=(1, 0)).set("b", torch.randn(4, 3)) - assert t["b"].shape == torch.Size((3, 4, 1)) + assert t["b"].shape == torch.Size((3, 4)) torch.permute(t, dims=(1, 0)).fill_("a", 0.0) assert torch.sum(t["a"]) == torch.Tensor([0]) @@ -2290,10 +2290,8 @@ def test_batchsize_reset(): # test index td[torch.tensor([1, 2])] with pytest.raises( - RuntimeError, - match=re.escape( - "The shape torch.Size([3]) is incompatible with the index (slice(None, None, None), 0)." - ), + IndexError, + match=re.escape("too many indices for tensor of dimension 1"), ): td[:, 0] @@ -2301,14 +2299,6 @@ def test_batchsize_reset(): td = TensorDict( {"a": torch.randn(3, 4, 5, 6), "b": torch.randn(3, 4, 5)}, batch_size=[3, 4] ) - with pytest.raises( - RuntimeError, - match=re.escape( - "TensorDict requires tensors that have at least one more dimension than the batch_size" - ), - ): - td.batch_size = torch.Size([3, 4, 5]) - del td["b"] td.batch_size = torch.Size([3, 4, 5]) td.set("c", torch.randn(3, 4, 5, 6)) @@ -2467,7 +2457,7 @@ def test_create_on_device(): a = torch.randn(2, 3) viewedtd.set("a", a) assert viewedtd.get("a").device == device - assert (a.unsqueeze(-1).to(device) == viewedtd.get("a")).all() + assert (a.to(device) == viewedtd.get("a")).all() def _remote_process(worker_id, command_pipe_child, command_pipe_parent, tensordict): @@ -3082,7 +3072,7 @@ def test_memory_lock(method): class TestMakeTensorDict: def test_create_tensordict(self): tensordict = make_tensordict(a=torch.zeros(3, 4)) - assert (tensordict["a"] == torch.zeros(3, 4, 1)).all() + assert (tensordict["a"] == torch.zeros(3, 4)).all() def test_tensordict_batch_size(self): tensordict = make_tensordict() @@ -3337,7 +3327,7 @@ def test_lazy_stacked_insert(dim, index, device): assert lstd.batch_size == torch.Size(bs) assert set(lstd.keys()) == {"a"} - t = torch.zeros(*bs, 1, device=device) + t = torch.zeros(*bs, device=device) if dim == 0: t[index] = 1 @@ -3375,7 +3365,7 @@ def test_lazy_stacked_append(dim, device): assert lstd.batch_size == torch.Size(bs) assert set(lstd.keys()) == {"a"} - t = torch.zeros(*bs, 1, device=device) + t = torch.zeros(*bs, device=device) if dim == 0: t[-1] = 1