-
Notifications
You must be signed in to change notification settings - Fork 25.7k
inductor: support conv+binary foldinig for freezing path #105048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inductor: support conv+binary foldinig for freezing path #105048
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105048
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Merge Blocking SEVsThere is 1 active merge blocking SEVs. Please view them below:
If you must merge, use ✅ 2 Unrelated FailuresAs of commit e3674cb: UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
|
@eellison, I wonder if we may remove the conv_bn folding pass by directly calling many times of binary folding to avoid re-tracing the aot_model using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still need to take a look at this but fyi theres a similar pass in jit, maybe worth checking that all of the same conditions are checked.
Can we reuse these tests ?
6971149#diff-533101d7d55b7513c4966fcda7f2554d9a7e19b82d54c6b00c6db59686d0e18cR1428-R1510
| _binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] | ||
| _computation_calls = [CallFunction(aten.convolution.default, *_conv_args)] | ||
|
|
||
| def _is_constant_node(node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the checks are following it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Would you mind copying over some of the comments and reusing the tests ? And maybe moving to a separate file ?
I think it'll cool to remove the batch norm pass and replace it with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, copy some of the comments and reuse the tests, for scalar case, I don't support it now, because it needs to use ```aten.full`` to create a new tensor, but it can't be folded(https://github.com/pytorch/pytorch/blob/main/torch/_inductor/freezing.py#L105). I will create another PR to remove the batch norm pass.
Yea sounds good ! |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
test/inductor/test_binary_folding.py
Outdated
| inp = torch.rand(inps).to(self.device) | ||
| out_eager = mod_eager(inp) | ||
| out_optimized = out_optimized(inp) | ||
| self.assertEqual(out_optimized, out_eager, atol=2e-04, rtol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU path works fine even using default tolerance, but for the GPU path, there seems has a big variance, but my local machine(A10) can't reproduce it, I need to find another machine to reproduce it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip the pass which cudnn is enabled, it can be passed when disable cudnn convolution.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
|
@pytorchbot merge |
Merge failedReason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov