-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Open
Labels
performancePerformance-related issuesPerformance-related issues
Description
Proposal to improve performance
Currently, at the start of fused_moe, we pad the hidden dim of the activations to comply with fused moe kernel requirements. This results in a copy following the router GEMM (GPT-OSS, Deepseek). In the captured fx.Graph, it looks something like:
...
mul_22: "bf16[s72, 2880]" = torch.ops.aten.mul.Tensor(convert_element_type_6, arg4_1); convert_element_type_6 = arg4_1 = None
constant_pad_nd: "bf16[s72, 3072]" = torch.ops.aten.constant_pad_nd.default(mul_22, [0, 192], 0.0)
...
Instead, we should write the output of mul_22
into a pre-padded tensor by replacing the sequence of these two operations with an out-of-place mm call (that takes a pre-allocated output tensor as an arg).
Report of performance regression
No response
Misc discussion on performance
No response
Your current environment (if you think it is necessary)
The output of `python collect_env.py`
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
performancePerformance-related issuesPerformance-related issues
Type
Projects
Status
In progress