Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@
import numpy as np
import torch

from tensordict.utils import (
_get_item,
_is_shared,
_ndimension,
_requires_grad,
_set_item,
_shape,
)
from tensordict.utils import _get_item, _is_shared, _requires_grad, _set_item, _shape
from torch import Tensor
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -824,12 +817,6 @@ def _process_input(
f"={_shape(tensor)[: self.batch_dims]} with tensor {tensor}"
)

# minimum ndimension is 1
if _ndimension(tensor) == self.ndimension() and not isinstance(
tensor, (TensorDictBase, KeyedJaggedTensor)
):
tensor = tensor.unsqueeze(-1)

return tensor

@abc.abstractmethod
Expand Down Expand Up @@ -1325,17 +1312,7 @@ def to(
def _check_new_batch_size(self, new_size: torch.Size):
n = len(new_size)
for key, meta_tensor in self.items_meta():
if not meta_tensor.is_kjt() and not meta_tensor.is_tensordict():
c1 = meta_tensor.ndimension() <= n
else:
c1 = meta_tensor.ndimension() < n
if c1 or (meta_tensor.shape[:n] != new_size):
if meta_tensor.ndimension() == n and meta_tensor.shape == new_size:
raise RuntimeError(
"TensorDict requires tensors that have at least one more "
f'dimension than the batch_size. The tensor "{key}" has shape '
f"{meta_tensor.shape} which is the same as the new size."
)
if meta_tensor.shape[:n] != new_size:
raise RuntimeError(
f"the tensor {key} has shape {meta_tensor.shape} which "
f"is incompatible with the new shape {new_size}."
Expand Down
26 changes: 8 additions & 18 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,14 @@ def test_permute(device):

td2 = torch.permute(td1, dims=(-1, -3, -2))
assert td2.shape == torch.Size((6, 4, 5))
assert td2["c"].shape == torch.Size((6, 4, 5, 1))
assert td2["c"].shape == torch.Size((6, 4, 5))

td2 = torch.permute(td1, dims=(0, 1, 2))
assert td2["a"].shape == torch.Size((4, 5, 6, 9))

t = TensorDict({"a": torch.randn(3, 4, 1)}, [3, 4])
torch.permute(t, dims=(1, 0)).set("b", torch.randn(4, 3))
assert t["b"].shape == torch.Size((3, 4, 1))
assert t["b"].shape == torch.Size((3, 4))

torch.permute(t, dims=(1, 0)).fill_("a", 0.0)
assert torch.sum(t["a"]) == torch.Tensor([0])
Expand Down Expand Up @@ -2290,25 +2290,15 @@ def test_batchsize_reset():
# test index
td[torch.tensor([1, 2])]
with pytest.raises(
RuntimeError,
match=re.escape(
"The shape torch.Size([3]) is incompatible with the index (slice(None, None, None), 0)."
),
IndexError,
match=re.escape("too many indices for tensor of dimension 1"),
):
td[:, 0]

# test a greater batch_size
td = TensorDict(
{"a": torch.randn(3, 4, 5, 6), "b": torch.randn(3, 4, 5)}, batch_size=[3, 4]
)
with pytest.raises(
RuntimeError,
match=re.escape(
"TensorDict requires tensors that have at least one more dimension than the batch_size"
),
):
td.batch_size = torch.Size([3, 4, 5])
del td["b"]
td.batch_size = torch.Size([3, 4, 5])

td.set("c", torch.randn(3, 4, 5, 6))
Expand Down Expand Up @@ -2467,7 +2457,7 @@ def test_create_on_device():
a = torch.randn(2, 3)
viewedtd.set("a", a)
assert viewedtd.get("a").device == device
assert (a.unsqueeze(-1).to(device) == viewedtd.get("a")).all()
assert (a.to(device) == viewedtd.get("a")).all()


def _remote_process(worker_id, command_pipe_child, command_pipe_parent, tensordict):
Expand Down Expand Up @@ -3082,7 +3072,7 @@ def test_memory_lock(method):
class TestMakeTensorDict:
def test_create_tensordict(self):
tensordict = make_tensordict(a=torch.zeros(3, 4))
assert (tensordict["a"] == torch.zeros(3, 4, 1)).all()
assert (tensordict["a"] == torch.zeros(3, 4)).all()

def test_tensordict_batch_size(self):
tensordict = make_tensordict()
Expand Down Expand Up @@ -3337,7 +3327,7 @@ def test_lazy_stacked_insert(dim, index, device):
assert lstd.batch_size == torch.Size(bs)
assert set(lstd.keys()) == {"a"}

t = torch.zeros(*bs, 1, device=device)
t = torch.zeros(*bs, device=device)

if dim == 0:
t[index] = 1
Expand Down Expand Up @@ -3375,7 +3365,7 @@ def test_lazy_stacked_append(dim, device):
assert lstd.batch_size == torch.Size(bs)
assert set(lstd.keys()) == {"a"}

t = torch.zeros(*bs, 1, device=device)
t = torch.zeros(*bs, device=device)

if dim == 0:
t[-1] = 1
Expand Down