Skip to content

[CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs #150812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from

Conversation

PaulZhang12
Copy link
Contributor

@PaulZhang12 PaulZhang12 commented Apr 8, 2025

Stack from ghstack (oldest at bottom):

Enable FP32 output from FP16/BF16 GEMMs in aten with cuBLAS. Accumulation for these GEMMs are generally already done in FP32. Adds the functionality to the following aten operators:

  • mm
  • bmm
  • addmm
  • baddmm

Follow up of customer issue: #146241 (comment)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Differential Revision: D73126191

Copy link

pytorch-bot bot commented Apr 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150812

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3dc77af with merge base daf2ccf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

PaulZhang12 added a commit that referenced this pull request Apr 8, 2025
ghstack-source-id: 0bbd38c
Pull Request resolved: #150812
Copy link
Contributor

github-actions bot commented Apr 8, 2025

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

PaulZhang12 added a commit that referenced this pull request Apr 9, 2025
ghstack-source-id: 375228f
Pull Request resolved: #150812
@@ -1375,6 +1375,16 @@
SparseCUDA: bmm_out_sparse_cuda
SparseCsrCUDA: bmm_out_sparse_csr_cuda

- func: bmm.out_dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add optional compute_dtype args?

Tensor out = at::empty({batch1_sizes[0], batch1_sizes[1], batch2_sizes[2]}, batch1.options().dtype(out_dtype));
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha, out_dtype);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point you no longer need out_dtype argument, you can get dtype from out

PaulZhang12 added a commit that referenced this pull request Apr 10, 2025
ghstack-source-id: 702dde1
Pull Request resolved: #150812
@PaulZhang12 PaulZhang12 changed the title FP32 output GEMM from FP16/BF16 inputs [CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs Apr 10, 2025
@PaulZhang12 PaulZhang12 marked this pull request as ready for review April 10, 2025 17:50
@PaulZhang12 PaulZhang12 added the topic: not user facing topic category label Apr 10, 2025
PaulZhang12 added a commit that referenced this pull request Apr 10, 2025
ghstack-source-id: 64f87d8
Pull Request resolved: #150812
(void*)alpha_ptr, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)beta_ptr, c, rocblas_datatype_f16_r, (int)ldc, stridec,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems datatype should be changed here also? Or we should be erroring out earlier for rocm and not have an implementation at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erroring out here is probably best, will do that before

@@ -617,6 +627,67 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
#endif // USE_ROCM
}

template <>
void bgemm_internal_cublas<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a copy paste from <at::Half> specialization with the only difference in c dtype, perhaps factor it out as a helper function and set c dtype with std::is_same

template <>
void bgemm_internal_cublas<at::Half, at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
  bgemm_internal_cublas_helper<at::Half>(CUDABLAS_BGEMM_ARGS_AND_C_DTYPE(at::Half, at::Half));
}

template <>
void bgemm_internal_cublas<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) {
  bgemm_internal_cublas_helper<float>(CUDABLAS_BGEMM_ARGS(at::Half));
}

PaulZhang12 added a commit that referenced this pull request Apr 11, 2025
ghstack-source-id: c530ae1
Pull Request resolved: #150812
PaulZhang12 added a commit that referenced this pull request Apr 11, 2025
ghstack-source-id: 4c3ef15
Pull Request resolved: #150812
PaulZhang12 added a commit that referenced this pull request Apr 11, 2025
ghstack-source-id: 7b383b9
Pull Request resolved: #150812
0, flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
Copy link
Collaborator

@eqy eqy Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this codepath make sense as an inclusion? For half it's a slightly esoteric option to do accumulation in half as well for SKUs that have limited float accumulation performance. For float output, even if it worked it would functionally be no different than a simple cast on top if the accum was done in half IIUC.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I don't know maybe just for API uniformity, even if it doesn't make a practical sense

flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same situation as above, does it make sense to add output dtype as a param when accumulation was done in half?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Half accum for half inputs was something that was supported before, so I think leaving it in is fine, unless there is an explicit reason to remove?

As far as supporting output dtype with half accum, I think even if the functionality can be simplified, it would be better to keep this functionality as opposed to overriding the output dtype argument. Another possibility if we don't want to support this functionality is to throw an error? cc @ngimel for thoughts as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if cublas itself allows this combination, there's no reason to throw an error for it. If on the other hand cublas doesn't support it then throwing an error seems more reasonable than coming up with workarounds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel @eqy will get this combination to throw an error as fp16 accum with fp32 output isn't actually supported in cublas based on docs: https://docs.nvidia.com/cuda/cublas/#cublasgemmstridedbatchedex

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. As @ngimel pointed out, the unfortunately named allowFP16Accumulation and allowBF16Accumulation flags actually control the math mode, and if they are true, we would set e.g.,
cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION) https://docs.nvidia.com/cuda/cublas/#cublassetmathmode. Which I assume since it allows both---would be surprising if it raises a unsupported error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allowFP16AccumulationCuBLAS will set compute type to CUBLAS_COMPUTE_16F, and according to the docs, only fp16 Ctype is compatible with that, so we should throw an error if user-requested ctype is fp32. allowFP16ReductionCuBLAS sets MathMode and should be compatible with different ctypes. So I think @PaulZhang12 is on the right track

PaulZhang12 added a commit that referenced this pull request Apr 14, 2025
ghstack-source-id: ad86b5c
Pull Request resolved: #150812
…6/BF16 inputs"


Enable FP32 output from FP16/BF16 GEMMs in aten with cuBLAS. Accumulation for these GEMMs are generally already done in FP32. Adds the functionality to the following aten operators:
* mm
* bmm
* addmm
* baddmm

Follow up of customer issue: #146241 (comment) 


[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 14, 2025
ghstack-source-id: 6a8c970
Pull Request resolved: #150812
PaulZhang12 added a commit that referenced this pull request Apr 28, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 28, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 28, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 29, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 29, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 29, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 29, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 29, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 29, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Apr 30, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request May 1, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request May 1, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request May 1, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request May 1, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 1, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: #150654
Approved by: https://github.com/eellison
PaulZhang12 added a commit that referenced this pull request May 2, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request May 2, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request May 2, 2025
… choice for mm"


As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request May 2, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />


TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 3, 2025
As a result of adding subgraph as a choice to inductor #149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: #150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: #150654
Approved by: https://github.com/eellison
@github-actions github-actions bot deleted the gh/PaulZhang12/10/head branch May 28, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor release notes: python_frontend python frontend release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants