-
Notifications
You must be signed in to change notification settings - Fork 24.5k
[CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs #150812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150812
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3dc77af with merge base daf2ccf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
[ghstack-poisoned]
@@ -1375,6 +1375,16 @@ | |||
SparseCUDA: bmm_out_sparse_cuda | |||
SparseCsrCUDA: bmm_out_sparse_csr_cuda | |||
|
|||
- func: bmm.out_dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also add optional compute_dtype args?
aten/src/ATen/native/cuda/Blas.cpp
Outdated
Tensor out = at::empty({batch1_sizes[0], batch1_sizes[1], batch2_sizes[2]}, batch1.options().dtype(out_dtype)); | ||
{ | ||
NoNamesGuard guard; | ||
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha, out_dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at this point you no longer need out_dtype argument, you can get dtype from out
[ghstack-poisoned]
…6/BF16 inputs" [ghstack-poisoned]
aten/src/ATen/cuda/CUDABlas.cpp
Outdated
(void*)alpha_ptr, a, rocblas_datatype_f16_r, (int)lda, stridea, | ||
b, rocblas_datatype_f16_r, (int)ldb, strideb, | ||
(void*)beta_ptr, c, rocblas_datatype_f16_r, (int)ldc, stridec, | ||
c, rocblas_datatype_f16_r, (int)ldc, stridec, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems datatype should be changed here also? Or we should be erroring out earlier for rocm and not have an implementation at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Erroring out here is probably best, will do that before
aten/src/ATen/cuda/CUDABlas.cpp
Outdated
@@ -617,6 +627,67 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { | |||
#endif // USE_ROCM | |||
} | |||
|
|||
template <> | |||
void bgemm_internal_cublas<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks like a copy paste from <at::Half>
specialization with the only difference in c dtype, perhaps factor it out as a helper function and set c dtype with std::is_same
template <>
void bgemm_internal_cublas<at::Half, at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
bgemm_internal_cublas_helper<at::Half>(CUDABLAS_BGEMM_ARGS_AND_C_DTYPE(at::Half, at::Half));
}
template <>
void bgemm_internal_cublas<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) {
bgemm_internal_cublas_helper<float>(CUDABLAS_BGEMM_ARGS(at::Half));
}
…6/BF16 inputs" [ghstack-poisoned]
…6/BF16 inputs" [ghstack-poisoned]
…6/BF16 inputs" [ghstack-poisoned]
aten/src/ATen/cuda/CUDABlas.cpp
Outdated
0, flag))); | ||
#else | ||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this codepath make sense as an inclusion? For half
it's a slightly esoteric option to do accumulation in half
as well for SKUs that have limited float
accumulation performance. For float
output, even if it worked it would functionally be no different than a simple cast on top if the accum was done in half
IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I don't know maybe just for API uniformity, even if it doesn't make a practical sense
aten/src/ATen/cuda/CUDABlas.cpp
Outdated
flag))); | ||
#else | ||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same situation as above, does it make sense to add output dtype
as a param when accumulation was done in half
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Half accum for half inputs was something that was supported before, so I think leaving it in is fine, unless there is an explicit reason to remove?
As far as supporting output dtype
with half
accum, I think even if the functionality can be simplified, it would be better to keep this functionality as opposed to overriding the output dtype
argument. Another possibility if we don't want to support this functionality is to throw an error? cc @ngimel for thoughts as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if cublas itself allows this combination, there's no reason to throw an error for it. If on the other hand cublas doesn't support it then throwing an error seems more reasonable than coming up with workarounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel @eqy will get this combination to throw an error as fp16 accum with fp32 output isn't actually supported in cublas based on docs: https://docs.nvidia.com/cuda/cublas/#cublasgemmstridedbatchedex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I follow. As @ngimel pointed out, the unfortunately named allowFP16Accumulation
and allowBF16Accumulation
flags actually control the math mode, and if they are true
, we would set e.g.,
cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION)
https://docs.nvidia.com/cuda/cublas/#cublassetmathmode. Which I assume since it allows both---would be surprising if it raises a unsupported error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allowFP16AccumulationCuBLAS
will set compute type to CUBLAS_COMPUTE_16F, and according to the docs, only fp16 Ctype is compatible with that, so we should throw an error if user-requested ctype is fp32. allowFP16ReductionCuBLAS
sets MathMode and should be compatible with different ctypes. So I think @PaulZhang12 is on the right track
…6/BF16 inputs" [ghstack-poisoned]
…6/BF16 inputs" Enable FP32 output from FP16/BF16 GEMMs in aten with cuBLAS. Accumulation for these GEMMs are generally already done in FP32. Adds the functionality to the following aten operators: * mm * bmm * addmm * baddmm Follow up of customer issue: #146241 (comment) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) Pull Request resolved: #150654 Approved by: https://github.com/eellison
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
… choice for mm" As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) [ghstack-poisoned]
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) Pull Request resolved: #150654 Approved by: https://github.com/eellison
Stack from ghstack (oldest at bottom):
Enable FP32 output from FP16/BF16 GEMMs in aten with cuBLAS. Accumulation for these GEMMs are generally already done in FP32. Adds the functionality to the following aten operators:
Follow up of customer issue: #146241 (comment)
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov
Differential Revision: D73126191