Skip to content

Conversation

@XiaobingSuper
Copy link
Collaborator

@XiaobingSuper XiaobingSuper commented Apr 17, 2023

Stack from ghstack (oldest at bottom):

For TIMM xcit_large_24_p8_224, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and scaled_dot_product_attention doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value.

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99296

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 0723f11:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.


For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
self._check_common(sfdp_pattern_6, contains=False)

def test_pattern_fails_with_tensor_factor(self):
# https://github.com/pytorch/pytorch/issues/99124
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is reported on CPU side, but I checked it is also happening on GPU side, so add this case here.

scale_factor_node = list(view_node.users.keys())[0]
if len(scale_factor_node.args) != 2:
return False
# make sure the scale_factor a float/int. SymInt?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a dynamic shape case to verify SymInt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we may need to add dynamic shape case, but the patterns doesn't match for dynamic shape case. Let do it at next step.

Comment on lines +183 to +184
def _return_true(match):
return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we fold this function to _sfdp_scale_factor_check.

# make sure the mul(div) for the scale factor is scalar mul(div).
# bmm->view->mul(div)
matmuls = filter_nodes(match.nodes, aten.bmm)
if len(matmuls) < 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this already be checked by the pattern? Can you give an example and add a test for when this fails?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this has been checked by the pattern, I do that to make sure it is all right. I will remove it to simplify it.

if len(matmuls) < 2:
return False
if (
len(matmuls[0].users) != 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this already be checked by the pattern? Can you give an example and add a test for when this fails?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed such a check.

return False
view_node = list(matmuls[0].users.keys())[0]
if (
len(view_node.users) != 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this already be checked by the pattern? Can you give an example and add a test for when this fails?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed such a check.

or list(view_node.users.keys())[0].target != scale_factor_op
):
return False
scale_factor_node = list(view_node.users.keys())[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you go directly here with filter_nodes(match.nodes, scale_factor_op)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, changed.

torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
]
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test training as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training test is added, please note that even if not have this PR, the training path can work well(the pattern doesn't match), and has an accuracy gap compared with eager mode.


For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request Apr 18, 2023
ghstack-source-id: 27301c3
Pull Request resolved: #99296
@XiaobingSuper XiaobingSuper requested a review from jansel April 18, 2023 03:03

For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request Apr 18, 2023
ghstack-source-id: 0e95ee9
Pull Request resolved: #99296

For TIMM ```xcit_large_24_p8_224```, the scale factor is a tensor(https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/xcit.py#L205), and ```scaled_dot_product_attention``` doesn't support it, this PR will add a check which only does the fusion when the scale factor is float/int value.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@XiaobingSuper
Copy link
Collaborator Author

@jansel, please help review this PR again. Thanks!

@XiaobingSuper XiaobingSuper added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 20, 2023
@XiaobingSuper
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/XiaobingSuper/92/head branch June 8, 2023 15:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Inductor] [CPU] scaled_dot_product_attention() unexpected a value type caused crash in xcit_large_24_p8_224

6 participants