Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 108 additions & 12 deletions backends/cuda/triton/kernels/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,109 @@ def _fused_moe_kernel(
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)


@triton.jit
def _fused_moe_silu_kernel(
# Pointers
A, # [M * top_k, 2*inter] bf16 GEMM1 output (gate | up)
B, # [E, N, K//2] int8 packed INT4 weights
C, # [M * top_k, N] bf16 output
B_scale, # [E, N, K//group_size] bf16 scales
topk_ids, # [M * top_k] int64 expert indices
topk_weights, # [M * top_k] float32 router weights
# Dimensions
N: tl.constexpr,
K: tl.constexpr, # intermediate_size
num_token_expert_pairs,
# Strides
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,
stride_bsn,
# Config
group_size: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
compute_type: tl.constexpr,
):
"""GEMM2 with fused SiLU activation.

Reads gate and up columns from GEMM1 output (A), applies SiLU(gate)*up
on-the-fly, and multiplies by INT4 w2 weights. Router weights are applied
to the output. Eliminates the intermediate activation buffer.
"""
pid = tl.program_id(0)
num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N)
pair_idx = pid // num_n_blocks
n_block = pid % num_n_blocks

if pair_idx >= num_token_expert_pairs:
return

expert_id = tl.load(topk_ids + pair_idx).to(tl.int64)

offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
n_mask = offs_n < N
offs_k = tl.arange(0, BLOCK_SIZE_K)

# A pointers: gate at columns [0, K), up at columns [K, 2*K)
a_gate_ptrs = A + pair_idx * stride_am + offs_k * stride_ak
a_up_ptrs = a_gate_ptrs + K * stride_ak

# B pointer: [expert_id, offs_n, offs_k//2]
b_ptrs = (
B
+ expert_id * stride_be
+ (offs_k[:, None] // 2) * stride_bk
+ offs_n[None, :] * stride_bn
)
b_shifter = (offs_k[:, None] % 2) * 4

acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)

for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_remaining = K - k_step * BLOCK_SIZE_K
k_mask = offs_k < k_remaining

# Load gate and up, apply SiLU(gate) * up
gate = tl.load(a_gate_ptrs, mask=k_mask, other=0.0).to(tl.float32)
up = tl.load(a_up_ptrs, mask=k_mask, other=0.0)
a = (gate * tl.sigmoid(gate) * up).to(compute_type)

# Load and dequantize INT4 weights
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0)
b = (b >> b_shifter) & 0xF

scale_ptrs = (
B_scale
+ expert_id * stride_bse
+ offs_n[None, :] * stride_bsn
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
)
b_scale = tl.load(
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
).to(tl.float32)

b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type)
acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0)

a_gate_ptrs += BLOCK_SIZE_K * stride_ak
a_up_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk

# Multiply by router weight
weight = tl.load(topk_weights + pair_idx)
acc = acc * weight

c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)


# ---------------------------------------------------------------------------
# triton_op wrapper
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -231,18 +334,13 @@ def fused_moe(
compute_type=tl.bfloat16,
)

# ---- Activation: SiLU(gate) * up ----
gate = cache1[:, :intermediate]
up = cache1[:, intermediate:]
cache2 = torch.nn.functional.silu(gate) * up

# ---- GEMM2: down projection, multiply by router weights ----
# ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ----
cache3 = torch.empty(
num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device
)
grid2 = (num_pairs * triton.cdiv(N2, BLOCK_SIZE_N),)
wrap_triton(_fused_moe_kernel)[grid2](
cache2,
wrap_triton(_fused_moe_silu_kernel)[grid2](
cache1,
w2,
cache3,
w2_scale,
Expand All @@ -251,8 +349,8 @@ def fused_moe(
N=N2,
K=intermediate,
num_token_expert_pairs=num_pairs,
stride_am=cache2.stride(0),
stride_ak=cache2.stride(1),
stride_am=cache1.stride(0),
stride_ak=cache1.stride(1),
stride_be=w2.stride(0),
stride_bk=w2.stride(2),
stride_bn=w2.stride(1),
Expand All @@ -264,8 +362,6 @@ def fused_moe(
group_size=group_size,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type=tl.bfloat16,
)

Expand Down
6 changes: 0 additions & 6 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def export_and_lower(model, config, args):
to_edge_transform_and_lower,
)
from executorch.exir.passes import MemoryPlanningPass
from torch._inductor.decomposition import conv1d_to_conv2d
from torch.export import Dim, export

# Coordinate descent recompiles each kernel trying config perturbations,
Expand All @@ -293,11 +292,6 @@ def export_and_lower(model, config, args):
)
print("Export successful!")

# conv1d → conv2d decomposition (required for CUDA backend)
exported = exported.run_decompositions(
{torch.ops.aten.conv1d.default: conv1d_to_conv2d}
)

# Lower with CUDA backend
print("Lowering to ExecuTorch with CUDA...")
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]
Expand Down
Loading
Loading