diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 98f0d7adfa4..1797663bd1d 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -424,6 +424,7 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then test -f "${OUTPUT_DIR}/model.pte" test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd" ls -al "${OUTPUT_DIR}" + exit 0 fi diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index 04e284b5186..98a86698bc4 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -38,6 +38,32 @@ from torch.library import triton_op, wrap_triton +# Autotune configs for GEMM1 (_fused_moe_kernel). +# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions +# (M=1, N=1024, K=2048, 8 experts, group_size=128). +_GEMM1_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5), +] + +# Autotune configs for GEMM2 (_fused_moe_silu_kernel). +# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions +# (M=1, N=2048, K=512, 8 experts, group_size=128). +_GEMM2_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=3), +] + + +@triton.autotune(configs=_GEMM1_CONFIGS, key=["N", "K"]) @triton.jit def _fused_moe_kernel( # Pointers @@ -147,6 +173,7 @@ def _fused_moe_kernel( tl.store(c_ptrs, acc.to(compute_type), mask=n_mask) +@triton.autotune(configs=_GEMM2_CONFIGS, key=["N", "K"]) @triton.jit def _fused_moe_silu_kernel( # Pointers @@ -294,18 +321,19 @@ def fused_moe( N2 = w2.shape[1] # hidden_size num_pairs = M * top_k - BLOCK_SIZE_N = 32 - BLOCK_SIZE_K = 32 - # Flatten topk tensors topk_ids_flat = topk_ids.reshape(-1) topk_weights_flat = topk_weights.reshape(-1) # ---- GEMM1: gate + up projection ---- + # Grid is a lambda because BLOCK_SIZE_N is selected by autotune cache1 = torch.empty( num_pairs, N1, dtype=hidden_states.dtype, device=hidden_states.device ) - grid1 = (num_pairs * triton.cdiv(N1, BLOCK_SIZE_N),) + + def grid1(meta): + return (num_pairs * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),) + wrap_triton(_fused_moe_kernel)[grid1]( hidden_states, w1, @@ -327,8 +355,6 @@ def fused_moe( stride_bsk=w1_scale.stride(2), stride_bsn=w1_scale.stride(1), group_size=group_size, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, MUL_ROUTED_WEIGHT=False, top_k=top_k, compute_type=tl.bfloat16, @@ -338,7 +364,10 @@ def fused_moe( cache3 = torch.empty( num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device ) - grid2 = (num_pairs * triton.cdiv(N2, BLOCK_SIZE_N),) + + def grid2(meta): + return (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),) + wrap_triton(_fused_moe_silu_kernel)[grid2]( cache1, w2, @@ -360,8 +389,6 @@ def fused_moe( stride_bsk=w2_scale.stride(2), stride_bsn=w2_scale.stride(1), group_size=group_size, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, compute_type=tl.bfloat16, )