Hoist W4A8 activation quantization out of GEMM K-loop#19209
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19209
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Pending, 1 Unrelated FailureAs of commit 28c87c0 with merge base b3baac5 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
385bf6d to
793be8d
Compare
793be8d to
d936717
Compare
|
One ask: the existing INT8_TEST_CONFIGS only test toy dimensions (max hidden=128). Please add a config or two at realistic scale so we catch any precision or alignment issues with K > PREQUANT_BLOCK_K: (77, 512, 2048, 1024, 8, 2, 128, "512tok_real_dims"), backends/cuda/tests/test_fused_moe.py |
|
|
||
| offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
|
|
||
| for k_tile in range(NUM_K_TILES): |
There was a problem hiding this comment.
here we are quantizing 128 values of A, that has accuracy implications and not apples/apples with llama.cpp. We should use 32.
Add dedicated _quantize_activations_int8_kernel and _silu_quantize_int8_kernel that pre-quantize activations to INT8 with per-row-per-tile FP32 scales before GEMM1 and GEMM2 respectively. The existing _fused_moe_batched_int8_kernel and _fused_moe_silu_batched_int8_kernel are rewritten to consume pre-quantized activations + scales, eliminating ~256 redundant tl.max reductions per program (cdiv(K, BLOCK_K) tiles * BLOCK_M rows) and halving activation HBM bandwidth in the K-loop (bf16 -> int8). BLOCK_SIZE_K is fixed at PREQUANT_BLOCK_K (= 128) so per-tile activation scales align with the GEMM K-loop. Correctness: 7/7 microbenchmark configs pass with rel diff <1.5% vs BF16 ref. End-to-end (Qwen3.5 MoE 1600 prefill + 512 decode, --cuda_graph, A100): prefill 5727 -> 6171 tok/s (+7.7%), decode 92.6 -> 99.0 tok/s (+6.9%).
d936717 to
93cf373
Compare
This PR needs a
|
311944d to
4563eb2
Compare
Microbenchmarked on Qwen3.5 MoE prefill (M=1696, top_k=8, 256 experts):
BLOCK_M=16: 3.62 ms
BLOCK_M=32: 2.85 ms (1.27x)
BLOCK_M=64: 2.75 ms (1.32x)
E2E (Qwen3.5-35B-A3B prefill, --moe-activation-dtype int8 --dense-prefill
dequant --cuda_graph, p=1600 d=512, run_1..5 median):
BLOCK_M=16: 5897 tok/s prefill (273 ms), 98.1 tok/s decode
BLOCK_M=64: 6793 tok/s prefill (237 ms), 98.1 tok/s decode
Speedup: 1.152x prefill, decode unchanged (decode uses non-batched
fused_moe kernel)
Outputs are bit-identical between BLOCK_M=16 and BLOCK_M=64 in the
microbenchmark (max abs diff = 0).
…ofiling - int4_matmul.py: 5 new _MATVEC_CONFIGS targeting decode shapes (N=1 shared_expert_gate, N=2048 o_proj, N=248320 lm_head, N=12352 GDN.in_proj, N=256/1024 router/gate_up_proj) - fused_moe.py: added num_warps=2 configs in _BATCHED_GEMM1_INT8_CONFIGS and _BATCHED_GEMM2_INT8_CONFIGS for prefill INT8 path - fused_moe.py: added (BLOCK_SIZE_N=8, BLOCK_SIZE_K=256, num_warps=2, num_stages=2) to _GEMM2_CONFIGS for decode MoE GEMM2 Validated end-to-end: decode +4% tok/s avg, prefill at S=2048 6375 tok/s. Op-level sweep showed best swept configs are now in default lists.
Context
The original K-loop did
tl.max(tl.abs(a))+ INT8 cast on every tile (16 tiles × 16 rows = 256 reductions per program). Hoisting eliminates this redundant work and halves activation HBM bandwidth in the GEMM (bf16 → int8).Improvement
Pre-quantize activations to INT8 once into a dedicated buffer (with per-row-per-tile FP32 scales) before the W4A8 batched MoE GEMMs, instead of re-quantizing inside the K-loop on every tile.
Perf (1600-token prefill)
gh/digantdesai/53/head)Correctness
7/7 microbenchmark configs (incl. qwen3.5-like M=128, K=2048, gs=128) pass with relative diff <1.5% vs BF16 reference — within INT8 quantization noise.
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell