New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Make sure unfuse_addmm and addmm patterns don't overlap #110235
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110235
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fab9c1e with merge base 419ec3b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. [ghstack-poisoned]
Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. ghstack-source-id: 51dde8a7d748c44ce3e04d0f733dba3c14dc539a Pull Request resolved: #110235
…verlap" Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. ghstack-source-id: 19fb7d152103d3967d79f47ec9ca020705c3e19b Pull Request resolved: #110235
… overlap" Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. ghstack-source-id: 3f78107ecdb5d1a4df52b369d2e11241b845f8f8 Pull Request resolved: #110235
e1, e2 = fn(*args) | ||
a1, a2 = torch.compile(fn)(*args) | ||
torch.testing.assert_close(a1, e1) | ||
torch.testing.assert_close(a2, e2) | ||
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2) | ||
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4) | ||
count, nodes = (2, 4) if should_fuse else (0, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that these cases weren't actually fused previously, it's just that the pattern replaced them with a lowering that did add + mm
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. I think the code has a preexisting issue that we should fix tho.
… overlap" Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. ghstack-source-id: 2bb268e61749862f7c32f6c7efb40673cf0189f3 Pull Request resolved: #110235
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even better
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good just one comment about checking for input being a tensor
def addmm(match, mat1, mat2, inp): | ||
if isinstance(inp, ir.TensorBox): | ||
inp_shape = inp.get_size() | ||
matched = len(inp_shape) <= 2 | ||
mm_shape = shape_of_mm(mat1, mat2) | ||
for i, m in zip(inp_shape, mm_shape): | ||
matched &= i == 1 or i == m | ||
else: # inp is a Number | ||
matched = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to move this away from graph lowering pattern, this was overdue..
return not should_prefer_unfused_addmm(match) | ||
|
||
|
||
@register_graph_pattern( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yanboliang @jansel we should have some sort of commutative concept that would avoid this duplication
if not isinstance(inp, torch.fx.Node): | ||
return False # Input is a number |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i've made this check before which was fixed by #108160, you can have a fx.Node input which is a SymInt/SymFloat
…atterns don't overlap" Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Inductor has two opposing patterns, ``` addmm -> add + mm add + mm -> addmm ``` This uses the `extra_check` to disable the addmm fusion pattern when the heuristic to unfuse add is met, for consistency. ghstack-source-id: f1ea08f1c736ada4dc88545dd0965753bbaf7bf2 Pull Request resolved: #110235
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cast_to_fp64
#110232Inductor has two opposing patterns,
This uses the
extra_check
to disable the addmm fusion pattern when theheuristic to unfuse add is met, for consistency.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler