-
Notifications
You must be signed in to change notification settings - Fork 25.7k
inductor: add input type check for fuse_attention #99296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99296
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 0723f11: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
| self._check_common(sfdp_pattern_6, contains=False) | ||
|
|
||
| def test_pattern_fails_with_tensor_factor(self): | ||
| # https://github.com/pytorch/pytorch/issues/99124 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This issue is reported on CPU side, but I checked it is also happening on GPU side, so add this case here.
| scale_factor_node = list(view_node.users.keys())[0] | ||
| if len(scale_factor_node.args) != 2: | ||
| return False | ||
| # make sure the scale_factor a float/int. SymInt? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a dynamic shape case to verify SymInt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we may need to add dynamic shape case, but the patterns doesn't match for dynamic shape case. Let do it at next step.
| def _return_true(match): | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we fold this function to _sfdp_scale_factor_check.
| # make sure the mul(div) for the scale factor is scalar mul(div). | ||
| # bmm->view->mul(div) | ||
| matmuls = filter_nodes(match.nodes, aten.bmm) | ||
| if len(matmuls) < 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this already be checked by the pattern? Can you give an example and add a test for when this fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this has been checked by the pattern, I do that to make sure it is all right. I will remove it to simplify it.
| if len(matmuls) < 2: | ||
| return False | ||
| if ( | ||
| len(matmuls[0].users) != 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this already be checked by the pattern? Can you give an example and add a test for when this fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed such a check.
| return False | ||
| view_node = list(matmuls[0].users.keys())[0] | ||
| if ( | ||
| len(view_node.users) != 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this already be checked by the pattern? Can you give an example and add a test for when this fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed such a check.
| or list(view_node.users.keys())[0].target != scale_factor_op | ||
| ): | ||
| return False | ||
| scale_factor_node = list(view_node.users.keys())[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you go directly here with filter_nodes(match.nodes, scale_factor_op)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, changed.
| torch.randn(tensor_shape, device="cuda"), | ||
| torch.randn(tensor_shape, device="cuda"), | ||
| ] | ||
| with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you test training as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
training test is added, please note that even if not have this PR, the training path can work well(the pattern doesn't match), and has an accuracy gap compared with eager mode.
For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
|
@jansel, please help review this PR again. Thanks! |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
For TIMM
xcit_large_24_p8_224, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), andscaled_dot_product_attentiondoesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value.cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire