-
Notifications
You must be signed in to change notification settings - Fork 411
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Currently, TensorDict.expand does not work as expected:
>>> import torch
>>> tensor = torch.arange(6).view(3, 2)
>>> tensor.expand(2, 3, 2)
tensor([[[0, 1],
[2, 3],
[4, 5]],
[[0, 1],
[2, 3],
[4, 5]]])
>>> t = TensorDict({"a": tensor}, [3])
>>> t.expand(1)["a"] # gets the same result as above
tensor([[[0, 1],
[2, 3],
[4, 5]],
[[0, 1],
[2, 3],
[4, 5]]])
However, one would expect that this expand would take the argument 2, 3
and not just 2
.
Here is a list of the expected behaviours:
>>> t = TensorDict({"a": tensor}, [3])
>>> t.expand(2, 3) # should return a tensordict of size [2, 3]. Currently it will happen the 2, 3 dimensions, resulting in size [2, 3, 3]
>>> t = TensorDict({"a": tensor}, [1])
>>> t.expand(2, 3) # should return a tensordict of size [2, 3]. Currently it will happen the 2, 3 dimensions, resulting in size [2, 3, 1]
>>> t = TensorDict({"a": tensor}, [2])
>>> t.expand(2, 3) # should break, as the last dim is 2 but the indicated by expand is [2, 3].
This will require a bit of refactoring in the code, including in the tutorials and tests.
A quick look through all the .expand
in the lib is advised. In general, the current calls to tensordict.expand(*new_dims)
can simply be replaced by tensordict.expand(*new_dims, *tensordict.shape)
after the refactoring.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working