Skip to content

[RFC] Add CPP Grouped GEMM Template for Inductor CPU #144012

@leslie-fang-intel

Description

@leslie-fang-intel

🚀 The feature, motivation and pitch

Motivation

Grouped GEMM is a common pattern in modeling. For example, in the LlamaMLP module (https://github.com/huggingface/transformers/blob/d5aebc64653d09660818109f2fac55b5e1031023/src/transformers/models/llama/modeling_llama.py#L187-L188), the gate_proj and up_proj layers have the same dimensions and share the same activation. After gate_proj, an activation function is applied, and the resulting of activation is multiplied by up_proj to compute the final output. Fusing the gate_proj and up_proj layers into a Grouped GEMM improves memory locality when applying activation and multiplication operations. In this RFC, we propose the approaches to implemente this Grouped GEMM optimization.

Approaches

We propose to implement the Grouped GEMM optimization with CPP Template as it's more flexible to support different GEMM number and different epilogue fusions. Here are the proposed design of some key components.

Pattern Matcher

We introduce grouped_gemm_pass to find the pattern of a anchor node (which is the activation shared by Grouped GEMM) and a Group of GEMMs. Replace this pattern with grouped_gemm_lowering lowering function and further lowering into GEMM Template.

We also evaluate the MultiOutputPattern to enable the pattern matcher and fusion in post-grad fusion passes. Current limitation is the MultiOutputPattern requires fixed number of output nodes when define the pattern.

Inductor Lowering

After lowering into Grouped GEMM Template, most of the flow are same as standard template. The only extension is the Grouped GEMM Template may have multi output nodes. We define the template node with MultiOutputLayout and multi output buffers with MultiOutput (each is corresponding to a GEMM output).

Inductor Scheduler Nodes Fusions

In the scheduler node fusion phase,

  • Firstly, we fuse the template node (layout of MultiOutputLayout) and each GEMM output (MultiOutput) into a FusedSchedulerNode.
  • Then, we further fuse this FusedSchedulerNode with it's epilogues, etc silu, mul, relu.

After this phase, we have the FusedSchedulerNode with Grouped GEMM and its epilogues. Next, we will do the code generation within CPP Backend into CPP Grouped GEMM Template.

CPP Grouped GEMM Template

We define a CPP Grouped GEMM Template which extends current CPP GEMM Template implementation with:

  • Flexible number of GEMMs
  • Each GEMM can have independent or shared activations
  • Each GEMM can have a unique weight but same sizes
  • Each GEMM can have a unique bias or None
  • Each GEMM have its own epilogues

Specifically, we introduce a CppGroupedGemmTemplate class that inherits from CppGemmTemplate. Key methods, such as add_choices and render, are overridden to support the aforementioned features.

Alternatives

No response

Additional context

No response

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov

Metadata

Metadata

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions