Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path)#28
Merged
Fridge003 merged 3 commits intoMay 6, 2026
Merged
Conversation
added 3 commits
May 6, 2026 06:14
…bine path)
The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink:
each token, each topk slot = kHidden * 2 bytes. This commit adds an env-
gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte —
kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes.
Wiring:
- New `kUseFp8Combine` template flag (default false → keeps BF16 path
byte-identical when off).
- New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per
(token, slot) when on, zero when off.
- Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of
`DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a +
mainloops; this controls the combine a2a only).
Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh):
- Read 8 BF16 from smem (existing STSM target).
- Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes
that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the
outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half-
warp to exit on padding rows, and a full-warp shfl would deadlock.
- Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`).
- Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64.
- Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane
group writes the SF byte to `combine_sf_buffer`.
Consumer side (combine reduce):
- Per-slot SF base ptr cached at slot start.
- TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine).
- Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2
via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then
`__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant.
- BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4
(16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}.
Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32).
Validation:
- Microbench (`ptx/d_combine_reduce_v{1,2}_*`):
- v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect).
- v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference
that uses the same FP8 quant), 50% NVLink bytes savings.
- Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine):
- b=128: 364 us → 359 us (+1.5%)
- b=512: 377 us → 386 us (-2.2%)
- b=2048: 710 us → 739 us (-3.9%)
Single-GPU is compute-bound (no NVLink saving); production is the
point of the change.
- E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output):
- b=512: 91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3%
- b=2048: 259.4 s (FP8) → 238.2 s — +8.9%
- b=4096: 489.5 s (FP8) → 444.2 s — +10.2%
Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes.
Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs
BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations
post-SwiGLU + topk-weighting are bounded; production accuracy parity
preserved (same GSM8K results as FP4 baseline).
Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count and halves the accumulator register pressure (94 regs vs 138 regs). Inner loop, before: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) cvt.f32.f16 ×2 (FP16 → FP32) fma.rn.f32 ×2 (acc += sf_f32 * f32_val) = 5 ops per FP8x2 (= 2 elements) After: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) fma.rn.f16x2 (acc_fp16x2 += sf_pair * f16x2) = 2 ops per FP8x2 SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15. Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production activations post-SwiGLU + topk-weighting fit comfortably in FP16 range. End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem write-back (BF16 output unchanged). Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`): v1 BF16 baseline: 6,895 cycles/token v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1) v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)** E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output: | batch | FP4+MXF4 | combine FP32 | combine HFMA | |------:|---------:|-------------:|-------------:| | 512 | — | 7,526 | 7,350 | | 2048 | 9,814 | 9,903 | **9,992** | | 4096 | 10,418 | 10,622 | **10,699** | HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default. Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4 vs the FP32 reference. Production activations: still within sentinel tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline).
This reverts commit 48e8101.
Fridge003
pushed a commit
that referenced
this pull request
May 12, 2026
…bine path) (#28) * Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path) The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink: each token, each topk slot = kHidden * 2 bytes. This commit adds an env- gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte — kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes. Wiring: - New `kUseFp8Combine` template flag (default false → keeps BF16 path byte-identical when off). - New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per (token, slot) when on, zero when off. - Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of `DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a + mainloops; this controls the combine a2a only). Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh): - Read 8 BF16 from smem (existing STSM target). - Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half- warp to exit on padding rows, and a full-warp shfl would deadlock. - Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`). - Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64. - Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane group writes the SF byte to `combine_sf_buffer`. Consumer side (combine reduce): - Per-slot SF base ptr cached at slot start. - TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine). - Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2 via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then `__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant. - BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4 (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}. Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32). Validation: - Microbench (`ptx/d_combine_reduce_v{1,2}_*`): - v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect). - v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference that uses the same FP8 quant), 50% NVLink bytes savings. - Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine): - b=128: 364 us → 359 us (+1.5%) - b=512: 377 us → 386 us (-2.2%) - b=2048: 710 us → 739 us (-3.9%) Single-GPU is compute-bound (no NVLink saving); production is the point of the change. - E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output): - b=512: 91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3% - b=2048: 259.4 s (FP8) → 238.2 s — +8.9% - b=4096: 489.5 s (FP8) → 444.2 s — +10.2% Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes. Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations post-SwiGLU + topk-weighting are bounded; production accuracy parity preserved (same GSM8K results as FP4 baseline). * Combine reduce: HFMA path (FP16 accumulator + fma.f16x2) Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count and halves the accumulator register pressure (94 regs vs 138 regs). Inner loop, before: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) cvt.f32.f16 ×2 (FP16 → FP32) fma.rn.f32 ×2 (acc += sf_f32 * f32_val) = 5 ops per FP8x2 (= 2 elements) After: cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2) fma.rn.f16x2 (acc_fp16x2 += sf_pair * f16x2) = 2 ops per FP8x2 SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15. Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production activations post-SwiGLU + topk-weighting fit comfortably in FP16 range. End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem write-back (BF16 output unchanged). Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`): v1 BF16 baseline: 6,895 cycles/token v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1) v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)** E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output: | batch | FP4+MXF4 | combine FP32 | combine HFMA | |------:|---------:|-------------:|-------------:| | 512 | — | 7,526 | 7,350 | | 2048 | 9,814 | 9,903 | **9,992** | | 4096 | 10,418 | 10,622 | **10,699** | HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default. Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4 vs the FP32 reference. Production activations: still within sentinel tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline). * Revert "Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)" This reverts commit 48e8101. --------- Co-authored-by: pranjalssh <adkz.photos@gmail.com> (cherry picked from commit 8fc78b4)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The mega-MoE second all-to-all (combine) ships BF16 over NVLink:
kHidden * 2bytes per (token, slot). This PR adds an env-gated FP8 path that ships FP8 E4M3 + per-(token, N=128) UE8M0 SF:kHidden + kHidden/128bytes per (token, slot) — half the NVLink bytes.Changes
kUseFp8Combinetemplate flag insm100_fp8_fp4_mega_moe.cuh(defaultfalse, BF16 path stays byte-identical when off).combine_sf_bufferslot in the symm buffer (hidden/128bytes/token/slot when on, zero when off).DG_USE_FP8_COMBINE=1host-side env flag inmega.hpp(mirrors the FP4-acts pattern; independent ofDG_USE_FP4_ACTS/DG_USE_MXF4_KIND).Producer side (L2 epilogue write-back)
__shfl_xor_syncreduction over the 16 lanes that share each row tile. Use a 16-lane mask (NOT0xffffffff) — the outerif (m_idx_in_block >= valid_m) breakmay cause the OTHER half-warp to exit on padding rows, and a full-warp shfl would deadlock.get_e4m3_sf_and_sf_inv).__nv_fp8x4_e4m3(float4)×2; pack into uint64.combine_sf_buffer.Consumer side (combine reduce)
kNumChunkBytes / 2bytes whenkUseFp8Combine).__ldgthe SF byte for the segment;cvt.rn.f16x2.e4m3x2→cvt.f32.f16→__fmaf_rn(val, sf, acc)per element. cvt.f32.e4m3.b8 doesn't exist on sm_103a, so the FP16 intermediate is needed.(j*32+lane)*2 + {0,1}.Validation
Standalone microbench (
ptx/d_combine_reduce_v{1,2}_*)Single-GPU iso bench (8× B300, EP8, 384 experts, topk=6, hidden=7168, intermediate=3072)
Single-GPU is compute-bound (no NVLink contention); combine adds slight compute overhead. Production wins come from the NVLink savings.
End-to-end (sglang serve, DeepSeek-V4-Pro on 8x B300, 8K input, 1024 output)
Per-GPU throughput =
(input_len + output_len) × bs / latency / 8(tok/s/gpu).Numerical
Synthetic random-init test (combine FP8 vs BF16 baseline): rel-RMSE 0.027 (= 2.7%) — values are huge in this synthetic test (no SwiGLU clamping → tail outliers). Real activations post-SwiGLU + topk-weighting are bounded; production accuracy parity preserved.
How to use
Independent of
DG_USE_FP4_ACTS/DG_USE_MXF4_KIND. Combinable for max savings:Companion sglang PR: sgl-project/sglang#24449 — adds
SGLANG_OPT_DEEPGEMM_MEGA_MOE_USE_FP8_COMBINEflag that forwards toDG_USE_FP8_COMBINE.Test plan
develop.shon B300 (CUDA 13, sm_103a)tests/test_mega_moe_pre_dispatch.py: PASS (regression check)Builds on top of merged #27.
🤖 Generated with Claude Code