-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inductor: rewrite mkldnn fx fusion using pattern_matcher(conv_unary) #97007
Conversation
[ghstack-poisoned]
…onv_unary)" cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…onv_unary)" cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97007
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 186a588: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 779777fad1d385b53a45524b12c3401f787277d5 Pull Request resolved: #97007
ghstack-source-id: ae74b1f32376fe8941ea8c8e3e01e3f0f45bd470 Pull Request resolved: #97007
ghstack-source-id: e5f15609e22d0171b3945f0f27aff7039170815e Pull Request resolved: #97007
…onv_unary)" cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: b6ab144c346a478f87e8f7b77cf35ea55ea05500 Pull Request resolved: #97007
…onv_unary)" cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 9484d4056c62d3cd1cc3c0e501c4410977f589ee Pull Request resolved: #97007
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
torch/_inductor/mkldnn.py
Outdated
@@ -138,7 +130,7 @@ def _update_module_params(self, conv, unary, input_size): | |||
|
|||
def _conv_forward(self, input, weight, bias): | |||
if self.padding_mode != "zeros": | |||
return torch.ops.mkldnn._convolution_pointwise( | |||
return torch.mkldnn_convolution( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a single conv
and linear for the repacked conv
. The semantic of _conv_pointwise
is confusing as it serves both the single conv and the fused conv.
torch/_inductor/pattern_matcher.py
Outdated
return L[aten.clamp_max]( | ||
L[aten.clamp_min](conv_out, min_value), max_value | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the decompose version of hardtanh
, right? Is there any way to call the decomposed function directly but not duplicate it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find a way to do it.
…onv_unary)" cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Rebased |
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
torch/_inductor/pattern_matcher.py
Outdated
@@ -608,6 +608,279 @@ def addmm(match, mat1, mat2, inp): | |||
return L[aten.add](inp, L[aten.mm](mat1, mat2)) | |||
|
|||
|
|||
# TODO(XiaobingSuper): move it to fx_passes/mkldnn_fuse.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am ok with this stack of changes, but please move them into a separate file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can create a fx_passes/
folder to match with #97741
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, move it to a separate file.
torch/_inductor/pattern_matcher.py
Outdated
def relu_fusion(computation_call): | ||
return CallFunction(aten.relu, computation_call) | ||
|
||
def sigmoid_fusion(computation_call): | ||
return CallFunction(aten.sigmoid, computation_call) | ||
|
||
def tanh_fusion(computation_call): | ||
return CallFunction(aten.tanh, computation_call) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can match multiple functions in a single pattern like:
def relu_fusion(computation_call): | |
return CallFunction(aten.relu, computation_call) | |
def sigmoid_fusion(computation_call): | |
return CallFunction(aten.sigmoid, computation_call) | |
def tanh_fusion(computation_call): | |
return CallFunction(aten.tanh, computation_call) | |
def combined_fusion(computation_call): | |
return CallFunction([aten.relu, aten.sigmoid, aten.tanh], computation_call) |
Then figure out which one you hit with the match
arg to the lowering.
This should allow you to have fewer patterns throughout the stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it is simplified.
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
@jansel @desertfire, please help review this code again. Thanks! |
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
for pattern, computation_op in zip(_hardtanh_patterns, computation_ops): | ||
_register_hardtanh_fusion_lowering(pattern, computation_op) | ||
|
||
def _mkldnn_fusion_init(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _mkldnn_fusion_init(): | |
@functools.lru_cache(None) | |
def _mkldnn_fusion_init(): |
So this code only runs once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed.
if torch._C.has_mkldnn: | ||
from .mkldnn_fusion import _mkldnn_fusion_init | ||
|
||
_mkldnn_fusion_init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this code into a lazy_init()
function with an lru_cache(None)
decorator so it only runs once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the code to lazy_init(). Thanks!
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…onv_unary)" cc jgong5 mingfeima sanchitintel ashokei jingxu10 soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…97007) Pull Request resolved: #97007 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
Stack from ghstack (oldest at bottom):
cc @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire