From 6667b06be4572955ffb7d7a2d73590d48a758f31 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Dec 2022 21:39:26 +0000 Subject: [PATCH 1/3] init --- tensordict/tensordict.py | 19 +------------------ test/test_tensordict.py | 24 +++++++----------------- 2 files changed, 8 insertions(+), 35 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 8dc1be200..eb7eda037 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -39,7 +39,6 @@ from tensordict.utils import ( _get_item, _is_shared, - _ndimension, _requires_grad, _set_item, _shape, @@ -829,12 +828,6 @@ def _process_input( f"={_shape(tensor)[: self.batch_dims]} with tensor {tensor}" ) - # minimum ndimension is 1 - if _ndimension(tensor) == self.ndimension() and not isinstance( - tensor, (TensorDictBase, KeyedJaggedTensor) - ): - tensor = tensor.unsqueeze(-1) - return tensor @abc.abstractmethod @@ -1293,17 +1286,7 @@ def to( def _check_new_batch_size(self, new_size: torch.Size): n = len(new_size) for key, meta_tensor in self.items_meta(): - if not meta_tensor.is_kjt() and not meta_tensor.is_tensordict(): - c1 = meta_tensor.ndimension() <= n - else: - c1 = meta_tensor.ndimension() < n - if c1 or (meta_tensor.shape[:n] != new_size): - if meta_tensor.ndimension() == n and meta_tensor.shape == new_size: - raise RuntimeError( - "TensorDict requires tensors that have at least one more " - f'dimension than the batch_size. The tensor "{key}" has shape ' - f"{meta_tensor.shape} which is the same as the new size." - ) + if meta_tensor.shape[:n] != new_size: raise RuntimeError( f"the tensor {key} has shape {meta_tensor.shape} which " f"is incompatible with the new shape {new_size}." diff --git a/test/test_tensordict.py b/test/test_tensordict.py index b54b968de..6ab363525 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -356,14 +356,14 @@ def test_permute(device): td2 = torch.permute(td1, dims=(-1, -3, -2)) assert td2.shape == torch.Size((6, 4, 5)) - assert td2["c"].shape == torch.Size((6, 4, 5, 1)) + assert td2["c"].shape == torch.Size((6, 4, 5)) td2 = torch.permute(td1, dims=(0, 1, 2)) assert td2["a"].shape == torch.Size((4, 5, 6, 9)) t = TensorDict({"a": torch.randn(3, 4, 1)}, [3, 4]) torch.permute(t, dims=(1, 0)).set("b", torch.randn(4, 3)) - assert t["b"].shape == torch.Size((3, 4, 1)) + assert t["b"].shape == torch.Size((3, 4)) torch.permute(t, dims=(1, 0)).fill_("a", 0.0) assert torch.sum(t["a"]) == torch.Tensor([0]) @@ -2205,10 +2205,8 @@ def test_batchsize_reset(): # test index td[torch.tensor([1, 2])] with pytest.raises( - RuntimeError, - match=re.escape( - "The shape torch.Size([3]) is incompatible with the index (slice(None, None, None), 0)." - ), + IndexError, + match=re.escape("too many indices for tensor of dimension 1"), ): td[:, 0] @@ -2216,14 +2214,6 @@ def test_batchsize_reset(): td = TensorDict( {"a": torch.randn(3, 4, 5, 6), "b": torch.randn(3, 4, 5)}, batch_size=[3, 4] ) - with pytest.raises( - RuntimeError, - match=re.escape( - "TensorDict requires tensors that have at least one more dimension than the batch_size" - ), - ): - td.batch_size = torch.Size([3, 4, 5]) - del td["b"] td.batch_size = torch.Size([3, 4, 5]) td.set("c", torch.randn(3, 4, 5, 6)) @@ -3011,7 +3001,7 @@ def test_memory_lock(method): class TestMakeTensorDict: def test_create_tensordict(self): tensordict = make_tensordict(a=torch.zeros(3, 4)) - assert (tensordict["a"] == torch.zeros(3, 4, 1)).all() + assert (tensordict["a"] == torch.zeros(3, 4)).all() def test_tensordict_batch_size(self): tensordict = make_tensordict() @@ -3266,7 +3256,7 @@ def test_lazy_stacked_insert(dim, index, device): assert lstd.batch_size == torch.Size(bs) assert set(lstd.keys()) == {"a"} - t = torch.zeros(*bs, 1, device=device) + t = torch.zeros(*bs, device=device) if dim == 0: t[index] = 1 @@ -3304,7 +3294,7 @@ def test_lazy_stacked_append(dim, device): assert lstd.batch_size == torch.Size(bs) assert set(lstd.keys()) == {"a"} - t = torch.zeros(*bs, 1, device=device) + t = torch.zeros(*bs, device=device) if dim == 0: t[-1] = 1 From 18650d04930757c5b81a95ce1cf4a744c0cb13a1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Dec 2022 21:51:43 +0000 Subject: [PATCH 2/3] lint --- tensordict/tensordict.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index eb7eda037..f70869895 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -36,13 +36,7 @@ import numpy as np import torch -from tensordict.utils import ( - _get_item, - _is_shared, - _requires_grad, - _set_item, - _shape, -) +from tensordict.utils import _get_item, _is_shared, _requires_grad, _set_item, _shape from torch import Tensor from torch.utils._pytree import tree_map From 17ad75a833988ca79027bc8f69f0ae9eaac383b6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Dec 2022 21:54:38 +0000 Subject: [PATCH 3/3] bf --- test/test_tensordict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 6ab363525..36b7c26e2 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2372,7 +2372,7 @@ def test_create_on_device(): a = torch.randn(2, 3) viewedtd.set("a", a) assert viewedtd.get("a").device == device - assert (a.unsqueeze(-1).to(device) == viewedtd.get("a")).all() + assert (a.to(device) == viewedtd.get("a")).all() def _remote_process(worker_id, command_pipe_child, command_pipe_parent, tensordict):