Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1) (#110890)
Summary: This PR adds epilogue fusion code generation support for the new experimental [Inductor Cutlass backend]([#108015]). Details: A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and performs the computation of the fused Pointwise / Elementwise computation nodes. This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu), which is currently the only documentation and example of Cutlass Epilogue Visitor Trees. This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform. A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments to each of the functor subexpressions. Step 1 functionality: * End to end code generation is possible using the above approach. * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants ) after a matmul. * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc. * Examples / Unit tests include ReLU and ReLU6 fusion. * Support for fp16 and fp16 with fp32 accumulation data types. * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 ) The following is not yet supported, and is left for future work: * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers ) * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented ) * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature * Add support for additional (auxiliary) outputs ( requires support for full computation graphs ) * Add support for reduction operations and operations which use different output layouts than the input * Add support for additional dtypes ( as far as Cutlass allows ) This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features for the inductor backend. See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2 Notable changes in Cutlass 3.2.1 include: * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to prevent namespace clashes without resolving to monkey-patching ( which was done earlier ). * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet ) * Small API changes to the cutlass_library API ( requires adapting the inductor backend code ) Notable changes in Cutlass 3.2.2 include: * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention Test Plan: * CI * pytest test/inductor/test_max_autotune.py Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled. Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161) Pull Request resolved: #110890 Approved by: https://github.com/jansel ghstack dependencies: #112762
- Loading branch information