Skip to content

Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path)#28

Merged
Fridge003 merged 3 commits into
sgl-project:dev-0426from
pranjalssh:fp8-combine-on-second-a2a
May 6, 2026
Merged

Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path)#28
Fridge003 merged 3 commits into
sgl-project:dev-0426from
pranjalssh:fp8-combine-on-second-a2a

Conversation

@pranjalssh
Copy link
Copy Markdown

Summary

The mega-MoE second all-to-all (combine) ships BF16 over NVLink: kHidden * 2 bytes per (token, slot). This PR adds an env-gated FP8 path that ships FP8 E4M3 + per-(token, N=128) UE8M0 SF: kHidden + kHidden/128 bytes per (token, slot) — half the NVLink bytes.

Changes

  • New kUseFp8Combine template flag in sm100_fp8_fp4_mega_moe.cuh (default false, BF16 path stays byte-identical when off).
  • New combine_sf_buffer slot in the symm buffer (hidden/128 bytes/token/slot when on, zero when off).
  • DG_USE_FP8_COMBINE=1 host-side env flag in mega.hpp (mirrors the FP4-acts pattern; independent of DG_USE_FP4_ACTS / DG_USE_MXF4_KIND).

Producer side (L2 epilogue write-back)

  • Read 8 BF16 from smem (existing STSM target).
  • Compute per-row amax via __shfl_xor_sync reduction over the 16 lanes that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the outer if (m_idx_in_block >= valid_m) break may cause the OTHER half-warp to exit on padding rows, and a full-warp shfl would deadlock.
  • Compute UE8M0 SF (E4M3 finfo_max=448, mirrors get_e4m3_sf_and_sf_inv).
  • Cast 8 BF16 → 8 FP8 via __nv_fp8x4_e4m3(float4) ×2; pack into uint64.
  • Write 8 FP8 bytes to remote (vs 16 BF16). Lane 0 of the 16-lane group writes the SF byte to combine_sf_buffer.

Consumer side (combine reduce)

  • Per-slot SF base ptr cached at slot start.
  • TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine).
  • Per uint4 (16 FP8): __ldg the SF byte for the segment; cvt.rn.f16x2.e4m3x2cvt.f32.f16__fmaf_rn(val, sf, acc) per element. cvt.f32.e4m3.b8 doesn't exist on sm_103a, so the FP16 intermediate is needed.
  • BF16 store layout: 2 BF16 uint4 per input FP8 uint4 (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}.

Validation

Standalone microbench (ptx/d_combine_reduce_v{1,2}_*)

  • v1 BF16 baseline: 6,895 cycles/token, max_abs=0.
  • v2 FP8 + UE8M0 SF: correctness PASS, 50% NVLink-bytes savings (43014 bytes/token vs 86016).

Single-GPU iso bench (8× B300, EP8, 384 experts, topk=6, hidden=7168, intermediate=3072)

per-rank tokens tpe FP4+MXF4 (us) FP4+MXF4+combine (us) delta
128 16 364 359 +1.5%
512 64 377 386 -2.2%
2048 256 710 739 -3.9%

Single-GPU is compute-bound (no NVLink contention); combine adds slight compute overhead. Production wins come from the NVLink savings.

End-to-end (sglang serve, DeepSeek-V4-Pro on 8x B300, 8K input, 1024 output)

Per-GPU throughput = (input_len + output_len) × bs / latency / 8 (tok/s/gpu).

batch FP8 acts FP4+MXF4 FP4+MXF4+FP8combine combine vs FP4-MXF4 combine vs FP8
512 6,417 7,526 +17.3%
2048 9,096 9,814 9,903 +0.9% +8.9%
4096 9,639 10,418 10,622 +2.0% +10.2%

Numerical

Synthetic random-init test (combine FP8 vs BF16 baseline): rel-RMSE 0.027 (= 2.7%) — values are huge in this synthetic test (no SwiGLU clamping → tail outliers). Real activations post-SwiGLU + topk-weighting are bounded; production accuracy parity preserved.

How to use

export DG_USE_FP8_COMBINE=1   # halve combine NVLink bytes

Independent of DG_USE_FP4_ACTS / DG_USE_MXF4_KIND. Combinable for max savings:

export DG_USE_FP4_ACTS=1
export DG_USE_MXF4_KIND=1
export DG_USE_FP8_COMBINE=1

Companion sglang PR: sgl-project/sglang#24449 — adds SGLANG_OPT_DEEPGEMM_MEGA_MOE_USE_FP8_COMBINE flag that forwards to DG_USE_FP8_COMBINE.

Test plan

  • Build under develop.sh on B300 (CUDA 13, sm_103a)
  • tests/test_mega_moe_pre_dispatch.py: PASS (regression check)
  • Microbench correctness + bytes-savings model
  • FP8 combine vs BF16 baseline rel-RMSE check
  • e2e DeepSeek-V4-Pro b=512/2048/4096 on 8× B300

Builds on top of merged #27.

🤖 Generated with Claude Code

pranjalssh added 3 commits May 6, 2026 06:14
…bine path)

The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink:
each token, each topk slot = kHidden * 2 bytes. This commit adds an env-
gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte —
kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes.

Wiring:
- New `kUseFp8Combine` template flag (default false → keeps BF16 path
  byte-identical when off).
- New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per
  (token, slot) when on, zero when off.
- Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of
  `DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a +
  mainloops; this controls the combine a2a only).

Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh):
- Read 8 BF16 from smem (existing STSM target).
- Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes
  that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the
  outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half-
  warp to exit on padding rows, and a full-warp shfl would deadlock.
- Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`).
- Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64.
- Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane
  group writes the SF byte to `combine_sf_buffer`.

Consumer side (combine reduce):
- Per-slot SF base ptr cached at slot start.
- TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine).
- Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2
  via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then
  `__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant.
- BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4
  (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}.
  Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32).

Validation:
- Microbench (`ptx/d_combine_reduce_v{1,2}_*`):
  - v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect).
  - v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference
    that uses the same FP8 quant), 50% NVLink bytes savings.
- Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine):
  - b=128:  364 us → 359 us (+1.5%)
  - b=512:  377 us → 386 us (-2.2%)
  - b=2048: 710 us → 739 us (-3.9%)
  Single-GPU is compute-bound (no NVLink saving); production is the
  point of the change.
- E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output):
  - b=512:  91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3%
  - b=2048: 259.4 s (FP8) → 238.2 s — +8.9%
  - b=4096: 489.5 s (FP8) → 444.2 s — +10.2%
  Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes.

Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs
BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations
post-SwiGLU + topk-weighting are bounded; production accuracy parity
preserved (same GSM8K results as FP4 baseline).
Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar
fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count
and halves the accumulator register pressure (94 regs vs 138 regs).

Inner loop, before:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  cvt.f32.f16  ×2     (FP16 → FP32)
  fma.rn.f32   ×2     (acc += sf_f32 * f32_val)
  = 5 ops per FP8x2 (= 2 elements)

After:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  fma.rn.f16x2        (acc_fp16x2 += sf_pair * f16x2)
  = 2 ops per FP8x2

SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15.
Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production
activations post-SwiGLU + topk-weighting fit comfortably in FP16 range.

End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem
write-back (BF16 output unchanged).

Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`):
  v1 BF16 baseline: 6,895 cycles/token
  v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1)
  v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)**

E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output:
  | batch | FP4+MXF4 | combine FP32 | combine HFMA |
  |------:|---------:|-------------:|-------------:|
  | 512   | —        | 7,526        | 7,350        |
  | 2048  | 9,814    | 9,903        | **9,992**    |
  | 4096  | 10,418   | 10,622       | **10,699**   |

HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default.

Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4
vs the FP32 reference. Production activations: still within sentinel
tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline).
@Fridge003 Fridge003 merged commit 8fc78b4 into sgl-project:dev-0426 May 6, 2026
Fridge003 pushed a commit that referenced this pull request May 12, 2026
…bine path) (#28)

* Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path)

The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink:
each token, each topk slot = kHidden * 2 bytes. This commit adds an env-
gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte —
kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes.

Wiring:
- New `kUseFp8Combine` template flag (default false → keeps BF16 path
  byte-identical when off).
- New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per
  (token, slot) when on, zero when off.
- Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of
  `DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a +
  mainloops; this controls the combine a2a only).

Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh):
- Read 8 BF16 from smem (existing STSM target).
- Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes
  that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the
  outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half-
  warp to exit on padding rows, and a full-warp shfl would deadlock.
- Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`).
- Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64.
- Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane
  group writes the SF byte to `combine_sf_buffer`.

Consumer side (combine reduce):
- Per-slot SF base ptr cached at slot start.
- TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine).
- Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2
  via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then
  `__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant.
- BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4
  (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}.
  Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32).

Validation:
- Microbench (`ptx/d_combine_reduce_v{1,2}_*`):
  - v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect).
  - v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference
    that uses the same FP8 quant), 50% NVLink bytes savings.
- Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine):
  - b=128:  364 us → 359 us (+1.5%)
  - b=512:  377 us → 386 us (-2.2%)
  - b=2048: 710 us → 739 us (-3.9%)
  Single-GPU is compute-bound (no NVLink saving); production is the
  point of the change.
- E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output):
  - b=512:  91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3%
  - b=2048: 259.4 s (FP8) → 238.2 s — +8.9%
  - b=4096: 489.5 s (FP8) → 444.2 s — +10.2%
  Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes.

Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs
BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations
post-SwiGLU + topk-weighting are bounded; production accuracy parity
preserved (same GSM8K results as FP4 baseline).

* Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)

Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar
fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count
and halves the accumulator register pressure (94 regs vs 138 regs).

Inner loop, before:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  cvt.f32.f16  ×2     (FP16 → FP32)
  fma.rn.f32   ×2     (acc += sf_f32 * f32_val)
  = 5 ops per FP8x2 (= 2 elements)

After:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  fma.rn.f16x2        (acc_fp16x2 += sf_pair * f16x2)
  = 2 ops per FP8x2

SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15.
Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production
activations post-SwiGLU + topk-weighting fit comfortably in FP16 range.

End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem
write-back (BF16 output unchanged).

Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`):
  v1 BF16 baseline: 6,895 cycles/token
  v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1)
  v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)**

E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output:
  | batch | FP4+MXF4 | combine FP32 | combine HFMA |
  |------:|---------:|-------------:|-------------:|
  | 512   | —        | 7,526        | 7,350        |
  | 2048  | 9,814    | 9,903        | **9,992**    |
  | 4096  | 10,418   | 10,622       | **10,699**   |

HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default.

Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4
vs the FP32 reference. Production activations: still within sentinel
tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline).

* Revert "Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)"

This reverts commit 48e8101.

---------

Co-authored-by: pranjalssh <adkz.photos@gmail.com>
(cherry picked from commit 8fc78b4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants