Skip to content

[BUG] TensorDict.expand does not work as Tensor.expand #398

@vmoens

Description

@vmoens

Describe the bug

Currently, TensorDict.expand does not work as expected:

>>> import torch
>>> tensor = torch.arange(6).view(3, 2)
>>> tensor.expand(2, 3, 2)
tensor([[[0, 1],
         [2, 3],
         [4, 5]],

        [[0, 1],
         [2, 3],
         [4, 5]]])
>>> t = TensorDict({"a": tensor}, [3])
>>> t.expand(1)["a"]  # gets the same result as above
tensor([[[0, 1],
         [2, 3],
         [4, 5]],

        [[0, 1],
         [2, 3],
         [4, 5]]])

However, one would expect that this expand would take the argument 2, 3 and not just 2.

Here is a list of the expected behaviours:

>>> t = TensorDict({"a": tensor}, [3])
>>> t.expand(2, 3)  # should return a tensordict of size [2, 3]. Currently it will happen the 2, 3 dimensions, resulting in size [2, 3, 3]
>>> t = TensorDict({"a": tensor}, [1])
>>> t.expand(2, 3)  # should return a tensordict of size [2, 3]. Currently it will happen the 2, 3 dimensions, resulting in size [2, 3, 1]
>>> t = TensorDict({"a": tensor}, [2])
>>> t.expand(2, 3)  # should break, as the last dim is 2 but the indicated by expand is [2, 3]. 

This will require a bit of refactoring in the code, including in the tutorials and tests.
A quick look through all the .expand in the lib is advised. In general, the current calls to tensordict.expand(*new_dims) can simply be replaced by tensordict.expand(*new_dims, *tensordict.shape) after the refactoring.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions