Skip to content

Conversation

AnshulSehgal
Copy link
Contributor

Description

  1. Updated the TensorDict.expand to work like tensor.expand
  2. Updated the usage of TensorDict.expand in other files to take (*new_shape, *tensordict.batch_size) as args instead of (*new_shape)
  3. Updated test cases and added a new test case to test the singleton dim expand case
  4. Updated the tutorial ipython notebook to reflect the updated usage

Cause of the issue #398 :
Currently in the function when tensor.expand is called the argument passed is concatenation of new shape and the shape of the tensor stored as value in the key value pair and the batch_size of the resultant TensorDict is also a concatenation of new shape and current batch size.
Because of the above setup the tensorDict and tensors stored as values are always expanded to larger number of dimensions, the new ones (same as shape passed to the expand function) appended at the front

Expected behavior after my change
The shape passed as argument should be the final shape of the TensorDict after expansion and the function is updated to take care of following 3 scenarios

  1. new shape is same as old shape: Example new_shape = [3,2] old_shape = [3,2] In this case nothing changes
  2. new shape has same length as old shape but different values: This can have two cases
    a. singleton dimension in old_shape: Example new_shape [3,2] old_shape = [3,1]. In this case the expanded TensorDict should have the shape [3,2] and all tensors inside the dict should also change the shape from [3,1,x,x] to [3,2,x,x]
    b. no singleton dimension in old shape: Should throw an error
  3. new shape has length greater than old shape: Example new_shape = [3, 3, 2] old_shape = [3, 2]. In this case the expanded TensorDIct should have the shape [3, 3, 2] and all tensors stored in the dict as values should have the shape [3,3,2,x,x,..]. Another example here would be new_shape [3, 3, 2] old_shape [1, 2]. In this case also the output should be same as above but the size of 2nd dimension would change from 1 to 3 for TensorDict and all tensors within it

Motivation and Context

close #398 if this solves the issue #398

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • Example (update in the folder of examples)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 8, 2022
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level comment: better to create a separate branch on your fork and not work on main :)
I don't mind merging this PR but I wanted to change something to fix the CI and run the tests, unfortunately I can't do that if you're working on main

Otherwise LGTM

I was wondering if we should test that the expanded shape is compatible with the old one?
We'll get an error if it doesn't anyway but perhaps things could be clearer for the user if we indicate what has gone wrong before the expand operation is applied on the tensors.

Let's way for the CI to pass (I'll ping you when we can make it work) and see if we want to implement this exception before merging.

@vmoens
Copy link
Collaborator

vmoens commented Sep 9, 2022

I made a PR on your fork to fix the tests!

@vmoens vmoens added bug Something isn't working bc breaking backward compatibility breaking change quality code quality labels Sep 9, 2022
[BugFix] Temporarily fix gym to 0.25.1 to fix CI (pytorch#411)
@AnshulSehgal
Copy link
Contributor Author

Thanks for the feedback ... I can implement the exception to check if the expanded shape is compatible with the old one in the expand function. Also I noticed 3 (for stacked_td, sub_td and sub_td2) out of the 12 unit test cases failing for test_expand function because the expand function is overriden in the SubTensorDict and LazyStackedTensorDict class and the functionality seemed to be a little different there. For example, in the expand function in LazyStackedTensorDict it first expands the tensorsdicts in the stack independently and then sacks them back. I'll ping you to understand the reason for these overrides

@vmoens
Copy link
Collaborator

vmoens commented Sep 9, 2022

@AnshulSehgal there is an expand that has not been corrected in the redq.py loss, can you patch that one too?

@AnshulSehgal
Copy link
Contributor Author

Done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bc breaking backward compatibility breaking change bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. quality code quality
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants