-
Notifications
You must be signed in to change notification settings - Fork 411
[BugFix] Updated TensorDict.expand to work as Tensor.expand #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High-level comment: better to create a separate branch on your fork and not work on main :)
I don't mind merging this PR but I wanted to change something to fix the CI and run the tests, unfortunately I can't do that if you're working on main
Otherwise LGTM
I was wondering if we should test that the expanded shape is compatible with the old one?
We'll get an error if it doesn't anyway but perhaps things could be clearer for the user if we indicate what has gone wrong before the expand operation is applied on the tensors.
Let's way for the CI to pass (I'll ping you when we can make it work) and see if we want to implement this exception before merging.
I made a PR on your fork to fix the tests! |
[BugFix] Temporarily fix gym to 0.25.1 to fix CI (pytorch#411)
Thanks for the feedback ... I can implement the exception to check if the expanded shape is compatible with the old one in the expand function. Also I noticed 3 (for stacked_td, sub_td and sub_td2) out of the 12 unit test cases failing for test_expand function because the expand function is overriden in the SubTensorDict and LazyStackedTensorDict class and the functionality seemed to be a little different there. For example, in the expand function in LazyStackedTensorDict it first expands the tensorsdicts in the stack independently and then sacks them back. I'll ping you to understand the reason for these overrides |
@AnshulSehgal there is an expand that has not been corrected in the |
Done! |
Description
Cause of the issue #398 :
Currently in the function when tensor.expand is called the argument passed is concatenation of new shape and the shape of the tensor stored as value in the key value pair and the batch_size of the resultant TensorDict is also a concatenation of new shape and current batch size.
Because of the above setup the tensorDict and tensors stored as values are always expanded to larger number of dimensions, the new ones (same as shape passed to the expand function) appended at the front
Expected behavior after my change
The shape passed as argument should be the final shape of the TensorDict after expansion and the function is updated to take care of following 3 scenarios
a. singleton dimension in old_shape: Example new_shape [3,2] old_shape = [3,1]. In this case the expanded TensorDict should have the shape [3,2] and all tensors inside the dict should also change the shape from [3,1,x,x] to [3,2,x,x]
b. no singleton dimension in old shape: Should throw an error
Motivation and Context
close #398
if this solves the issue #398Types of changes
Checklist