Fuse linear projections, SiLU activation, and replace conv1d to reduce kernel launches#18392
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18392
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8e73144 to
0a92191
Compare
|
Nit: Can you update the summary with before and after tps for prefill and decode? |
12b2e74 to
08543c6
Compare
0a92191 to
8d6f13a
Compare
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR optimizes the Qwen3.5 MoE example model to reduce CUDA kernel launches and improve decode/prefill throughput by fusing multiple linear projections, folding SiLU into the MoE GEMM2 kernel, and replacing a problematic conv1d lowering path with a manual depthwise convolution formulation.
Changes:
- Fuse attention Q/K/V (+ gate) projections and GatedDeltaNet input projections into single
nn.Linearcalls, plus fuse shared-expert gate+up projection. - Add post-load checkpoint weight fusion to concatenate separate HF projection weights into the new fused parameter layouts.
- Replace
F.conv1dusage in GatedDeltaNet with a manual depthwise conv implementation and remove the export-timeconv1d -> conv2ddecomposition.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| examples/models/qwen3_5_moe/model.py | Fused projection modules, manual depthwise conv path, and checkpoint weight post-processing to match new fused parameter layouts. |
| examples/models/qwen3_5_moe/export.py | Removes now-unneeded conv1d decomposition during export. |
| backends/cuda/triton/kernels/fused_moe.py | Introduces a new Triton GEMM2 kernel that fuses SiLU(gate)*up activation into the down-projection. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| qkv_conv = sum( | ||
| conv_input[:, :, k : k + T_conv] * w[:, k : k + 1] | ||
| for k in range(self.conv_kernel_size) | ||
| ) |
There was a problem hiding this comment.
The manual depthwise conv accumulates in the input dtype (likely bf16) and uses Python sum(...) over a generator. This changes numerical behavior vs F.conv1d (which typically accumulates in fp32) and can also be brittle for torch.export/Dynamo tracing. Consider accumulating in fp32 (then cast back) and using an explicit loop or a stack+sum to avoid Python sum with an implicit 0 start value.
| qkv_conv = sum( | |
| conv_input[:, :, k : k + T_conv] * w[:, k : k + 1] | |
| for k in range(self.conv_kernel_size) | |
| ) | |
| # Accumulate depthwise conv in fp32 for numerical stability and export-friendliness | |
| conv_input_f32 = conv_input.to(torch.float32) | |
| w_f32 = w.to(torch.float32) | |
| qkv_conv = torch.zeros( | |
| conv_input_f32.size(0), | |
| conv_input_f32.size(1), | |
| T_conv, | |
| dtype=torch.float32, | |
| device=conv_input_f32.device, | |
| ) | |
| for k in range(self.conv_kernel_size): | |
| qkv_conv += conv_input_f32[:, :, k : k + T_conv] * w_f32[:, k : k + 1] | |
| qkv_conv = qkv_conv.to(conv_input.dtype) |
| if layer_type == "full_attention": | ||
| # Fuse Q + K + V into qkv_proj | ||
| q_key = f"layers.{i}.attn._q_proj.weight" | ||
| k_key = f"layers.{i}.attn._k_proj.weight" | ||
| v_key = f"layers.{i}.attn._v_proj.weight" | ||
| if q_key in state_dict: | ||
| state_dict[f"layers.{i}.attn.qkv_proj.weight"] = torch.cat( | ||
| [ | ||
| state_dict.pop(q_key), | ||
| state_dict.pop(k_key), | ||
| state_dict.pop(v_key), | ||
| ], |
There was a problem hiding this comment.
In _fuse_projection_weights, checking only q_key in state_dict and then unconditionally pop-ing k_key/v_key can raise KeyError on partially present/filtered checkpoints. Please guard on all required keys (q/k/v) before popping, or handle missing keys with a clear error/warning and skip fusion for that layer.
| else: | ||
| # Fuse GDN in_proj_qkv + in_proj_z + in_proj_b + in_proj_a | ||
| qkv_key = f"layers.{i}.attn._in_proj_qkv.weight" | ||
| z_key = f"layers.{i}.attn._in_proj_z.weight" | ||
| b_key = f"layers.{i}.attn._in_proj_b.weight" | ||
| a_key = f"layers.{i}.attn._in_proj_a.weight" | ||
| if qkv_key in state_dict: | ||
| state_dict[f"layers.{i}.attn.in_proj.weight"] = torch.cat( | ||
| [ | ||
| state_dict.pop(qkv_key), | ||
| state_dict.pop(z_key), | ||
| state_dict.pop(b_key), | ||
| state_dict.pop(a_key), | ||
| ], |
There was a problem hiding this comment.
Similarly, the GatedDeltaNet fusion path checks only _in_proj_qkv.weight but then unconditionally pops _in_proj_z/_b/_a. If any are missing, loading will crash with KeyError. Please verify all four keys exist before popping (or make the fusion conditional per-key with an explicit diagnostic).
| # Fuse shared expert gate + up into gate_up_proj | ||
| gate_key = f"layers.{i}.mlp.shared_expert._gate_proj.weight" | ||
| up_key = f"layers.{i}.mlp.shared_expert._up_proj.weight" | ||
| if gate_key in state_dict: |
There was a problem hiding this comment.
Shared-expert fusion checks only _gate_proj.weight and then pops _up_proj.weight unconditionally. This can KeyError if only one tensor is present (e.g., stripped/partial checkpoints). Please require both keys before fusing, or emit a clear error and leave the original tensors in state_dict.
| if gate_key in state_dict: | |
| if gate_key in state_dict and up_key in state_dict: |
…e kernel launches Apply 5 optimizations validated on nano_qwen35_moe: 1. Fuse SiLU into MoE GEMM2: new _fused_moe_silu_kernel reads gate+up from GEMM1 output and applies SiLU on-the-fly during GEMM2, eliminating the intermediate buffer and 1 kernel launch per layer. 2. Fuse QKV projections in full attention: separate q_proj, k_proj, v_proj replaced with single qkv_proj. Saves 2 kernel launches per full attention layer (10 layers = 20 launches). 3. Fuse GDN input projections: separate in_proj_qkv, in_proj_z, in_proj_b, in_proj_a replaced with single in_proj. Saves 3 kernel launches per GDN layer (30 layers = 90 launches). 4. Fuse gate+up in shared expert: separate gate_proj, up_proj replaced with single gate_up_proj. Saves 1 kernel launch per layer (40 layers = 40 launches). 5. Replace F.conv1d with manual depthwise conv (4 slice-multiply-adds). The conv1d->conv2d decomposition generated a catastrophically slow Triton kernel (2.1ms/call at 8192 channels). The manual approach produces simple element-wise ops that Inductor fuses efficiently. Eliminates 81.8% of decode CUDA time (64ms -> 4ms per step).
8d6f13a to
a18252c
Compare
Apply 5 optimizations validated on nano_qwen35_moe:
Fuse SiLU into MoE GEMM2: new _fused_moe_silu_kernel reads gate+up
from GEMM1 output and applies SiLU on-the-fly during GEMM2,
eliminating the intermediate buffer and 1 kernel launch per layer.
Fuse QKV projections in full attention: separate q_proj, k_proj,
v_proj replaced with single qkv_proj. Saves 2 kernel launches per
full attention layer (10 layers = 20 launches).
Fuse GDN input projections: separate in_proj_qkv, in_proj_z,
in_proj_b, in_proj_a replaced with single in_proj. Saves 3 kernel
launches per GDN layer (30 layers = 90 launches).
Fuse gate+up in shared expert: separate gate_proj, up_proj replaced
with single gate_up_proj. Saves 1 kernel launch per layer (40
layers = 40 launches).
Replace F.conv1d with manual depthwise conv (4 slice-multiply-adds).
The conv1d->conv2d decomposition generated a catastrophically slow
Triton kernel (2.1ms/call at 8192 channels). The manual approach
produces simple element-wise ops that Inductor fuses efficiently.
Eliminates 81.8% of decode CUDA time (64ms -> 4ms per step).
Before:
Decode latency: 12.41 tokens/s
Prefill latency: 47.3 tokens/s
After
Decode latency: 58.5 tokens/s
Prefill latency: 96 tokens/s
on A100