Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
test -f "${OUTPUT_DIR}/model.pte"
test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd"
ls -al "${OUTPUT_DIR}"

exit 0
fi

Expand Down
45 changes: 36 additions & 9 deletions backends/cuda/triton/kernels/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@
from torch.library import triton_op, wrap_triton


# Autotune configs for GEMM1 (_fused_moe_kernel).
# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
# (M=1, N=1024, K=2048, 8 experts, group_size=128).
_GEMM1_CONFIGS = [
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
]

# Autotune configs for GEMM2 (_fused_moe_silu_kernel).
# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
# (M=1, N=2048, K=512, 8 experts, group_size=128).
_GEMM2_CONFIGS = [
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=4),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4),
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=3),
]


@triton.autotune(configs=_GEMM1_CONFIGS, key=["N", "K"])
@triton.jit
def _fused_moe_kernel(
# Pointers
Expand Down Expand Up @@ -147,6 +173,7 @@ def _fused_moe_kernel(
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)


@triton.autotune(configs=_GEMM2_CONFIGS, key=["N", "K"])
@triton.jit
def _fused_moe_silu_kernel(
# Pointers
Expand Down Expand Up @@ -294,18 +321,19 @@ def fused_moe(
N2 = w2.shape[1] # hidden_size
num_pairs = M * top_k

BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32

# Flatten topk tensors
topk_ids_flat = topk_ids.reshape(-1)
topk_weights_flat = topk_weights.reshape(-1)

# ---- GEMM1: gate + up projection ----
# Grid is a lambda because BLOCK_SIZE_N is selected by autotune
cache1 = torch.empty(
num_pairs, N1, dtype=hidden_states.dtype, device=hidden_states.device
)
grid1 = (num_pairs * triton.cdiv(N1, BLOCK_SIZE_N),)

def grid1(meta):
return (num_pairs * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),)

wrap_triton(_fused_moe_kernel)[grid1](
hidden_states,
w1,
Expand All @@ -327,8 +355,6 @@ def fused_moe(
stride_bsk=w1_scale.stride(2),
stride_bsn=w1_scale.stride(1),
group_size=group_size,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type=tl.bfloat16,
Expand All @@ -338,7 +364,10 @@ def fused_moe(
cache3 = torch.empty(
num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device
)
grid2 = (num_pairs * triton.cdiv(N2, BLOCK_SIZE_N),)

def grid2(meta):
return (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),)

wrap_triton(_fused_moe_silu_kernel)[grid2](
cache1,
w2,
Expand All @@ -360,8 +389,6 @@ def fused_moe(
stride_bsk=w2_scale.stride(2),
stride_bsn=w2_scale.stride(1),
group_size=group_size,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
compute_type=tl.bfloat16,
)

Expand Down
Loading