-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor CUTLASS backend] Step 5: Gemm CUTLASS templates #108015
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108015
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 746185f with merge base f9a250c ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: d6120b12bc05f952ca042b6c2a8431836e965db8 Pull Request resolved: #108015
log.debug( | ||
"Got cutlass configs: total number of ops: %s, " | ||
"total number of 3x ops: %s, total number of 2x ops: %s", | ||
str(len(res)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use %d%
formatting instead of using str()
torch/_inductor/kernel/mm.py
Outdated
choices, | ||
op=op, | ||
) | ||
log.debug("Added %s cutlass gemm configs.", str(len(ops))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log.debug("Added %s cutlass gemm configs.", str(len(ops))) | |
log.debug("Added %d cutlass gemm configs.", len(ops)) |
torch/_inductor/kernel/mm.py
Outdated
choices, | ||
op=op, | ||
) | ||
log.debug("Added %s cutlass gemm configs.", str(len(ops))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log.debug("Added %s cutlass gemm configs.", str(len(ops))) | |
log.debug("Added %d cutlass gemm configs.", len(ops)) |
Hi @atalman, tests are broken due to cutlass python scripts not installed in the torch package used in unittests. Wonder do you know how to pack cutlass python scripts correctly in a torch package? Thanks! |
This is the step 5 to add cutlass as an alternative inductor backend. Feature request: #106991. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
ghstack-source-id: bde629beeab60212233adeeb1335f3fff78551c8 Pull Request resolved: #108015
# TODO: Enable dynamic test cases when dynamic support is added. | ||
@unittest.skipIf(not SM75OrLater, "need sm_75") | ||
@parametrize("dynamic", (False,)) | ||
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen, Triton, CUTLASS")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have control over the generated back-end in the second case, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, the final kernel is selected based on autotune result.
int64_t batch_stride_x = {{kernel.stride(X, -3)}}; | ||
int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; | ||
int64_t batch_stride_w = {{kernel.stride(W, -3)}}; | ||
int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; | ||
int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; | ||
int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; | ||
int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; | ||
int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need these defined separately, given that each line they're used below has a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this has better readability, though it's just a personal choice.
if op.gemm_kind in { | ||
cutlass_lib.GemmKind.Sparse, | ||
cutlass_lib.GemmKind.Grouped, | ||
}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we instead explicitly list the types we support (and return otherwise)? As written, this may break if a new cutlass_lib.GemmKind
is added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing implementation copies from AIT and is very hacky. It actually relies on GemmKind like Planar, Complex, etc, which doesn't seem make sense. However if I only use GemmKind.Gemm there is no config left. So for now I'd like to keep it as it is. CUTLASS folks suggest using the new Python cutlass interface to generate configs, which might be a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe, AIT uses an allow-list approach accepting only GemmKind.Universal
and GemmKind.Universal3x
: https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/backend/cuda/gemm_universal/common.py#L823-L826 . I guess, this should work here, too? GemmKind.Gemm
must be something pre-GemmUniversal
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I tried "GemmKind.Gemm" because we are replacing Gemm with GemmUniversal in generated code. But GemmKind.Gemm doesn't work well. Now I see that AIT actually uses GemmKind.GemmUniversal to emit GemmInstance, which I missed before. Thanks for pointing this out!
This is the step 5 to add cutlass as an alternative inductor backend. Feature request: #106991. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
ghstack-source-id: 1a95695d26ec1c79f43ad73aa3fd12ece6892c88 Pull Request resolved: #108015
This is the step 5 to add cutlass as an alternative inductor backend. Feature request: #106991. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
ghstack-source-id: 17f47762fd01fa9e97e303ab3ba0cd87755ca113 Pull Request resolved: #108015
This is the step 5 to add cutlass as an alternative inductor backend. Feature request: #106991. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
ghstack-source-id: 6778a6dc926bd381979b450ed53072f1fe7b2ef7 Pull Request resolved: #108015
This is the step 5 to add cutlass as an alternative inductor backend. Feature request: #106991. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
ghstack-source-id: 0a5fed79bacae562fec585dcf1d89e6d65aef261 Pull Request resolved: #108015
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…TLASS backend] Epilogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…pilogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: ghstack-source-id: 5d666530e9c23448a6d1ffb6c6a8903a9e04f2e6 Pull Request resolved: #110890
…SS backend] Epilogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…ogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: ghstack-source-id: 9499f3e1e073fe188e56914106aa27d74aac6186 Pull Request resolved: #110890
…ductor CUTLASS backend] Epilogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…ackend] Epilogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: ghstack-source-id: 84219ff2f5eb6c60e4bfaa8dd3db7cdb99f55e46 Pull Request resolved: #110890
… codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: ghstack-source-id: a7536e686e17ef016adc45211fcde5effd9f82ed Pull Request resolved: #110890
… codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: ghstack-source-id: 31e0ecdadf693cbb4be87934b75e5e7fb0e366a5 Pull Request resolved: #110890
… codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…or CUTLASS backend] Epilogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…nd] Epilogue fusion codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… codegen (Step 1)" Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161) Pull Request resolved: #110890 Approved by: https://github.com/jansel ghstack dependencies: #112762
…110890) Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([pytorch#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161) Pull Request resolved: pytorch#110890 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112762
…110890) Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([pytorch#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161) Pull Request resolved: pytorch#110890 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112762
This is the step 5 to add cutlass as an alternative inductor backend.
Feature request: #106991.
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov