-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[functorch] Fix torch.cat batching rule #86932
Conversation
The bug was discovered in #86842. torch.cat has an edge case where it ignores all tensors of shape [0]. So if any of the BatchedTensors have logical shape [0] but physical shape [B, 0], then we coerce them to shape [0] by slicing them. Why don't we just ignore those Tensors? We need to propagate requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of shape [B, 0] that requires grad, then the output must require grad). Test Plan: - new tests [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86932
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 Failures, 1 PendingAs of commit 167c1a6: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The bug was discovered in #86842. torch.cat has an edge case where it ignores all tensors of shape [0]. So if any of the BatchedTensors have logical shape [0] but physical shape [B, 0], then we coerce them to shape [0] by slicing them. Why don't we just ignore those Tensors? We need to propagate requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of shape [B, 0] that requires grad, then the output must require grad). Test Plan: - new tests ghstack-source-id: 3cd4196977aff46c079b63b811b5056290f772f8 Pull Request resolved: #86932
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also port this batching rule to not use the old API? (doesn't need to be done in this PR, just wondering).
Yes. The problem is the vmap codegen doesn't handle operators that accept TensorList yet |
@pytorchbot merge -f "failures look unrelated (they occur in the non-functorch shards) (also there are no logs)" |
1 similar comment
@pytorchbot merge -f "failures look unrelated (they occur in the non-functorch shards) (also there are no logs)" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @zou3519. |
The bug was discovered in pytorch#86842. torch.cat has an edge case where it ignores all tensors of shape [0]. So if any of the BatchedTensors have logical shape [0] but physical shape [B, 0], then we coerce them to shape [0] by slicing them. Why don't we just ignore those Tensors? We need to propagate requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of shape [B, 0] that requires grad, then the output must require grad). Test Plan: - new tests Pull Request resolved: pytorch#86932 Approved by: https://github.com/Chillee
Stack from ghstack:
The bug was discovered in #86842.
torch.cat has an edge case where it ignores all tensors of shape [0]. So
if any of the BatchedTensors have logical shape [0] but physical shape
[B, 0], then we coerce them to shape [0] by slicing them.
Why don't we just ignore those Tensors? We need to propagate
requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of
shape [B, 0] that requires grad, then the output must require grad).
Test Plan: