Skip to content

[Performance]: Fuse padding onto GEMM by making the GEMM out-of-place #24917

@ProExpertProg

Description

@ProExpertProg

Proposal to improve performance

Currently, at the start of fused_moe, we pad the hidden dim of the activations to comply with fused moe kernel requirements. This results in a copy following the router GEMM (GPT-OSS, Deepseek). In the captured fx.Graph, it looks something like:

        ...
        mul_22: "bf16[s72, 2880]" = torch.ops.aten.mul.Tensor(convert_element_type_6, arg4_1);  convert_element_type_6 = arg4_1 = None

        constant_pad_nd: "bf16[s72, 3072]" = torch.ops.aten.constant_pad_nd.default(mul_22, [0, 192], 0.0)
        ...

Instead, we should write the output of mul_22 into a pre-padded tensor by replacing the sequence of these two operations with an out-of-place mm call (that takes a pre-allocated output tensor as an arg).

Report of performance regression

No response

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

The output of `python collect_env.py`

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

Labels

performancePerformance-related issues

Type

No type

Projects

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions