Skip to content

Fuse linear projections, SiLU activation, and replace conv1d to reduce kernel launches#18392

Merged
mergennachin merged 1 commit intomainfrom
mnachin/qwen3_5_moe_fused_experts_v2
Mar 23, 2026
Merged

Fuse linear projections, SiLU activation, and replace conv1d to reduce kernel launches#18392
mergennachin merged 1 commit intomainfrom
mnachin/qwen3_5_moe_fused_experts_v2

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Mar 22, 2026

Apply 5 optimizations validated on nano_qwen35_moe:

  1. Fuse SiLU into MoE GEMM2: new _fused_moe_silu_kernel reads gate+up
    from GEMM1 output and applies SiLU on-the-fly during GEMM2,
    eliminating the intermediate buffer and 1 kernel launch per layer.

  2. Fuse QKV projections in full attention: separate q_proj, k_proj,
    v_proj replaced with single qkv_proj. Saves 2 kernel launches per
    full attention layer (10 layers = 20 launches).

  3. Fuse GDN input projections: separate in_proj_qkv, in_proj_z,
    in_proj_b, in_proj_a replaced with single in_proj. Saves 3 kernel
    launches per GDN layer (30 layers = 90 launches).

  4. Fuse gate+up in shared expert: separate gate_proj, up_proj replaced
    with single gate_up_proj. Saves 1 kernel launch per layer (40
    layers = 40 launches).

  5. Replace F.conv1d with manual depthwise conv (4 slice-multiply-adds).
    The conv1d->conv2d decomposition generated a catastrophically slow
    Triton kernel (2.1ms/call at 8192 channels). The manual approach
    produces simple element-wise ops that Inductor fuses efficiently.
    Eliminates 81.8% of decode CUDA time (64ms -> 4ms per step).

Before:

Decode latency: 12.41 tokens/s
Prefill latency: 47.3 tokens/s

After

Decode latency: 58.5 tokens/s
Prefill latency: 96 tokens/s

on A100

@mergennachin mergennachin requested a review from lucylq as a code owner March 22, 2026 19:43
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 22, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18392

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 22, 2026
@mergennachin mergennachin force-pushed the mnachin/qwen3_5_moe_fused_experts_v2 branch from 8e73144 to 0a92191 Compare March 22, 2026 19:44
@digantdesai
Copy link
Copy Markdown
Contributor

Nit: Can you update the summary with before and after tps for prefill and decode?

Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@mergennachin mergennachin force-pushed the mnachin/qwen3_5_moe_fused_experts branch from 12b2e74 to 08543c6 Compare March 23, 2026 19:39
Base automatically changed from mnachin/qwen3_5_moe_fused_experts to main March 23, 2026 20:27
Copilot AI review requested due to automatic review settings March 23, 2026 20:35
@mergennachin mergennachin force-pushed the mnachin/qwen3_5_moe_fused_experts_v2 branch from 0a92191 to 8d6f13a Compare March 23, 2026 20:35
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the Qwen3.5 MoE example model to reduce CUDA kernel launches and improve decode/prefill throughput by fusing multiple linear projections, folding SiLU into the MoE GEMM2 kernel, and replacing a problematic conv1d lowering path with a manual depthwise convolution formulation.

Changes:

  • Fuse attention Q/K/V (+ gate) projections and GatedDeltaNet input projections into single nn.Linear calls, plus fuse shared-expert gate+up projection.
  • Add post-load checkpoint weight fusion to concatenate separate HF projection weights into the new fused parameter layouts.
  • Replace F.conv1d usage in GatedDeltaNet with a manual depthwise conv implementation and remove the export-time conv1d -> conv2d decomposition.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
examples/models/qwen3_5_moe/model.py Fused projection modules, manual depthwise conv path, and checkpoint weight post-processing to match new fused parameter layouts.
examples/models/qwen3_5_moe/export.py Removes now-unneeded conv1d decomposition during export.
backends/cuda/triton/kernels/fused_moe.py Introduces a new Triton GEMM2 kernel that fuses SiLU(gate)*up activation into the down-projection.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/qwen3_5_moe/model.py Outdated
Comment on lines 365 to 368
qkv_conv = sum(
conv_input[:, :, k : k + T_conv] * w[:, k : k + 1]
for k in range(self.conv_kernel_size)
)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manual depthwise conv accumulates in the input dtype (likely bf16) and uses Python sum(...) over a generator. This changes numerical behavior vs F.conv1d (which typically accumulates in fp32) and can also be brittle for torch.export/Dynamo tracing. Consider accumulating in fp32 (then cast back) and using an explicit loop or a stack+sum to avoid Python sum with an implicit 0 start value.

Suggested change
qkv_conv = sum(
conv_input[:, :, k : k + T_conv] * w[:, k : k + 1]
for k in range(self.conv_kernel_size)
)
# Accumulate depthwise conv in fp32 for numerical stability and export-friendliness
conv_input_f32 = conv_input.to(torch.float32)
w_f32 = w.to(torch.float32)
qkv_conv = torch.zeros(
conv_input_f32.size(0),
conv_input_f32.size(1),
T_conv,
dtype=torch.float32,
device=conv_input_f32.device,
)
for k in range(self.conv_kernel_size):
qkv_conv += conv_input_f32[:, :, k : k + T_conv] * w_f32[:, k : k + 1]
qkv_conv = qkv_conv.to(conv_input.dtype)

Copilot uses AI. Check for mistakes.
Comment on lines +678 to +689
if layer_type == "full_attention":
# Fuse Q + K + V into qkv_proj
q_key = f"layers.{i}.attn._q_proj.weight"
k_key = f"layers.{i}.attn._k_proj.weight"
v_key = f"layers.{i}.attn._v_proj.weight"
if q_key in state_dict:
state_dict[f"layers.{i}.attn.qkv_proj.weight"] = torch.cat(
[
state_dict.pop(q_key),
state_dict.pop(k_key),
state_dict.pop(v_key),
],
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _fuse_projection_weights, checking only q_key in state_dict and then unconditionally pop-ing k_key/v_key can raise KeyError on partially present/filtered checkpoints. Please guard on all required keys (q/k/v) before popping, or handle missing keys with a clear error/warning and skip fusion for that layer.

Copilot uses AI. Check for mistakes.
Comment on lines +692 to +705
else:
# Fuse GDN in_proj_qkv + in_proj_z + in_proj_b + in_proj_a
qkv_key = f"layers.{i}.attn._in_proj_qkv.weight"
z_key = f"layers.{i}.attn._in_proj_z.weight"
b_key = f"layers.{i}.attn._in_proj_b.weight"
a_key = f"layers.{i}.attn._in_proj_a.weight"
if qkv_key in state_dict:
state_dict[f"layers.{i}.attn.in_proj.weight"] = torch.cat(
[
state_dict.pop(qkv_key),
state_dict.pop(z_key),
state_dict.pop(b_key),
state_dict.pop(a_key),
],
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, the GatedDeltaNet fusion path checks only _in_proj_qkv.weight but then unconditionally pops _in_proj_z/_b/_a. If any are missing, loading will crash with KeyError. Please verify all four keys exist before popping (or make the fusion conditional per-key with an explicit diagnostic).

Copilot uses AI. Check for mistakes.
# Fuse shared expert gate + up into gate_up_proj
gate_key = f"layers.{i}.mlp.shared_expert._gate_proj.weight"
up_key = f"layers.{i}.mlp.shared_expert._up_proj.weight"
if gate_key in state_dict:
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shared-expert fusion checks only _gate_proj.weight and then pops _up_proj.weight unconditionally. This can KeyError if only one tensor is present (e.g., stripped/partial checkpoints). Please require both keys before fusing, or emit a clear error and leave the original tensors in state_dict.

Suggested change
if gate_key in state_dict:
if gate_key in state_dict and up_key in state_dict:

Copilot uses AI. Check for mistakes.
…e kernel launches

Apply 5 optimizations validated on nano_qwen35_moe:

1. Fuse SiLU into MoE GEMM2: new _fused_moe_silu_kernel reads gate+up
   from GEMM1 output and applies SiLU on-the-fly during GEMM2,
   eliminating the intermediate buffer and 1 kernel launch per layer.

2. Fuse QKV projections in full attention: separate q_proj, k_proj,
   v_proj replaced with single qkv_proj. Saves 2 kernel launches per
   full attention layer (10 layers = 20 launches).

3. Fuse GDN input projections: separate in_proj_qkv, in_proj_z,
   in_proj_b, in_proj_a replaced with single in_proj. Saves 3 kernel
   launches per GDN layer (30 layers = 90 launches).

4. Fuse gate+up in shared expert: separate gate_proj, up_proj replaced
   with single gate_up_proj. Saves 1 kernel launch per layer (40
   layers = 40 launches).

5. Replace F.conv1d with manual depthwise conv (4 slice-multiply-adds).
   The conv1d->conv2d decomposition generated a catastrophically slow
   Triton kernel (2.1ms/call at 8192 channels). The manual approach
   produces simple element-wise ops that Inductor fuses efficiently.
   Eliminates 81.8% of decode CUDA time (64ms -> 4ms per step).
@mergennachin mergennachin force-pushed the mnachin/qwen3_5_moe_fused_experts_v2 branch from 8d6f13a to a18252c Compare March 23, 2026 21:47
@mergennachin mergennachin merged commit f479ecf into main Mar 23, 2026
144 of 147 checks passed
@mergennachin mergennachin deleted the mnachin/qwen3_5_moe_fused_experts_v2 branch March 23, 2026 21:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants