New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aten decomp] Update sdpa decom #108371
[aten decomp] Update sdpa decom #108371
Conversation
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108371
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 49068e0 with merge base b9fc6d7 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 326253868dfe705e7f84ff62043a8d333139d5c1 Pull Request resolved: #108371
talked to @larryliu0820 and realized that this is wrong. Need a different solution |
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 22b36d764c6feacc7a5d51117009cab57f451092 Pull Request resolved: #108371
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 9f853418211822789fb20d8621bbf94d1a0772f1 Pull Request resolved: #108371
@@ -3993,7 +3993,7 @@ def scaled_dot_product_flash_attention( | |||
query, key, value, attn_mask, dropout_p, is_causal, None, scale=scale | |||
) | |||
return ( | |||
output, | |||
output.transpose(1, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow this is so weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is indeed. flash attnetion, https://fburl.com/38pwyabk, does that while the other one does not. I am gonna check if i trigger non flash variant what happens. But surprising calls to contiguous, https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L5441, does not appear in trace. which is also very weird.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a bug? It would be great to add this explanation in the comment. What's the relationship to decomposition statement in the summary?
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D48917461](https://our.internmc.facebook.com/intern/diff/D48917461) [ghstack-poisoned]
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 730e4790241d63376f8c0652f1f4289ce90b105a Pull Request resolved: #108371
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D48917461](https://our.internmc.facebook.com/intern/diff/D48917461) [ghstack-poisoned]
Summary: Earlier decomp was routing _flash* variant to _match variant and this was result in failure during torch.export, for some reason that I couldnt trace. However, it seems that we should really have a decomp for scaled_dot_product_attention, instead of scaled_dot_product_flash_attention. Right? This diff adds that. Plus it adds a test to check if the model exported via two stage export, has decomposed the op. This test needs improvement to figur eout what the core aten opset is and check for anything that is not inside. Test Plan: test_model_exports_to_core_aten Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 44a94e7bb6ec354833c4b5049311bf80fd17119b Pull Request resolved: #108371
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.
However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?
This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.
Test Plan:
test_model_exports_to_core_aten
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D48917461