From 664ecb12a84e3d7fe0f0550044fbda80281fdb50 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 27 Mar 2026 10:09:59 -0700 Subject: [PATCH 1/9] Tune MoE kernel block sizes for M=1 decode Change fused_moe kernel config from (N=32, K=32, warps=2, stages=2) to (N=128, K=64, warps=4, stages=3). Benchmarked on A100 for Qwen3.5 MoE dimensions, delivering -33.6% MoE kernel time and -19.6% overall wall clock speedup with zero impact on non-MoE kernels. --- backends/cuda/triton/kernels/fused_moe.py | 29 ++++++++++++++++------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index 04e284b5186..54251586b06 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -37,6 +37,15 @@ import triton.language as tl from torch.library import triton_op, wrap_triton +# Block sizes tuned for M=1 decode on Qwen3.5 MoE dimensions. +# Benchmarked on H100: N=128, K=64, warps=4, stages=3 gives 1.73x speedup +# for GEMM1 (N=1024, K=2048) and 1.42x for GEMM2 (N=2048, K=512) vs the +# original N=32, K=32, warps=2, stages=2 baseline. +_BLOCK_SIZE_N: int = 128 +_BLOCK_SIZE_K: int = 64 +_NUM_WARPS: int = 4 +_NUM_STAGES: int = 3 + @triton.jit def _fused_moe_kernel( @@ -294,18 +303,16 @@ def fused_moe( N2 = w2.shape[1] # hidden_size num_pairs = M * top_k - BLOCK_SIZE_N = 32 - BLOCK_SIZE_K = 32 - # Flatten topk tensors topk_ids_flat = topk_ids.reshape(-1) topk_weights_flat = topk_weights.reshape(-1) # ---- GEMM1: gate + up projection ---- + # Grid is a lambda because BLOCK_SIZE_N is selected by autotune cache1 = torch.empty( num_pairs, N1, dtype=hidden_states.dtype, device=hidden_states.device ) - grid1 = (num_pairs * triton.cdiv(N1, BLOCK_SIZE_N),) + grid1 = (num_pairs * triton.cdiv(N1, _BLOCK_SIZE_N),) wrap_triton(_fused_moe_kernel)[grid1]( hidden_states, w1, @@ -327,18 +334,20 @@ def fused_moe( stride_bsk=w1_scale.stride(2), stride_bsn=w1_scale.stride(1), group_size=group_size, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_N=_BLOCK_SIZE_N, + BLOCK_SIZE_K=_BLOCK_SIZE_K, MUL_ROUTED_WEIGHT=False, top_k=top_k, compute_type=tl.bfloat16, + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, ) # ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ---- cache3 = torch.empty( num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device ) - grid2 = (num_pairs * triton.cdiv(N2, BLOCK_SIZE_N),) + grid2 = (num_pairs * triton.cdiv(N2, _BLOCK_SIZE_N),) wrap_triton(_fused_moe_silu_kernel)[grid2]( cache1, w2, @@ -360,9 +369,11 @@ def fused_moe( stride_bsk=w2_scale.stride(2), stride_bsn=w2_scale.stride(1), group_size=group_size, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_N=_BLOCK_SIZE_N, + BLOCK_SIZE_K=_BLOCK_SIZE_K, compute_type=tl.bfloat16, + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, ) # ---- Sum across top-k experts ---- From 5323c3a2d8f736016912aba4e66084bb0b9b83a3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sun, 29 Mar 2026 17:31:17 -0700 Subject: [PATCH 2/9] Use @triton.autotune for MoE kernels instead of static block sizes Replace hardcoded block sizes with @triton.autotune for both GEMM1 (_fused_moe_kernel) and GEMM2 (_fused_moe_silu_kernel). Each kernel gets its own set of 4 autotune configs derived from standalone benchmarking on A100 with Qwen3.5 MoE dimensions (M=1 decode, INT4 HQQ, group_size=128). GEMM1 configs optimized for N=1024, K=2048. GEMM2 configs optimized for N=2048, K=512. Results vs baseline (N=32, K=32): - GEMM1: 190.1 -> 124.5 us/call (-34.5%) - GEMM2: 87.2 -> 48.2 us/call (-44.7%) - Overall MoE: -37.7% - E2E decode: 45.13 -> 53.86 tok/s (+19.3%) Keeping configs to 4 per kernel avoids the AOTI fatbin OOM issue seen with larger autotune config sets. --- backends/cuda/triton/kernels/fused_moe.py | 44 +++++++++++++---------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index 54251586b06..38cb5309871 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -37,16 +37,29 @@ import triton.language as tl from torch.library import triton_op, wrap_triton -# Block sizes tuned for M=1 decode on Qwen3.5 MoE dimensions. -# Benchmarked on H100: N=128, K=64, warps=4, stages=3 gives 1.73x speedup -# for GEMM1 (N=1024, K=2048) and 1.42x for GEMM2 (N=2048, K=512) vs the -# original N=32, K=32, warps=2, stages=2 baseline. -_BLOCK_SIZE_N: int = 128 -_BLOCK_SIZE_K: int = 64 -_NUM_WARPS: int = 4 -_NUM_STAGES: int = 3 - +# Autotune configs for GEMM1 (_fused_moe_kernel). +# Top performers from standalone benchmark on A100, Qwen3.5 MoE dimensions +# (M=1, N=1024, K=2048, 8 experts, group_size=128). +_GEMM1_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), +] + +# Autotune configs for GEMM2 (_fused_moe_silu_kernel). +# Top performers from standalone benchmark on A100, Qwen3.5 MoE dimensions +# (M=1, N=2048, K=512, 8 experts, group_size=128). +_GEMM2_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=5), +] + + +@triton.autotune(configs=_GEMM1_CONFIGS, key=["N", "K"]) @triton.jit def _fused_moe_kernel( # Pointers @@ -156,6 +169,7 @@ def _fused_moe_kernel( tl.store(c_ptrs, acc.to(compute_type), mask=n_mask) +@triton.autotune(configs=_GEMM2_CONFIGS, key=["N", "K"]) @triton.jit def _fused_moe_silu_kernel( # Pointers @@ -312,7 +326,7 @@ def fused_moe( cache1 = torch.empty( num_pairs, N1, dtype=hidden_states.dtype, device=hidden_states.device ) - grid1 = (num_pairs * triton.cdiv(N1, _BLOCK_SIZE_N),) + grid1 = lambda meta: (num_pairs * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),) wrap_triton(_fused_moe_kernel)[grid1]( hidden_states, w1, @@ -334,20 +348,16 @@ def fused_moe( stride_bsk=w1_scale.stride(2), stride_bsn=w1_scale.stride(1), group_size=group_size, - BLOCK_SIZE_N=_BLOCK_SIZE_N, - BLOCK_SIZE_K=_BLOCK_SIZE_K, MUL_ROUTED_WEIGHT=False, top_k=top_k, compute_type=tl.bfloat16, - num_warps=_NUM_WARPS, - num_stages=_NUM_STAGES, ) # ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ---- cache3 = torch.empty( num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device ) - grid2 = (num_pairs * triton.cdiv(N2, _BLOCK_SIZE_N),) + grid2 = lambda meta: (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),) wrap_triton(_fused_moe_silu_kernel)[grid2]( cache1, w2, @@ -369,11 +379,7 @@ def fused_moe( stride_bsk=w2_scale.stride(2), stride_bsn=w2_scale.stride(1), group_size=group_size, - BLOCK_SIZE_N=_BLOCK_SIZE_N, - BLOCK_SIZE_K=_BLOCK_SIZE_K, compute_type=tl.bfloat16, - num_warps=_NUM_WARPS, - num_stages=_NUM_STAGES, ) # ---- Sum across top-k experts ---- From df548634b7e9816ef9861ac827843ff2bb892620 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sun, 29 Mar 2026 20:21:04 -0700 Subject: [PATCH 3/9] Add baseline (32,32) to autotune configs and GPU diagnostics to CI Include the original default block sizes (N=32, K=32) in both GEMM1 and GEMM2 autotune candidate lists to prevent perf regression on hardware where smaller block sizes are optimal. Add GPU diagnostic output to test_model_e2e.sh to help investigate perf discrepancies between local and CI environments (GPU variant, memory bandwidth, CUDA version, etc.). --- .ci/scripts/test_model_e2e.sh | 15 +++++++++++++++ backends/cuda/triton/kernels/fused_moe.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 85bd327cfc6..008cb9caac2 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -101,6 +101,21 @@ fi echo "Testing model: $HF_MODEL (quantization: $QUANT_NAME)" +# GPU diagnostics — helps compare CI vs local performance +echo "::group::GPU Diagnostics" +if command -v nvidia-smi &> /dev/null; then + nvidia-smi --query-gpu=name,memory.total,pcie.link.gen.max,pcie.link.width.max,clocks.max.sm,clocks.max.mem --format=csv + echo "---" + nvidia-smi -q | grep -E "Product Name|Product Brand|GPU UUID|GPU Part Number|FB Memory Usage|BAR1 Memory Usage|GPU Current Temp|GPU Max Operating Temp|Power Draw|Power Limit|Max Clocks|Clocks$" | head -20 + echo "---" + echo "CUDA version (nvcc):" + nvcc --version 2>/dev/null || echo "nvcc not found" + echo "---" + echo "PyTorch CUDA info:" + python -c "import torch; print(f'torch.version.cuda={torch.version.cuda}'); print(f'torch.cuda.get_device_name()={torch.cuda.get_device_name()}'); print(f'torch.cuda.get_device_properties(0)={torch.cuda.get_device_properties(0)}')" 2>/dev/null || echo "PyTorch not available yet" +fi +echo "::endgroup::" + # Make sure model.pte exists if [ ! -f "$MODEL_DIR/model.pte" ]; then echo "Error: model.pte not found in $MODEL_DIR" diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index 38cb5309871..dea52f843e2 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -42,6 +42,7 @@ # Top performers from standalone benchmark on A100, Qwen3.5 MoE dimensions # (M=1, N=1024, K=2048, 8 experts, group_size=128). _GEMM1_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2), triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3), triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=3), @@ -52,6 +53,7 @@ # Top performers from standalone benchmark on A100, Qwen3.5 MoE dimensions # (M=1, N=2048, K=512, 8 experts, group_size=128). _GEMM2_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2), triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=2), triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=2, num_stages=3), triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), From dacd3a4f7597b1df441c5893249be245fbb21e74 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sun, 29 Mar 2026 23:32:45 -0700 Subject: [PATCH 4/9] Add artifact checksums and version diagnostics to CI export Print md5sum of exported model.pte and aoti_cuda_blob.ptd on CI, along with local reference checksums and PyTorch/Triton/torchao versions, to help diagnose cross-machine perf discrepancies. --- .ci/scripts/export_model_artifact.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 98f0d7adfa4..1dfc02ec737 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -424,6 +424,17 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then test -f "${OUTPUT_DIR}/model.pte" test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd" ls -al "${OUTPUT_DIR}" + + # Diagnostic: print checksums for cross-machine comparison + echo "::group::Artifact checksums" + md5sum "${OUTPUT_DIR}/model.pte" "${OUTPUT_DIR}/aoti_cuda_blob.ptd" + echo "Local reference checksums:" + echo " model.pte: 3b79cbc9d921b6eaa2d655ede993f6a7" + echo " aoti_cuda_blob.ptd: 2c8d0d31004acbd6dc43118eddabf700" + echo "---" + python -c "import torch; print(f'torch={torch.__version__}'); import triton; print(f'triton={triton.__version__}'); import torchao; print(f'torchao={torchao.__version__}')" + echo "::endgroup::" + exit 0 fi From d08fa9f35bb1572445be3a10c0ca1a4c8a4bdaf1 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 30 Mar 2026 00:34:01 -0700 Subject: [PATCH 5/9] Add MoE kernel benchmark to CI export for block size tuning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run a standalone benchmark sweep of (N, K, warps, stages) on the CI GPU during the Qwen3.5 MoE export job. This finds the optimal block sizes for the CI hardware (A100-SXM4-80GB) so we can use them as triton autotune candidates. The benchmark runs before export and is non-fatal — export proceeds even if the benchmark fails. --- .ci/scripts/export_model_artifact.sh | 5 + .ci/scripts/moe_kernel_benchmark.py | 289 +++++++++++++++++++++++++++ 2 files changed, 294 insertions(+) create mode 100644 .ci/scripts/moe_kernel_benchmark.py diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 1dfc02ec737..6a6a7db9ef9 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -413,6 +413,11 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then # Copy tokenizer for the runner cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" + # Run MoE kernel benchmark to find optimal block sizes for this GPU + echo "::group::MoE Kernel Benchmark" + python .ci/scripts/moe_kernel_benchmark.py 2>&1 || echo "Benchmark failed (non-fatal)" + echo "::endgroup::" + # Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues) echo "::group::Export" TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ diff --git a/.ci/scripts/moe_kernel_benchmark.py b/.ci/scripts/moe_kernel_benchmark.py new file mode 100644 index 00000000000..52148d6fe10 --- /dev/null +++ b/.ci/scripts/moe_kernel_benchmark.py @@ -0,0 +1,289 @@ +"""Standalone MoE kernel benchmark for tuning block sizes. + +Sweeps (BLOCK_SIZE_N, BLOCK_SIZE_K, num_warps, num_stages) on the actual +Qwen3.5 MoE dimensions with INT4 quantized weights. + +GEMM1: M=1, N=1024 (2*intermediate), K=2048 (hidden), 8 experts +GEMM2: M=1, N=2048 (hidden), K=512 (intermediate), 8 experts +""" + +import torch +import triton +import triton.language as tl +import itertools + +# Qwen3.5 MoE dimensions +HIDDEN = 2048 +INTERMEDIATE = 512 +NUM_EXPERTS = 256 +TOP_K = 8 +GROUP_SIZE = 128 # HQQ group size + +# GEMM1: N=2*INTERMEDIATE=1024, K=HIDDEN=2048 +# GEMM2: N=HIDDEN=2048, K=INTERMEDIATE=512 + + +@triton.jit +def _fused_moe_kernel( + A, B, C, B_scale, topk_ids, topk_weights, + N: tl.constexpr, K: tl.constexpr, + num_token_expert_pairs, + stride_am, stride_ak, stride_be, stride_bk, stride_bn, + stride_cm, stride_cn, stride_bse, stride_bsk, stride_bsn, + group_size: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, + compute_type: tl.constexpr, +): + pid = tl.program_id(0) + num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) + pair_idx = pid // num_n_blocks + n_block = pid % num_n_blocks + if pair_idx >= num_token_expert_pairs: + return + expert_id = tl.load(topk_ids + pair_idx).to(tl.int64) + token_idx = pair_idx // top_k + offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + n_mask = offs_n < N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + token_idx * stride_am + offs_k * stride_ak + b_ptrs = B + expert_id * stride_be + (offs_k[:, None] // 2) * stride_bk + offs_n[None, :] * stride_bn + b_shifter = (offs_k[:, None] % 2) * 4 + acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k_step * BLOCK_SIZE_K + k_mask = offs_k < k_remaining + a = tl.load(a_ptrs, mask=k_mask, other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) + b = (b >> b_shifter) & 0xF + scale_ptrs = B_scale + expert_id * stride_bse + offs_n[None, :] * stride_bsn + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + b_scale = tl.load(scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.float32) + b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) + acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + if MUL_ROUTED_WEIGHT: + weight = tl.load(topk_weights + pair_idx) + acc = acc * weight + c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn + tl.store(c_ptrs, acc.to(compute_type), mask=n_mask) + + +@triton.jit +def _fused_moe_silu_kernel( + A, B, C, B_scale, topk_ids, topk_weights, + N: tl.constexpr, K: tl.constexpr, + num_token_expert_pairs, + stride_am, stride_ak, stride_be, stride_bk, stride_bn, + stride_cm, stride_cn, stride_bse, stride_bsk, stride_bsn, + group_size: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + compute_type: tl.constexpr, +): + pid = tl.program_id(0) + num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) + pair_idx = pid // num_n_blocks + n_block = pid % num_n_blocks + if pair_idx >= num_token_expert_pairs: + return + expert_id = tl.load(topk_ids + pair_idx).to(tl.int64) + offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + n_mask = offs_n < N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_gate_ptrs = A + pair_idx * stride_am + offs_k * stride_ak + a_up_ptrs = a_gate_ptrs + K * stride_ak + b_ptrs = B + expert_id * stride_be + (offs_k[:, None] // 2) * stride_bk + offs_n[None, :] * stride_bn + b_shifter = (offs_k[:, None] % 2) * 4 + acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k_step * BLOCK_SIZE_K + k_mask = offs_k < k_remaining + gate = tl.load(a_gate_ptrs, mask=k_mask, other=0.0).to(tl.float32) + up = tl.load(a_up_ptrs, mask=k_mask, other=0.0) + a = (gate * tl.sigmoid(gate) * up).to(compute_type) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) + b = (b >> b_shifter) & 0xF + scale_ptrs = B_scale + expert_id * stride_bse + offs_n[None, :] * stride_bsn + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + b_scale = tl.load(scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.float32) + b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) + acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0) + a_gate_ptrs += BLOCK_SIZE_K * stride_ak + a_up_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + weight = tl.load(topk_weights + pair_idx) + acc = acc * weight + c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn + tl.store(c_ptrs, acc.to(compute_type), mask=n_mask) + + +def bench_gemm1(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, stages): + A = torch.randn(1, K, dtype=torch.bfloat16, device='cuda') + B = torch.randint(-128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device='cuda') + C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device='cuda') + B_scale = torch.randn(NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device='cuda') + topk_ids = torch.randint(0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device='cuda') + topk_weights = torch.randn(num_pairs, dtype=torch.float32, device='cuda') + + grid = (num_pairs * triton.cdiv(N, block_n),) + + def run(): + _fused_moe_kernel[grid]( + A, B, C, B_scale, topk_ids, topk_weights, + N=N, K=K, num_token_expert_pairs=num_pairs, + stride_am=A.stride(0), stride_ak=A.stride(1), + stride_be=B.stride(0), stride_bk=B.stride(2), stride_bn=B.stride(1), + stride_cm=C.stride(0), stride_cn=C.stride(1), + stride_bse=B_scale.stride(0), stride_bsk=B_scale.stride(2), stride_bsn=B_scale.stride(1), + group_size=group_size, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, + MUL_ROUTED_WEIGHT=False, top_k=top_k, compute_type=tl.bfloat16, + num_warps=warps, num_stages=stages, + ) + + # Warmup + for _ in range(10): + run() + torch.cuda.synchronize() + + # Benchmark + import time + iters = 200 + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + run() + torch.cuda.synchronize() + t1 = time.perf_counter() + return (t1 - t0) / iters * 1e6 # us + + +def bench_gemm2(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, stages): + A = torch.randn(num_pairs, 2 * K, dtype=torch.bfloat16, device='cuda') + B = torch.randint(-128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device='cuda') + C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device='cuda') + B_scale = torch.randn(NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device='cuda') + topk_ids = torch.randint(0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device='cuda') + topk_weights = torch.randn(num_pairs, dtype=torch.float32, device='cuda') + + grid = (num_pairs * triton.cdiv(N, block_n),) + + def run(): + _fused_moe_silu_kernel[grid]( + A, B, C, B_scale, topk_ids, topk_weights, + N=N, K=K, num_token_expert_pairs=num_pairs, + stride_am=A.stride(0), stride_ak=A.stride(1), + stride_be=B.stride(0), stride_bk=B.stride(2), stride_bn=B.stride(1), + stride_cm=C.stride(0), stride_cn=C.stride(1), + stride_bse=B_scale.stride(0), stride_bsk=B_scale.stride(2), stride_bsn=B_scale.stride(1), + group_size=group_size, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, + compute_type=tl.bfloat16, + num_warps=warps, num_stages=stages, + ) + + for _ in range(10): + run() + torch.cuda.synchronize() + + import time + iters = 200 + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + run() + torch.cuda.synchronize() + t1 = time.perf_counter() + return (t1 - t0) / iters * 1e6 + + +def main(): + N1 = 2 * INTERMEDIATE # 1024 + K1 = HIDDEN # 2048 + N2 = HIDDEN # 2048 + K2 = INTERMEDIATE # 512 + num_pairs = TOP_K # 8 + + # Search space + block_ns = [32, 64, 128, 256] + block_ks = [32, 64, 128, 256] + warp_counts = [2, 4, 8] + stage_counts = [2, 3, 4, 5] + + print(f"GEMM1: M=1, N={N1}, K={K1}, pairs={num_pairs}, group_size={GROUP_SIZE}") + print(f"GEMM2: M=1, N={N2}, K={K2}, pairs={num_pairs}, group_size={GROUP_SIZE}") + print() + + # GEMM1 + print("=== GEMM1 (_fused_moe_kernel) ===") + print(f"{'N':>4} {'K':>4} {'warps':>5} {'stages':>6} {'time_us':>8}") + best1 = (float('inf'), None) + results1 = [] + for bn, bk, w, s in itertools.product(block_ns, block_ks, warp_counts, stage_counts): + # Skip invalid: K must be divisible, and block_k must divide group_size evenly or vice versa + if bk > K1 or bn > N1: + continue + try: + t = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, bn, bk, w, s) + results1.append((t, bn, bk, w, s)) + if t < best1[0]: + best1 = (t, (bn, bk, w, s)) + print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} {t:>8.1f}") + except Exception as e: + print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} FAILED: {e}") + + print(f"\nBest GEMM1: {best1[1]} -> {best1[0]:.1f} us") + + # GEMM2 + print("\n=== GEMM2 (_fused_moe_silu_kernel) ===") + print(f"{'N':>4} {'K':>4} {'warps':>5} {'stages':>6} {'time_us':>8}") + best2 = (float('inf'), None) + results2 = [] + for bn, bk, w, s in itertools.product(block_ns, block_ks, warp_counts, stage_counts): + if bk > K2 or bn > N2: + continue + try: + t = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, bn, bk, w, s) + results2.append((t, bn, bk, w, s)) + if t < best2[0]: + best2 = (t, (bn, bk, w, s)) + print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} {t:>8.1f}") + except Exception as e: + print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} FAILED: {e}") + + print(f"\nBest GEMM2: {best2[1]} -> {best2[0]:.1f} us") + + # Summary + print("\n=== SUMMARY ===") + # Baseline (N=32, K=32, default warps/stages) + t1_base = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, 32, 32, 4, 2) + t2_base = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, 32, 32, 4, 2) + bn1, bk1, w1, s1 = best1[1] + bn2, bk2, w2, s2 = best2[1] + t1_best = best1[0] + t2_best = best2[0] + print(f"Baseline (32,32): GEMM1={t1_base:.1f}us, GEMM2={t2_base:.1f}us, total={t1_base+t2_base:.1f}us") + print(f"Best GEMM1 ({bn1},{bk1},w{w1},s{s1}): {t1_best:.1f}us ({(1-t1_best/t1_base)*100:.1f}% faster)") + print(f"Best GEMM2 ({bn2},{bk2},w{w2},s{s2}): {t2_best:.1f}us ({(1-t2_best/t2_base)*100:.1f}% faster)") + + # If GEMM1 and GEMM2 best configs differ, also show a unified config + if (bn1, bk1, w1, s1) != (bn2, bk2, w2, s2): + # Try best GEMM1 config on GEMM2 and vice versa + t2_with_g1 = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, bn1, bk1, w1, s1) + t1_with_g2 = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, bn2, bk2, w2, s2) + unified_a = t1_best + t2_with_g1 + unified_b = t1_with_g2 + t2_best + print(f"\nUnified option A (GEMM1-best {bn1},{bk1},w{w1},s{s1}): GEMM1={t1_best:.1f}+GEMM2={t2_with_g1:.1f}={unified_a:.1f}us") + print(f"Unified option B (GEMM2-best {bn2},{bk2},w{w2},s{s2}): GEMM1={t1_with_g2:.1f}+GEMM2={t2_best:.1f}={unified_b:.1f}us") + print(f"Separate configs: GEMM1={t1_best:.1f}+GEMM2={t2_best:.1f}={t1_best+t2_best:.1f}us") + + # Top 5 for each + results1.sort() + results2.sort() + print("\nTop 5 GEMM1:") + for t, bn, bk, w, s in results1[:5]: + print(f" N={bn}, K={bk}, warps={w}, stages={s}: {t:.1f} us") + print("\nTop 5 GEMM2:") + for t, bn, bk, w, s in results2[:5]: + print(f" N={bn}, K={bk}, warps={w}, stages={s}: {t:.1f} us") + + +if __name__ == '__main__': + main() From b6e396785d7ae03cced2d03ac27b1656aa5165bc Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 30 Mar 2026 00:50:01 -0700 Subject: [PATCH 6/9] Add MoE kernel benchmark to CI export for block size tuning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run a standalone benchmark sweep of (N, K, warps, stages) on the CI GPU during the Qwen3.5 MoE export job. Uses the actual Triton kernels from executorch.backends.cuda to ensure consistency. Finds optimal block sizes for the CI hardware so we can use them as autotune candidates. Non-fatal — export proceeds even if benchmark fails. --- .ci/scripts/moe_kernel_benchmark.py | 299 +++++++++++++--------------- 1 file changed, 139 insertions(+), 160 deletions(-) diff --git a/.ci/scripts/moe_kernel_benchmark.py b/.ci/scripts/moe_kernel_benchmark.py index 52148d6fe10..f29dcd9805b 100644 --- a/.ci/scripts/moe_kernel_benchmark.py +++ b/.ci/scripts/moe_kernel_benchmark.py @@ -1,16 +1,23 @@ """Standalone MoE kernel benchmark for tuning block sizes. -Sweeps (BLOCK_SIZE_N, BLOCK_SIZE_K, num_warps, num_stages) on the actual -Qwen3.5 MoE dimensions with INT4 quantized weights. +Imports the actual Triton kernels from executorch.backends.cuda and sweeps +(BLOCK_SIZE_N, BLOCK_SIZE_K, num_warps, num_stages) on real Qwen3.5 MoE +dimensions with INT4 quantized weights. GEMM1: M=1, N=1024 (2*intermediate), K=2048 (hidden), 8 experts GEMM2: M=1, N=2048 (hidden), K=512 (intermediate), 8 experts """ +import itertools +import time + import torch import triton import triton.language as tl -import itertools +from executorch.backends.cuda.triton.kernels.fused_moe import ( + _fused_moe_kernel, + _fused_moe_silu_kernel, +) # Qwen3.5 MoE dimensions HIDDEN = 2048 @@ -19,124 +26,52 @@ TOP_K = 8 GROUP_SIZE = 128 # HQQ group size -# GEMM1: N=2*INTERMEDIATE=1024, K=HIDDEN=2048 -# GEMM2: N=HIDDEN=2048, K=INTERMEDIATE=512 - - -@triton.jit -def _fused_moe_kernel( - A, B, C, B_scale, topk_ids, topk_weights, - N: tl.constexpr, K: tl.constexpr, - num_token_expert_pairs, - stride_am, stride_ak, stride_be, stride_bk, stride_bn, - stride_cm, stride_cn, stride_bse, stride_bsk, stride_bsn, - group_size: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, - compute_type: tl.constexpr, -): - pid = tl.program_id(0) - num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) - pair_idx = pid // num_n_blocks - n_block = pid % num_n_blocks - if pair_idx >= num_token_expert_pairs: - return - expert_id = tl.load(topk_ids + pair_idx).to(tl.int64) - token_idx = pair_idx // top_k - offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) - n_mask = offs_n < N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A + token_idx * stride_am + offs_k * stride_ak - b_ptrs = B + expert_id * stride_be + (offs_k[:, None] // 2) * stride_bk + offs_n[None, :] * stride_bn - b_shifter = (offs_k[:, None] % 2) * 4 - acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) - for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - k_remaining = K - k_step * BLOCK_SIZE_K - k_mask = offs_k < k_remaining - a = tl.load(a_ptrs, mask=k_mask, other=0.0) - b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) - b = (b >> b_shifter) & 0xF - scale_ptrs = B_scale + expert_id * stride_bse + offs_n[None, :] * stride_bsn + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk - b_scale = tl.load(scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.float32) - b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) - acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk - if MUL_ROUTED_WEIGHT: - weight = tl.load(topk_weights + pair_idx) - acc = acc * weight - c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn - tl.store(c_ptrs, acc.to(compute_type), mask=n_mask) - - -@triton.jit -def _fused_moe_silu_kernel( - A, B, C, B_scale, topk_ids, topk_weights, - N: tl.constexpr, K: tl.constexpr, - num_token_expert_pairs, - stride_am, stride_ak, stride_be, stride_bk, stride_bn, - stride_cm, stride_cn, stride_bse, stride_bsk, stride_bsn, - group_size: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - compute_type: tl.constexpr, -): - pid = tl.program_id(0) - num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) - pair_idx = pid // num_n_blocks - n_block = pid % num_n_blocks - if pair_idx >= num_token_expert_pairs: - return - expert_id = tl.load(topk_ids + pair_idx).to(tl.int64) - offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) - n_mask = offs_n < N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_gate_ptrs = A + pair_idx * stride_am + offs_k * stride_ak - a_up_ptrs = a_gate_ptrs + K * stride_ak - b_ptrs = B + expert_id * stride_be + (offs_k[:, None] // 2) * stride_bk + offs_n[None, :] * stride_bn - b_shifter = (offs_k[:, None] % 2) * 4 - acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) - for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - k_remaining = K - k_step * BLOCK_SIZE_K - k_mask = offs_k < k_remaining - gate = tl.load(a_gate_ptrs, mask=k_mask, other=0.0).to(tl.float32) - up = tl.load(a_up_ptrs, mask=k_mask, other=0.0) - a = (gate * tl.sigmoid(gate) * up).to(compute_type) - b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) - b = (b >> b_shifter) & 0xF - scale_ptrs = B_scale + expert_id * stride_bse + offs_n[None, :] * stride_bsn + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk - b_scale = tl.load(scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.float32) - b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) - acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0) - a_gate_ptrs += BLOCK_SIZE_K * stride_ak - a_up_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk - weight = tl.load(topk_weights + pair_idx) - acc = acc * weight - c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn - tl.store(c_ptrs, acc.to(compute_type), mask=n_mask) - def bench_gemm1(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, stages): - A = torch.randn(1, K, dtype=torch.bfloat16, device='cuda') - B = torch.randint(-128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device='cuda') - C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device='cuda') - B_scale = torch.randn(NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device='cuda') - topk_ids = torch.randint(0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device='cuda') - topk_weights = torch.randn(num_pairs, dtype=torch.float32, device='cuda') + A = torch.randn(1, K, dtype=torch.bfloat16, device="cuda") + B = torch.randint( + -128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device="cuda" + ) + C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device="cuda") + B_scale = torch.randn( + NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device="cuda" + ) + topk_ids = torch.randint( + 0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device="cuda" + ) + topk_weights = torch.randn(num_pairs, dtype=torch.float32, device="cuda") grid = (num_pairs * triton.cdiv(N, block_n),) def run(): _fused_moe_kernel[grid]( - A, B, C, B_scale, topk_ids, topk_weights, - N=N, K=K, num_token_expert_pairs=num_pairs, - stride_am=A.stride(0), stride_ak=A.stride(1), - stride_be=B.stride(0), stride_bk=B.stride(2), stride_bn=B.stride(1), - stride_cm=C.stride(0), stride_cn=C.stride(1), - stride_bse=B_scale.stride(0), stride_bsk=B_scale.stride(2), stride_bsn=B_scale.stride(1), - group_size=group_size, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, - MUL_ROUTED_WEIGHT=False, top_k=top_k, compute_type=tl.bfloat16, - num_warps=warps, num_stages=stages, + A, + B, + C, + B_scale, + topk_ids, + topk_weights, + N=N, + K=K, + num_token_expert_pairs=num_pairs, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_be=B.stride(0), + stride_bk=B.stride(2), + stride_bn=B.stride(1), + stride_cm=C.stride(0), + stride_cn=C.stride(1), + stride_bse=B_scale.stride(0), + stride_bsk=B_scale.stride(2), + stride_bsn=B_scale.stride(1), + group_size=group_size, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type=tl.bfloat16, + num_warps=warps, + num_stages=stages, ) # Warmup @@ -145,7 +80,6 @@ def run(): torch.cuda.synchronize() # Benchmark - import time iters = 200 torch.cuda.synchronize() t0 = time.perf_counter() @@ -157,33 +91,54 @@ def run(): def bench_gemm2(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, stages): - A = torch.randn(num_pairs, 2 * K, dtype=torch.bfloat16, device='cuda') - B = torch.randint(-128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device='cuda') - C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device='cuda') - B_scale = torch.randn(NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device='cuda') - topk_ids = torch.randint(0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device='cuda') - topk_weights = torch.randn(num_pairs, dtype=torch.float32, device='cuda') + A = torch.randn(num_pairs, 2 * K, dtype=torch.bfloat16, device="cuda") + B = torch.randint( + -128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device="cuda" + ) + C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device="cuda") + B_scale = torch.randn( + NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device="cuda" + ) + topk_ids = torch.randint( + 0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device="cuda" + ) + topk_weights = torch.randn(num_pairs, dtype=torch.float32, device="cuda") grid = (num_pairs * triton.cdiv(N, block_n),) def run(): _fused_moe_silu_kernel[grid]( - A, B, C, B_scale, topk_ids, topk_weights, - N=N, K=K, num_token_expert_pairs=num_pairs, - stride_am=A.stride(0), stride_ak=A.stride(1), - stride_be=B.stride(0), stride_bk=B.stride(2), stride_bn=B.stride(1), - stride_cm=C.stride(0), stride_cn=C.stride(1), - stride_bse=B_scale.stride(0), stride_bsk=B_scale.stride(2), stride_bsn=B_scale.stride(1), - group_size=group_size, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, + A, + B, + C, + B_scale, + topk_ids, + topk_weights, + N=N, + K=K, + num_token_expert_pairs=num_pairs, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_be=B.stride(0), + stride_bk=B.stride(2), + stride_bn=B.stride(1), + stride_cm=C.stride(0), + stride_cn=C.stride(1), + stride_bse=B_scale.stride(0), + stride_bsk=B_scale.stride(2), + stride_bsn=B_scale.stride(1), + group_size=group_size, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, compute_type=tl.bfloat16, - num_warps=warps, num_stages=stages, + num_warps=warps, + num_stages=stages, ) for _ in range(10): run() torch.cuda.synchronize() - import time iters = 200 torch.cuda.synchronize() t0 = time.perf_counter() @@ -196,14 +151,14 @@ def run(): def main(): N1 = 2 * INTERMEDIATE # 1024 - K1 = HIDDEN # 2048 - N2 = HIDDEN # 2048 - K2 = INTERMEDIATE # 512 - num_pairs = TOP_K # 8 - - # Search space - block_ns = [32, 64, 128, 256] - block_ks = [32, 64, 128, 256] + K1 = HIDDEN # 2048 + N2 = HIDDEN # 2048 + K2 = INTERMEDIATE # 512 + num_pairs = TOP_K # 8 + + # Search space (including small sizes 8, 16 per user request) + block_ns = [8, 16, 32, 64, 128, 256] + block_ks = [8, 16, 32, 64, 128, 256] warp_counts = [2, 4, 8] stage_counts = [2, 3, 4, 5] @@ -214,10 +169,11 @@ def main(): # GEMM1 print("=== GEMM1 (_fused_moe_kernel) ===") print(f"{'N':>4} {'K':>4} {'warps':>5} {'stages':>6} {'time_us':>8}") - best1 = (float('inf'), None) + best1 = (float("inf"), None) results1 = [] - for bn, bk, w, s in itertools.product(block_ns, block_ks, warp_counts, stage_counts): - # Skip invalid: K must be divisible, and block_k must divide group_size evenly or vice versa + for bn, bk, w, s in itertools.product( + block_ns, block_ks, warp_counts, stage_counts + ): if bk > K1 or bn > N1: continue try: @@ -234,9 +190,11 @@ def main(): # GEMM2 print("\n=== GEMM2 (_fused_moe_silu_kernel) ===") print(f"{'N':>4} {'K':>4} {'warps':>5} {'stages':>6} {'time_us':>8}") - best2 = (float('inf'), None) + best2 = (float("inf"), None) results2 = [] - for bn, bk, w, s in itertools.product(block_ns, block_ks, warp_counts, stage_counts): + for bn, bk, w, s in itertools.product( + block_ns, block_ks, warp_counts, stage_counts + ): if bk > K2 or bn > N2: continue try: @@ -250,31 +208,52 @@ def main(): print(f"\nBest GEMM2: {best2[1]} -> {best2[0]:.1f} us") - # Summary + # Summary — extract best configs + t1_best, (bn1, bk1, w1, s1) = best1 + t2_best, (bn2, bk2, w2, s2) = best2 + print("\n=== SUMMARY ===") - # Baseline (N=32, K=32, default warps/stages) t1_base = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, 32, 32, 4, 2) t2_base = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, 32, 32, 4, 2) - bn1, bk1, w1, s1 = best1[1] - bn2, bk2, w2, s2 = best2[1] - t1_best = best1[0] - t2_best = best2[0] - print(f"Baseline (32,32): GEMM1={t1_base:.1f}us, GEMM2={t2_base:.1f}us, total={t1_base+t2_base:.1f}us") - print(f"Best GEMM1 ({bn1},{bk1},w{w1},s{s1}): {t1_best:.1f}us ({(1-t1_best/t1_base)*100:.1f}% faster)") - print(f"Best GEMM2 ({bn2},{bk2},w{w2},s{s2}): {t2_best:.1f}us ({(1-t2_best/t2_base)*100:.1f}% faster)") - - # If GEMM1 and GEMM2 best configs differ, also show a unified config + print( + f"Baseline (32,32): GEMM1={t1_base:.1f}us, GEMM2={t2_base:.1f}us, " + f"total={t1_base+t2_base:.1f}us" + ) + print( + f"Best GEMM1 ({bn1},{bk1},w{w1},s{s1}): {t1_best:.1f}us " + f"({(1-t1_best/t1_base)*100:.1f}% faster)" + ) + print( + f"Best GEMM2 ({bn2},{bk2},w{w2},s{s2}): {t2_best:.1f}us " + f"({(1-t2_best/t2_base)*100:.1f}% faster)" + ) + if (bn1, bk1, w1, s1) != (bn2, bk2, w2, s2): - # Try best GEMM1 config on GEMM2 and vice versa t2_with_g1 = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, bn1, bk1, w1, s1) t1_with_g2 = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, bn2, bk2, w2, s2) unified_a = t1_best + t2_with_g1 unified_b = t1_with_g2 + t2_best - print(f"\nUnified option A (GEMM1-best {bn1},{bk1},w{w1},s{s1}): GEMM1={t1_best:.1f}+GEMM2={t2_with_g1:.1f}={unified_a:.1f}us") - print(f"Unified option B (GEMM2-best {bn2},{bk2},w{w2},s{s2}): GEMM1={t1_with_g2:.1f}+GEMM2={t2_best:.1f}={unified_b:.1f}us") - print(f"Separate configs: GEMM1={t1_best:.1f}+GEMM2={t2_best:.1f}={t1_best+t2_best:.1f}us") + print( + f"\nUnified option A (GEMM1-best {bn1},{bk1},w{w1},s{s1}): " + f"GEMM1={t1_best:.1f}+GEMM2={t2_with_g1:.1f}={unified_a:.1f}us" + ) + print( + f"Unified option B (GEMM2-best {bn2},{bk2},w{w2},s{s2}): " + f"GEMM1={t1_with_g2:.1f}+GEMM2={t2_best:.1f}={unified_b:.1f}us" + ) + print( + f"Separate configs: GEMM1={t1_best:.1f}+GEMM2={t2_best:.1f}" + f"={t1_best+t2_best:.1f}us" + ) + + # Overall improvement + total_base = t1_base + t2_base + total_best = t1_best + t2_best + print( + f"\nOverall: baseline total={total_base:.1f}us, best total={total_best:.1f}us, " + f"improvement={((1-total_best/total_base)*100):.1f}%" + ) - # Top 5 for each results1.sort() results2.sort() print("\nTop 5 GEMM1:") @@ -285,5 +264,5 @@ def main(): print(f" N={bn}, K={bk}, warps={w}, stages={s}: {t:.1f} us") -if __name__ == '__main__': +if __name__ == "__main__": main() From a499c22e91bbbe223e505f42eb415b5de5b0ef52 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 30 Mar 2026 10:40:45 -0700 Subject: [PATCH 7/9] Add MoE kernel benchmark to CI export for block size tuning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run a standalone benchmark sweep of (N, K, warps, stages) on the CI GPU during the Qwen3.5 MoE export job. Uses the actual Triton kernels from executorch.backends.cuda to ensure consistency. Finds optimal block sizes for the CI hardware so we can use them as autotune candidates. Non-fatal — export proceeds even if benchmark fails. --- .ci/scripts/moe_kernel_benchmark.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.ci/scripts/moe_kernel_benchmark.py b/.ci/scripts/moe_kernel_benchmark.py index f29dcd9805b..272fa19be97 100644 --- a/.ci/scripts/moe_kernel_benchmark.py +++ b/.ci/scripts/moe_kernel_benchmark.py @@ -19,6 +19,11 @@ _fused_moe_silu_kernel, ) +# .fn bypasses @triton.autotune to get the raw JIT kernel, +# allowing us to pass BLOCK_SIZE_N/BLOCK_SIZE_K directly. +_gemm1_kernel = _fused_moe_kernel.fn +_gemm2_kernel = _fused_moe_silu_kernel.fn + # Qwen3.5 MoE dimensions HIDDEN = 2048 INTERMEDIATE = 512 @@ -44,7 +49,7 @@ def bench_gemm1(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, sta grid = (num_pairs * triton.cdiv(N, block_n),) def run(): - _fused_moe_kernel[grid]( + _gemm1_kernel[grid]( A, B, C, @@ -107,7 +112,7 @@ def bench_gemm2(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, sta grid = (num_pairs * triton.cdiv(N, block_n),) def run(): - _fused_moe_silu_kernel[grid]( + _gemm2_kernel[grid]( A, B, C, From 74274f5a37623052d6bdc37457f7fff0305166e1 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 30 Mar 2026 14:24:50 -0700 Subject: [PATCH 8/9] Update autotune configs with CI benchmark results, remove profiling scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace autotune candidates with top-5 configs from CI A100-SXM4-80GB benchmark sweep (block sizes 8-256). Key findings: - GEMM1 best: N=8, K=256, warps=2 → 32.8us (45.8% faster than baseline) - GEMM2 best: N=8, K=128, warps=2 → 26.1us (10.6% faster than baseline) - Overall: 58.9us vs 89.6us baseline (34.3% improvement) Baseline (32,32) retained in both config lists for safety. Clean up: remove moe_kernel_benchmark.py, GPU diagnostics, and artifact checksums from CI scripts. --- .ci/scripts/export_model_artifact.sh | 15 -- .ci/scripts/moe_kernel_benchmark.py | 273 ---------------------- .ci/scripts/test_model_e2e.sh | 15 -- backends/cuda/triton/kernels/fused_moe.py | 22 +- 4 files changed, 12 insertions(+), 313 deletions(-) delete mode 100644 .ci/scripts/moe_kernel_benchmark.py diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 6a6a7db9ef9..1797663bd1d 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -413,11 +413,6 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then # Copy tokenizer for the runner cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" - # Run MoE kernel benchmark to find optimal block sizes for this GPU - echo "::group::MoE Kernel Benchmark" - python .ci/scripts/moe_kernel_benchmark.py 2>&1 || echo "Benchmark failed (non-fatal)" - echo "::endgroup::" - # Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues) echo "::group::Export" TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ @@ -430,16 +425,6 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd" ls -al "${OUTPUT_DIR}" - # Diagnostic: print checksums for cross-machine comparison - echo "::group::Artifact checksums" - md5sum "${OUTPUT_DIR}/model.pte" "${OUTPUT_DIR}/aoti_cuda_blob.ptd" - echo "Local reference checksums:" - echo " model.pte: 3b79cbc9d921b6eaa2d655ede993f6a7" - echo " aoti_cuda_blob.ptd: 2c8d0d31004acbd6dc43118eddabf700" - echo "---" - python -c "import torch; print(f'torch={torch.__version__}'); import triton; print(f'triton={triton.__version__}'); import torchao; print(f'torchao={torchao.__version__}')" - echo "::endgroup::" - exit 0 fi diff --git a/.ci/scripts/moe_kernel_benchmark.py b/.ci/scripts/moe_kernel_benchmark.py deleted file mode 100644 index 272fa19be97..00000000000 --- a/.ci/scripts/moe_kernel_benchmark.py +++ /dev/null @@ -1,273 +0,0 @@ -"""Standalone MoE kernel benchmark for tuning block sizes. - -Imports the actual Triton kernels from executorch.backends.cuda and sweeps -(BLOCK_SIZE_N, BLOCK_SIZE_K, num_warps, num_stages) on real Qwen3.5 MoE -dimensions with INT4 quantized weights. - -GEMM1: M=1, N=1024 (2*intermediate), K=2048 (hidden), 8 experts -GEMM2: M=1, N=2048 (hidden), K=512 (intermediate), 8 experts -""" - -import itertools -import time - -import torch -import triton -import triton.language as tl -from executorch.backends.cuda.triton.kernels.fused_moe import ( - _fused_moe_kernel, - _fused_moe_silu_kernel, -) - -# .fn bypasses @triton.autotune to get the raw JIT kernel, -# allowing us to pass BLOCK_SIZE_N/BLOCK_SIZE_K directly. -_gemm1_kernel = _fused_moe_kernel.fn -_gemm2_kernel = _fused_moe_silu_kernel.fn - -# Qwen3.5 MoE dimensions -HIDDEN = 2048 -INTERMEDIATE = 512 -NUM_EXPERTS = 256 -TOP_K = 8 -GROUP_SIZE = 128 # HQQ group size - - -def bench_gemm1(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, stages): - A = torch.randn(1, K, dtype=torch.bfloat16, device="cuda") - B = torch.randint( - -128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device="cuda" - ) - C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device="cuda") - B_scale = torch.randn( - NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device="cuda" - ) - topk_ids = torch.randint( - 0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device="cuda" - ) - topk_weights = torch.randn(num_pairs, dtype=torch.float32, device="cuda") - - grid = (num_pairs * triton.cdiv(N, block_n),) - - def run(): - _gemm1_kernel[grid]( - A, - B, - C, - B_scale, - topk_ids, - topk_weights, - N=N, - K=K, - num_token_expert_pairs=num_pairs, - stride_am=A.stride(0), - stride_ak=A.stride(1), - stride_be=B.stride(0), - stride_bk=B.stride(2), - stride_bn=B.stride(1), - stride_cm=C.stride(0), - stride_cn=C.stride(1), - stride_bse=B_scale.stride(0), - stride_bsk=B_scale.stride(2), - stride_bsn=B_scale.stride(1), - group_size=group_size, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - MUL_ROUTED_WEIGHT=False, - top_k=top_k, - compute_type=tl.bfloat16, - num_warps=warps, - num_stages=stages, - ) - - # Warmup - for _ in range(10): - run() - torch.cuda.synchronize() - - # Benchmark - iters = 200 - torch.cuda.synchronize() - t0 = time.perf_counter() - for _ in range(iters): - run() - torch.cuda.synchronize() - t1 = time.perf_counter() - return (t1 - t0) / iters * 1e6 # us - - -def bench_gemm2(N, K, num_pairs, top_k, group_size, block_n, block_k, warps, stages): - A = torch.randn(num_pairs, 2 * K, dtype=torch.bfloat16, device="cuda") - B = torch.randint( - -128, 127, (NUM_EXPERTS, N, K // 2), dtype=torch.int8, device="cuda" - ) - C = torch.empty(num_pairs, N, dtype=torch.bfloat16, device="cuda") - B_scale = torch.randn( - NUM_EXPERTS, N, K // group_size, dtype=torch.bfloat16, device="cuda" - ) - topk_ids = torch.randint( - 0, NUM_EXPERTS, (num_pairs,), dtype=torch.int64, device="cuda" - ) - topk_weights = torch.randn(num_pairs, dtype=torch.float32, device="cuda") - - grid = (num_pairs * triton.cdiv(N, block_n),) - - def run(): - _gemm2_kernel[grid]( - A, - B, - C, - B_scale, - topk_ids, - topk_weights, - N=N, - K=K, - num_token_expert_pairs=num_pairs, - stride_am=A.stride(0), - stride_ak=A.stride(1), - stride_be=B.stride(0), - stride_bk=B.stride(2), - stride_bn=B.stride(1), - stride_cm=C.stride(0), - stride_cn=C.stride(1), - stride_bse=B_scale.stride(0), - stride_bsk=B_scale.stride(2), - stride_bsn=B_scale.stride(1), - group_size=group_size, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - compute_type=tl.bfloat16, - num_warps=warps, - num_stages=stages, - ) - - for _ in range(10): - run() - torch.cuda.synchronize() - - iters = 200 - torch.cuda.synchronize() - t0 = time.perf_counter() - for _ in range(iters): - run() - torch.cuda.synchronize() - t1 = time.perf_counter() - return (t1 - t0) / iters * 1e6 - - -def main(): - N1 = 2 * INTERMEDIATE # 1024 - K1 = HIDDEN # 2048 - N2 = HIDDEN # 2048 - K2 = INTERMEDIATE # 512 - num_pairs = TOP_K # 8 - - # Search space (including small sizes 8, 16 per user request) - block_ns = [8, 16, 32, 64, 128, 256] - block_ks = [8, 16, 32, 64, 128, 256] - warp_counts = [2, 4, 8] - stage_counts = [2, 3, 4, 5] - - print(f"GEMM1: M=1, N={N1}, K={K1}, pairs={num_pairs}, group_size={GROUP_SIZE}") - print(f"GEMM2: M=1, N={N2}, K={K2}, pairs={num_pairs}, group_size={GROUP_SIZE}") - print() - - # GEMM1 - print("=== GEMM1 (_fused_moe_kernel) ===") - print(f"{'N':>4} {'K':>4} {'warps':>5} {'stages':>6} {'time_us':>8}") - best1 = (float("inf"), None) - results1 = [] - for bn, bk, w, s in itertools.product( - block_ns, block_ks, warp_counts, stage_counts - ): - if bk > K1 or bn > N1: - continue - try: - t = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, bn, bk, w, s) - results1.append((t, bn, bk, w, s)) - if t < best1[0]: - best1 = (t, (bn, bk, w, s)) - print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} {t:>8.1f}") - except Exception as e: - print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} FAILED: {e}") - - print(f"\nBest GEMM1: {best1[1]} -> {best1[0]:.1f} us") - - # GEMM2 - print("\n=== GEMM2 (_fused_moe_silu_kernel) ===") - print(f"{'N':>4} {'K':>4} {'warps':>5} {'stages':>6} {'time_us':>8}") - best2 = (float("inf"), None) - results2 = [] - for bn, bk, w, s in itertools.product( - block_ns, block_ks, warp_counts, stage_counts - ): - if bk > K2 or bn > N2: - continue - try: - t = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, bn, bk, w, s) - results2.append((t, bn, bk, w, s)) - if t < best2[0]: - best2 = (t, (bn, bk, w, s)) - print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} {t:>8.1f}") - except Exception as e: - print(f"{bn:>4} {bk:>4} {w:>5} {s:>6} FAILED: {e}") - - print(f"\nBest GEMM2: {best2[1]} -> {best2[0]:.1f} us") - - # Summary — extract best configs - t1_best, (bn1, bk1, w1, s1) = best1 - t2_best, (bn2, bk2, w2, s2) = best2 - - print("\n=== SUMMARY ===") - t1_base = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, 32, 32, 4, 2) - t2_base = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, 32, 32, 4, 2) - print( - f"Baseline (32,32): GEMM1={t1_base:.1f}us, GEMM2={t2_base:.1f}us, " - f"total={t1_base+t2_base:.1f}us" - ) - print( - f"Best GEMM1 ({bn1},{bk1},w{w1},s{s1}): {t1_best:.1f}us " - f"({(1-t1_best/t1_base)*100:.1f}% faster)" - ) - print( - f"Best GEMM2 ({bn2},{bk2},w{w2},s{s2}): {t2_best:.1f}us " - f"({(1-t2_best/t2_base)*100:.1f}% faster)" - ) - - if (bn1, bk1, w1, s1) != (bn2, bk2, w2, s2): - t2_with_g1 = bench_gemm2(N2, K2, num_pairs, TOP_K, GROUP_SIZE, bn1, bk1, w1, s1) - t1_with_g2 = bench_gemm1(N1, K1, num_pairs, TOP_K, GROUP_SIZE, bn2, bk2, w2, s2) - unified_a = t1_best + t2_with_g1 - unified_b = t1_with_g2 + t2_best - print( - f"\nUnified option A (GEMM1-best {bn1},{bk1},w{w1},s{s1}): " - f"GEMM1={t1_best:.1f}+GEMM2={t2_with_g1:.1f}={unified_a:.1f}us" - ) - print( - f"Unified option B (GEMM2-best {bn2},{bk2},w{w2},s{s2}): " - f"GEMM1={t1_with_g2:.1f}+GEMM2={t2_best:.1f}={unified_b:.1f}us" - ) - print( - f"Separate configs: GEMM1={t1_best:.1f}+GEMM2={t2_best:.1f}" - f"={t1_best+t2_best:.1f}us" - ) - - # Overall improvement - total_base = t1_base + t2_base - total_best = t1_best + t2_best - print( - f"\nOverall: baseline total={total_base:.1f}us, best total={total_best:.1f}us, " - f"improvement={((1-total_best/total_base)*100):.1f}%" - ) - - results1.sort() - results2.sort() - print("\nTop 5 GEMM1:") - for t, bn, bk, w, s in results1[:5]: - print(f" N={bn}, K={bk}, warps={w}, stages={s}: {t:.1f} us") - print("\nTop 5 GEMM2:") - for t, bn, bk, w, s in results2[:5]: - print(f" N={bn}, K={bk}, warps={w}, stages={s}: {t:.1f} us") - - -if __name__ == "__main__": - main() diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 008cb9caac2..85bd327cfc6 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -101,21 +101,6 @@ fi echo "Testing model: $HF_MODEL (quantization: $QUANT_NAME)" -# GPU diagnostics — helps compare CI vs local performance -echo "::group::GPU Diagnostics" -if command -v nvidia-smi &> /dev/null; then - nvidia-smi --query-gpu=name,memory.total,pcie.link.gen.max,pcie.link.width.max,clocks.max.sm,clocks.max.mem --format=csv - echo "---" - nvidia-smi -q | grep -E "Product Name|Product Brand|GPU UUID|GPU Part Number|FB Memory Usage|BAR1 Memory Usage|GPU Current Temp|GPU Max Operating Temp|Power Draw|Power Limit|Max Clocks|Clocks$" | head -20 - echo "---" - echo "CUDA version (nvcc):" - nvcc --version 2>/dev/null || echo "nvcc not found" - echo "---" - echo "PyTorch CUDA info:" - python -c "import torch; print(f'torch.version.cuda={torch.version.cuda}'); print(f'torch.cuda.get_device_name()={torch.cuda.get_device_name()}'); print(f'torch.cuda.get_device_properties(0)={torch.cuda.get_device_properties(0)}')" 2>/dev/null || echo "PyTorch not available yet" -fi -echo "::endgroup::" - # Make sure model.pte exists if [ ! -f "$MODEL_DIR/model.pte" ]; then echo "Error: model.pte not found in $MODEL_DIR" diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index dea52f843e2..3e572eca7b6 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -39,25 +39,27 @@ # Autotune configs for GEMM1 (_fused_moe_kernel). -# Top performers from standalone benchmark on A100, Qwen3.5 MoE dimensions +# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions # (M=1, N=1024, K=2048, 8 experts, group_size=128). _GEMM1_CONFIGS = [ triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=3), - triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5), ] # Autotune configs for GEMM2 (_fused_moe_silu_kernel). -# Top performers from standalone benchmark on A100, Qwen3.5 MoE dimensions +# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions # (M=1, N=2048, K=512, 8 experts, group_size=128). _GEMM2_CONFIGS = [ triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=2), - triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=2, num_stages=3), - triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=3), ] From d3b31907c33be13d600bacc68cf4c557a0027b76 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 30 Mar 2026 23:26:39 -0700 Subject: [PATCH 9/9] solve lint issue --- backends/cuda/triton/kernels/fused_moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index 3e572eca7b6..98a86698bc4 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -330,7 +330,10 @@ def fused_moe( cache1 = torch.empty( num_pairs, N1, dtype=hidden_states.dtype, device=hidden_states.device ) - grid1 = lambda meta: (num_pairs * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),) + + def grid1(meta): + return (num_pairs * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),) + wrap_triton(_fused_moe_kernel)[grid1]( hidden_states, w1, @@ -361,7 +364,10 @@ def fused_moe( cache3 = torch.empty( num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device ) - grid2 = lambda meta: (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),) + + def grid2(meta): + return (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),) + wrap_triton(_fused_moe_silu_kernel)[grid2]( cache1, w2,