Skip to content

Hoist W4A8 activation quantization out of GEMM K-loop#19209

Merged
Gasoonjia merged 9 commits into
mainfrom
hoist-activation-quant
May 9, 2026
Merged

Hoist W4A8 activation quantization out of GEMM K-loop#19209
Gasoonjia merged 9 commits into
mainfrom
hoist-activation-quant

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 29, 2026

Context

The original K-loop did tl.max(tl.abs(a)) + INT8 cast on every tile (16 tiles × 16 rows = 256 reductions per program). Hoisting eliminates this redundant work and halves activation HBM bandwidth in the GEMM (bf16 → int8).

Improvement

Pre-quantize activations to INT8 once into a dedicated buffer (with per-row-per-tile FP32 scales) before the W4A8 batched MoE GEMMs, instead of re-quantizing inside the K-loop on every tile.

Perf (1600-token prefill)

Metric Baseline (gh/digantdesai/53/head) Optimized Speedup
Prefill 5727 tok/s (5296–5963) 6171 tok/s (5941–6313) 1.08×

Correctness

7/7 microbenchmark configs (incl. qwen3.5-like M=128, K=2048, gs=128) pass with relative diff <1.5% vs BF16 reference — within INT8 quantization noise.

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell

@Gasoonjia Gasoonjia requested a review from lucylq as a code owner April 29, 2026 19:47
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 29, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19209

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 2 Pending, 1 Unrelated Failure

As of commit 28c87c0 with merge base b3baac5 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2026
@Gasoonjia Gasoonjia force-pushed the hoist-activation-quant branch 2 times, most recently from 385bf6d to 793be8d Compare April 29, 2026 20:19
@Gasoonjia Gasoonjia force-pushed the hoist-activation-quant branch from 793be8d to d936717 Compare April 29, 2026 21:14
Base automatically changed from gh/digantdesai/53/head to gh/digantdesai/53/base April 30, 2026 15:05
@mergennachin
Copy link
Copy Markdown
Contributor

mergennachin commented May 1, 2026

One ask: the existing INT8_TEST_CONFIGS only test toy dimensions (max hidden=128). Please add a config or two at realistic scale so we catch any precision or alignment issues with K > PREQUANT_BLOCK_K:

(77, 512, 2048, 1024, 8, 2, 128, "512tok_real_dims"),
(21, 1, 2048, 1024, 8, 2, 128, "1tok_decode"),

backends/cuda/tests/test_fused_moe.py


offs_k = tl.arange(0, BLOCK_SIZE_K)

for k_tile in range(NUM_K_TILES):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we are quantizing 128 values of A, that has accuracy implications and not apples/apples with llama.cpp. We should use 32.

Add dedicated _quantize_activations_int8_kernel and _silu_quantize_int8_kernel
that pre-quantize activations to INT8 with per-row-per-tile FP32 scales before
GEMM1 and GEMM2 respectively. The existing _fused_moe_batched_int8_kernel and
_fused_moe_silu_batched_int8_kernel are rewritten to consume pre-quantized
activations + scales, eliminating ~256 redundant tl.max reductions per program
(cdiv(K, BLOCK_K) tiles * BLOCK_M rows) and halving activation HBM bandwidth in
the K-loop (bf16 -> int8). BLOCK_SIZE_K is fixed at PREQUANT_BLOCK_K (= 128)
so per-tile activation scales align with the GEMM K-loop.

Correctness: 7/7 microbenchmark configs pass with rel diff <1.5% vs BF16 ref.
End-to-end (Qwen3.5 MoE 1600 prefill + 512 decode, --cuda_graph, A100):
prefill 5727 -> 6171 tok/s (+7.7%), decode 92.6 -> 99.0 tok/s (+6.9%).
@Gasoonjia Gasoonjia force-pushed the hoist-activation-quant branch from d936717 to 93cf373 Compare May 4, 2026 22:54
@github-actions github-actions Bot added ciflow/trunk module: arm Issues related to arm backend labels May 4, 2026
@Gasoonjia Gasoonjia changed the base branch from gh/digantdesai/53/base to main May 4, 2026 22:55
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia force-pushed the hoist-activation-quant branch from 311944d to 4563eb2 Compare May 6, 2026 04:48
Gasoonjia and others added 7 commits May 5, 2026 21:48
Microbenchmarked on Qwen3.5 MoE prefill (M=1696, top_k=8, 256 experts):
  BLOCK_M=16: 3.62 ms
  BLOCK_M=32: 2.85 ms (1.27x)
  BLOCK_M=64: 2.75 ms (1.32x)

E2E (Qwen3.5-35B-A3B prefill, --moe-activation-dtype int8 --dense-prefill
dequant --cuda_graph, p=1600 d=512, run_1..5 median):
  BLOCK_M=16: 5897 tok/s prefill (273 ms), 98.1 tok/s decode
  BLOCK_M=64: 6793 tok/s prefill (237 ms), 98.1 tok/s decode
  Speedup:    1.152x prefill, decode unchanged (decode uses non-batched
              fused_moe kernel)

Outputs are bit-identical between BLOCK_M=16 and BLOCK_M=64 in the
microbenchmark (max abs diff = 0).
…ofiling

- int4_matmul.py: 5 new _MATVEC_CONFIGS targeting decode shapes
  (N=1 shared_expert_gate, N=2048 o_proj, N=248320 lm_head,
   N=12352 GDN.in_proj, N=256/1024 router/gate_up_proj)
- fused_moe.py: added num_warps=2 configs in _BATCHED_GEMM1_INT8_CONFIGS
  and _BATCHED_GEMM2_INT8_CONFIGS for prefill INT8 path
- fused_moe.py: added (BLOCK_SIZE_N=8, BLOCK_SIZE_K=256, num_warps=2,
  num_stages=2) to _GEMM2_CONFIGS for decode MoE GEMM2

Validated end-to-end: decode +4% tok/s avg, prefill at S=2048 6375 tok/s.
Op-level sweep showed best swept configs are now in default lists.
@Gasoonjia Gasoonjia merged commit 93b764e into main May 9, 2026
482 of 484 checks passed
@Gasoonjia Gasoonjia deleted the hoist-activation-quant branch May 9, 2026 01:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: arm Issues related to arm backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants