-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🚀 The feature, motivation and pitch
Motivation
Grouped GEMM is a common pattern in modeling. For example, in the LlamaMLP
module (https://github.com/huggingface/transformers/blob/d5aebc64653d09660818109f2fac55b5e1031023/src/transformers/models/llama/modeling_llama.py#L187-L188), the gate_proj
and up_proj
layers have the same dimensions and share the same activation. After gate_proj
, an activation function is applied, and the resulting of activation is multiplied by up_proj
to compute the final output. Fusing the gate_proj
and up_proj
layers into a Grouped GEMM improves memory locality when applying activation and multiplication operations. In this RFC, we propose the approaches
to implemente this Grouped GEMM optimization.
Approaches
We propose to implement the Grouped GEMM optimization with CPP Template as it's more flexible to support different GEMM number and different epilogue fusions. Here are the proposed design of some key components.
Pattern Matcher
We introduce grouped_gemm_pass
to find the pattern of a anchor node (which is the activation shared by Grouped GEMM) and a Group of GEMMs. Replace this pattern with grouped_gemm_lowering
lowering function and further lowering into GEMM Template.
We also evaluate the MultiOutputPattern
to enable the pattern matcher and fusion in post-grad fusion passes. Current limitation is the MultiOutputPattern
requires fixed number of output nodes when define the pattern.
Inductor Lowering
After lowering into Grouped GEMM Template, most of the flow are same as standard template. The only extension is the Grouped GEMM Template may have multi output nodes. We define the template node with MultiOutputLayout
and multi output buffers with MultiOutput
(each is corresponding to a GEMM output).
Inductor Scheduler Nodes Fusions
In the scheduler node fusion phase,
- Firstly, we fuse the template node (layout of
MultiOutputLayout
) and each GEMM output (MultiOutput
) into aFusedSchedulerNode
. - Then, we further fuse this
FusedSchedulerNode
with it's epilogues, etcsilu
,mul
,relu
.
After this phase, we have the FusedSchedulerNode
with Grouped GEMM and its epilogues. Next, we will do the code generation within CPP Backend into CPP Grouped GEMM Template.
CPP Grouped GEMM Template
We define a CPP Grouped GEMM Template which extends current CPP GEMM Template implementation with:
- Flexible number of GEMMs
- Each GEMM can have independent or shared activations
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
- Each GEMM have its own epilogues
Specifically, we introduce a CppGroupedGemmTemplate
class that inherits from CppGemmTemplate
. Key methods, such as add_choices
and render
, are overridden to support the aforementioned features.
Alternatives
No response
Additional context
No response
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov