-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bootcamp] [AI Accelerators] Convert attention equation to attention call in torch.compile() #94729
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94729
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 791358f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D43234635 |
CLA is signed now. I am a Meta employee in Bootcamp. |
/easycla |
…rch#94729) Summary: Pull Request resolved: pytorch#94729 Many attention-based models directly implement scaled dot product attention used by transformer models using component pytorch operators. Recognize these with FX rewrites and convert them to calls to sdpa(), such that the optimized implementation is chosen where applicable. Implementation Given that the usual scaled dot product attention equation references a shape variable ( because everything gets normalized by a dividing by sqrt(embedding_dim_size), the graph rewrite is not entirely trivial. It was neccessary to introduce a new function ( torch.nn.scale_factor_dot_product_attention ). which accepts a scale factor as parameter. This in turn allows to rewrite any pattern like torch.matmul(query, key.transpose(-2, -1)) .div(scale_factor) .softmax(dim=-1) .matmul(value) into a corresponding (potentially optimized) API call like scale_factor_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False, scale_factor=scale_factor, ) without knowing the scale factor or query and key shapes at compile time. As I would consider this rewriter to be an alpha-quality feature, a new **feature switch** was introduced at *torch._inductor.pattern_replace_scaled_dot_product_attention* which defaults to False torch.nn.scale_factor_dot_product_attention would be unneccessary if we would introduce an optional scale_factor parameter to the pre-existing torch.nn.scaled_dot_product_attention. I chose not to do this, as it would suggest that an optimized function would be used, regardless of the choice of scale_factor, and due to the higher risk of introducing a breaking change. Also, it should be possible to push the logic which is now in native code within torch.nn.functional.scale_factor_dot_product_attention into Python code. Not clear what the benefits of that would be, but maybe you prefer it that way. Test Plan: Added several new unit tests. These can be executed selectively by running buck2 test @//mode/dev-nosan caffe2/test/inductor:pattern_matcher -- test_scaled_dot_product_attention buck2 test @//mode/dev-nosan caffe2/test:test_transformers -- test_scale_factor_dot_product_attention Differential Revision: D43234635
9889cdc
to
1b33070
Compare
Update: Had to do a force-commit from the patch file after I was unable to merge via export from Meta-internal diff. |
…call in torch.compile() (pytorch#94729) Summary: Pull Request resolved: pytorch#94729 ### Goals Many attention-based models directly implement scaled dot product attention used by transformer models using component pytorch operators. Recognize these with FX rewrites and convert them to calls to sdpa(), such that the optimized implementation is chosen where applicable. Implementation Given that the usual scaled dot product attention equation references a shape variable ( because everything gets normalized by a dividing by sqrt(embedding_dim_size), the graph rewrite is not entirely trivial. It was neccessary to introduce a new function ( torch.nn.scale_factor_dot_product_attention ). which accepts a scale factor as parameter. This in turn allows to rewrite any pattern like torch.matmul(query, key.transpose(-2, -1)) .div(scale_factor) .softmax(dim=-1) .matmul(value) into a corresponding (potentially optimized) API call like scale_factor_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False, scale_factor=scale_factor, ) without knowing the scale factor or query and key shapes at compile time. ### Feature Switch As I would consider this rewriter to be an alpha-quality feature, a new **feature switch** was introduced at *torch._inductor.pattern_replace_scaled_dot_product_attention* which defaults to False ### Reviewer Notes: torch.nn.scale_factor_dot_product_attention would be unneccessary if we would introduce an optional scale_factor parameter to the pre-existing torch.nn.scaled_dot_product_attention. I chose not to do this, as it would suggest that an optimized function would be used, regardless of the choice of scale_factor, and due to the higher risk of introducing a breaking change. Also, it should be possible to push the logic which is now in native code within torch.nn.functional.scale_factor_dot_product_attention into Python code. Not clear what the benefits of that would be, but maybe you prefer it that way. Test Plan: Added several new unit tests. These can be executed selectively by running buck2 test @//mode/dev-nosan caffe2/test/inductor:pattern_matcher -- test_scaled_dot_product_attention buck2 test @//mode/dev-nosan caffe2/test:test_transformers -- test_scale_factor_dot_product_attention Differential Revision: D43234635 fbshipit-source-id: d5517f78a1e39088bb604d4e27468b574f5eea21
1b33070
to
2940b59
Compare
This pull request was exported from Phabricator. Differential Revision: D43234635 |
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 8a47b3bdf3ad335c9a100cd91ba1b00e01fdb82a Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 9d151de57eec17de6f70e66b0cd17eab2c9ccb5c Pull Request resolved: #97741
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: fe121935b26fc1a9bfae4ac91342d2d2d0389e0a Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 7af7bed3e6fccb9b32ef7b2087316ada15e4ba7b Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: ccd7b0e6f0454ee6760058762a7c806ba0c6bf2d Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 5e1675469721abcad30970ed0de0de2a65e8e837 Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: db9be06f974bb6adf43a8c2ace0f58f5da3a65c0 Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 59e3d747059acb3ab1423cf3f8e8bad8687f8b0f Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. Pull Request resolved: #97741 Approved by: https://github.com/Chillee
Patterns based on pytorch#94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. Pull Request resolved: pytorch#97741 Approved by: https://github.com/Chillee
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. Pull Request resolved: #97741 Approved by: https://github.com/Chillee
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Summary:
Goals
Many attention-based models directly implement scaled dot product attention used by transformer models using component pytorch operators. Recognize these with FX rewrites and convert them to calls to sdpa(), such that
the optimized implementation is chosen where applicable.
Implementation
Given that the usual scaled dot product attention equation references a shape variable ( because everything gets normalized by a dividing by sqrt(embedding_dim_size), the graph rewrite is not entirely trivial.
It requires a new scale_factor parameter in torch.nn.functional.scaled_dot_product_attention ( not part of this PR, see discussion #94729 (comment) about earlier version of this PR )
This in turn will allow to rewrite any pattern like
torch.matmul(query, key.transpose(-2, -1))
.div(scale_factor)
.softmax(dim=-1)
.matmul(value)
into a corresponding (potentially optimized) API call like
scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale_factor=scale_factor,
)
without knowing the scale factor or query and key shapes at compile time.
Reviewer Notes:
As discussed with drisspg, this code would be ready for integration into torch._inductor.overrides once the scale_factor parameter is available on torch.nn.functional.scaled_dot_product_attention. Once that is done, the corresponding sections in the unit tests
may be activated again to verify not just the graph rewrite, but also numerical equivalence.
Followup Tasks:
Once torch.nn.functional.scaled_dot_product_attention has a scale_factor parameter, this rewrite could be activated by configuring it as the sdpa implementation, and the inactive code sections in the unit tests could be made active again.
Test Plan:
Added several new unit tests. These can be executed selectively by running ( Meta Internal Tooling )
buck2 test @//mode/dev-nosan caffe2/test/inductor:pattern_matcher -- test_sdpa_rewriter
Differential Revision: D43234635
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @soumith @yanboliang @anijain2305 @desertfire @mlazos