-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] GEMM shape padding improvements #118522
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118522
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit baf3b9a with merge base 68c3cb7 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
eac9bc0
to
d9c820f
Compare
@kadeng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Would you add a few tests?
size=[batchsize, m, n], | ||
stride=[n * m, n, 1], | ||
) | ||
if use_cutlass_template(fake_layout): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still want to pad bandwidth bound mms in cutlass ? I found for the small matmuls it wasnt profitable to pad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not going to use cutlass for small Matmuls, in practice they will have a size threshold (in terms of MNK) below which they won't be used. This is how I did it here ( Meta-internal link) https://fb.workplace.com/notes/347669491400529/ and arrived at whole-model speedups of up to 14%.
And bandwidth bound large matmuls are definitely what I want the padding to be applied to, these can be sped up considerably by Cutlass. This has a benefit on average because the padding can often be fused with other (previous) ops by Triton (such that no mem IO overhead remains through padding) or be done via no-op memory reinterpretation ( e.g. transpose etc.). In addition, the Cutlass speedup can be so large that it would speed things up even if the above were not the case.
See also this (really large) PR where the threshold is introduced: #118416
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has a benefit on average because the padding can often be fused with other (previous) ops by Triton (such that no mem IO overhead remains through padding) or be done via no-op memory reinterpretation ( e.g. transpose etc.).
For both these cases we should be able to analyze the graph to know that these are going to fire... A general TODO in this file is to skip the padding on a tensor in benchmarking when it comes from a fusable operator.
elif n_padded_length == 0 and m_padded_length != 0: | ||
m_padded_length = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding:
if A[M, K] @ B[K, N] , with only dimension M need padding, we will consider this matmul does not need padding?
Do you actually means dimension M does not worth padding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, because A[M,K] @ B[K,N] -> [M,N] -> this is fine, since N will be the last dimension, M won't have to meet alignment requirements ( assuming row-major output layout, which will then certainly be picked by the matmul backend if M is not aligned ).
So, in other words, we only need K and either M or N to be aligned.
d9c820f
to
64fb65a
Compare
@kadeng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
64fb65a
to
4b5cd88
Compare
@kadeng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Just did that. I focused on testing the padding logic itself, not the pattern matching etc. which is unchanged. |
4b5cd88
to
a9ba985
Compare
Improvements to shape padding logic in torch/_inductor/pad_mm.py Most notably: * Enable shape padding for Cutlass * Add flag to always pad shapes * Use aten.const_pad_nd operation to pad Tensors in a single op instead of using multiple steps involving intermediate buffers. * Make many paddings unneccessary when either M or N dimension is properly aligned but the other is not ( configurable, on by default Updates: * Addressed reviewer comments * Removed config setting to only pad K dimension * Added detailed unit tests in test/inductor/test_pad_mm.py
a9ba985
to
baf3b9a
Compare
@kadeng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Improvements to shape padding logic in torch/_inductor/pad_mm.py These changes could lead up to 14% perf improvement for certain Meta internal models in experiments. Most notably: * 1.) Use aten.const_pad_nd operation to pad Tensors in a single op instead of using multiple steps involving intermediate buffers. This appears to be more performant than the previous logic, confirmed by Profiling & Benchmarking results ( Meta internal ) * 2.) Make many paddings unneccessary using explicitly transposed GEMM when either M or N dimension is properly aligned but the other is not, configurable via config.shape_pad_use_transpose (default: True). * 3.) Enable shape padding for the Inductor CUDA / Cutlass backend for all GEMM ops where Cutlass would be enabled, without benchmarking in that case. * Add config flag to always pad shapes (without benchmarking first), configurable via config.force_shape_pad (default: False ) * Added several new unit tests to ensure tensors are padded such that they meet all alignment requirements after padding. Pull Request resolved: #118522 Approved by: https://github.com/jansel, https://github.com/eellison
This reverts commit cc46829. Reverted #118522 on behalf of https://github.com/eellison due to regresses HF ~4/5% ([comment](#118522 (comment)))
This reverts commit cc46829. Reverted #118522 on behalf of https://github.com/eellison due to regresses HF ~4/5% ([comment](#118522 (comment)))
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Superseded by #119578 |
Relanding just the pad in a single pass portion of the pr. Not including the transpose logic: [ghstack-poisoned]
Relanding just the pad in a single pass portion of the pr. Not including the transpose logic: ghstack-source-id: 710d916730270cc59da0e4410ecf12be434cebc0 Pull Request resolved: #125773
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ovements (#118522)'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…'" Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Relanding just the pad in a single pass portion of [the pr](#118522). Not including the transpose logic: This was previously accepted and reviewed. Pull Request resolved: #125773 Approved by: https://github.com/shunting314 ghstack dependencies: #125772
…ytorch#125773) Relanding just the pad in a single pass portion of [the pr](pytorch#118522). Not including the transpose logic: This was previously accepted and reviewed. Pull Request resolved: pytorch#125773 Approved by: https://github.com/shunting314 ghstack dependencies: pytorch#125772
Improvements to shape padding logic in torch/_inductor/pad_mm.py
These changes could lead up to 14% perf improvement for certain Meta internal models in experiments.
Most notably:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang