Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Make sure unfuse_addmm and addmm patterns don't overlap #110235

Closed
wants to merge 5 commits into from

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Sep 28, 2023

Stack from ghstack (oldest at bottom):

Inductor has two opposing patterns,

addmm -> add + mm
add + mm -> addmm

This uses the extra_check to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110235

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fab9c1e with merge base 419ec3b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 28, 2023
Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

ghstack-source-id: 51dde8a7d748c44ce3e04d0f733dba3c14dc539a
Pull Request resolved: #110235
…verlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 28, 2023
Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

ghstack-source-id: 19fb7d152103d3967d79f47ec9ca020705c3e19b
Pull Request resolved: #110235
… overlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 29, 2023
Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

ghstack-source-id: 3f78107ecdb5d1a4df52b369d2e11241b845f8f8
Pull Request resolved: #110235
@peterbell10 peterbell10 marked this pull request as ready for review September 29, 2023 12:39
e1, e2 = fn(*args)
a1, a2 = torch.compile(fn)(*args)
torch.testing.assert_close(a1, e1)
torch.testing.assert_close(a2, e2)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
count, nodes = (2, 4) if should_fuse else (0, 0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that these cases weren't actually fused previously, it's just that the pattern replaced them with a lowering that did add + mm.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I think the code has a preexisting issue that we should fix tho.

torch/_inductor/fx_passes/post_grad.py Outdated Show resolved Hide resolved
… overlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 29, 2023
Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

ghstack-source-id: 2bb268e61749862f7c32f6c7efb40673cf0189f3
Pull Request resolved: #110235
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 29, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good just one comment about checking for input being a tensor

Comment on lines -462 to -470
def addmm(match, mat1, mat2, inp):
if isinstance(inp, ir.TensorBox):
inp_shape = inp.get_size()
matched = len(inp_shape) <= 2
mm_shape = shape_of_mm(mat1, mat2)
for i, m in zip(inp_shape, mm_shape):
matched &= i == 1 or i == m
else: # inp is a Number
matched = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to move this away from graph lowering pattern, this was overdue..

return not should_prefer_unfused_addmm(match)


@register_graph_pattern(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yanboliang @jansel we should have some sort of commutative concept that would avoid this duplication

Comment on lines 787 to 788
if not isinstance(inp, torch.fx.Node):
return False # Input is a number
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've made this check before which was fixed by #108160, you can have a fx.Node input which is a SymInt/SymFloat

…atterns don't overlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 29, 2023
Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

ghstack-source-id: f1ea08f1c736ada4dc88545dd0965753bbaf7bf2
Pull Request resolved: #110235
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/625/head branch October 3, 2023 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants