New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fused attention patterns #97741
Fused attention patterns #97741
Conversation
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97741
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 75c89eb: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 9e9d55ffb30d5b8ae6f4834cdc516c2ca432d11c Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: f28e19e4652844c1b03570266081fb9350c4bb4c Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 9a29d7318b98f6ea7d767d25fe879cb144427d6c Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 27f35ae573c85de846561431692d8e7a2b949228 Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: b40c2aa25f550538c6d1bdcca8d4f732ae6e20f5 Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 8a47b3bdf3ad335c9a100cd91ba1b00e01fdb82a Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 9d151de57eec17de6f70e66b0cd17eab2c9ccb5c Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: fe121935b26fc1a9bfae4ac91342d2d2d0389e0a Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 7af7bed3e6fccb9b32ef7b2087316ada15e4ba7b Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: ccd7b0e6f0454ee6760058762a7c806ba0c6bf2d Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 5e1675469721abcad30970ed0de0de2a65e8e837 Pull Request resolved: #97741
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: db9be06f974bb6adf43a8c2ace0f58f5da3a65c0 Pull Request resolved: #97741
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine to me - I think it's reasonable to add optimizations to the joint graph, although I wonder how many of these would be better done at the pre-dispatch IR level.
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Patterns based on #94729 -- mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. ghstack-source-id: 59e3d747059acb3ab1423cf3f8e8bad8687f8b0f Pull Request resolved: #97741
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Patterns based on pytorch#94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. Pull Request resolved: pytorch#97741 Approved by: https://github.com/Chillee
@jansel there appears to be a regression here with 'phantom compilations' popping up (dynamo is compiling extra stuff nobody asked for) -- #98778 Also i was attempting earlier to bisect an apparent regression in compile time on the dashboard but it was hard to bisect, this commit was in the 10 commit range between the bad and good, so i wonder about that too. |
The extra calls to AotAutograd are intended in this PR. They should only happen on warmup and when pattern matching triggers. We should change logging to hide these calls by default. |
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements. Up until now, we had two places to do pattern matching 1) Pre-grad has janky infra (graph not normalized or functional), but is desirable for many types of passes where you want your change to affect grad formulas. 2) Post-grad has good infra, but cant change grad formulas. This PR adds a third place to do pattern matching: the joint forward+backwards graph. The idea is to take the patterns and lower them to a joint graph and replace both the forwards+backwards before we partition them. This allows us to do something similar to pre-grad transforms, but run after normalization and functionalization. Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher. Pull Request resolved: #97741 Approved by: https://github.com/Chillee
Stack from ghstack (oldest at bottom):
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.
Up until now, we had two places to do pattern matching
desirable for many types of passes where you want your change to
affect grad formulas.
This PR adds a third place to do pattern matching: the joint
forward+backwards graph. The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them. This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.
Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.
cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire