Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aten decomp] Update sdpa decom #108371

Closed
wants to merge 5 commits into from

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Aug 31, 2023

Stack from ghstack (oldest at bottom):

Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D48917461

Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108371

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 49068e0 with merge base b9fc6d7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kimishpatel added a commit that referenced this pull request Aug 31, 2023
Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 326253868dfe705e7f84ff62043a8d333139d5c1
Pull Request resolved: #108371
@kimishpatel
Copy link
Contributor Author

talked to @larryliu0820 and realized that this is wrong. Need a different solution

Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Sep 1, 2023
Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 22b36d764c6feacc7a5d51117009cab57f451092
Pull Request resolved: #108371
Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Sep 1, 2023
Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 9f853418211822789fb20d8621bbf94d1a0772f1
Pull Request resolved: #108371
@@ -3993,7 +3993,7 @@ def scaled_dot_product_flash_attention(
query, key, value, attn_mask, dropout_p, is_causal, None, scale=scale
)
return (
output,
output.transpose(1, 2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow this is so weird

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is indeed. flash attnetion, https://fburl.com/38pwyabk, does that while the other one does not. I am gonna check if i trigger non flash variant what happens. But surprising calls to contiguous, https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L5441, does not appear in trace. which is also very weird.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a bug? It would be great to add this explanation in the comment. What's the relationship to decomposition statement in the summary?

@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D48917461](https://our.internmc.facebook.com/intern/diff/D48917461)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Sep 1, 2023
Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 730e4790241d63376f8c0652f1f4289ce90b105a
Pull Request resolved: #108371
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D48917461](https://our.internmc.facebook.com/intern/diff/D48917461)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Sep 2, 2023
Summary:
Earlier decomp was routing _flash* variant to _match variant and this
was result in failure during torch.export, for some reason that I
couldnt trace.

However, it seems that we should really have a decomp for
scaled_dot_product_attention, instead of
scaled_dot_product_flash_attention. Right?

This diff adds that. Plus it adds a test to check if the model exported
via two stage export, has decomposed the op. This test needs improvement
to figur eout what the core aten opset is and check for anything that is
not inside.

Test Plan:
test_model_exports_to_core_aten

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 44a94e7bb6ec354833c4b5049311bf80fd17119b
Pull Request resolved: #108371
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/kimishpatel/178/head branch September 7, 2023 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants