diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index 99e64e259a..a2d8d03a79 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -159,22 +159,23 @@ CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer. ### Individual bfloat16 torch._grouped_mm op vs torchao_scaled_grouped_mm -MXFP8: - -| M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup | -|------------------------|-----------------|-------------------|------------------------| -| (128000, 8192, 5120, 1) | 40463 | 24406 | 1.658x | -| (128000, 8192, 5120, 2) | 35494.5 | 24705.1 | 1.437x | -| (128000, 8192, 5120, 4) | 38879.3 | 24508.5 | 1.586x | -| (128000, 8192, 5120, 8) | 35714.6 | 25937.6 | 1.377x | -| (128000, 1536, 5120, 1) | 6353.06 | 7401.54 | 0.858x | -| (128000, 1536, 5120, 2) | 6511.65 | 6729.33 | 0.968x | -| (128000, 1536, 5120, 4) | 6455.2 | 6626.5 | 0.974x | -| (128000, 1536, 5120, 8) | 7716.13 | 6516.74 | 1.184x | -| (128000, 2048, 7168, 1) | 11758 | 11255.7 | 1.045x | -| (128000, 2048, 7168, 2) | 15012.9 | 9917.9 | 1.514x | -| (128000, 2048, 7168, 4) | 14904.2 | 10493.8 | 1.42x | -| (128000, 2048, 7168, 8) | 13178 | 9638.38 | 1.367x | +**MXFP8 with Llama4 17b 16e shapes** (with G=1-8 to simulate different degrees of expert parallelism) + +| M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup | +| ----------------------- | --------------: | ----------------: | ---------------------: | +| (128000, 8192, 5120, 1) | 43140.20 | 23867.00 | 1.808x | +| (128000, 8192, 5120, 2) | 39487.60 | 23359.00 | 1.690x | +| (128000, 8192, 5120, 4) | 39189.20 | 23945.50 | 1.637x | +| (128000, 8192, 5120, 8) | 37700.70 | 22170.60 | 1.700x | + +**MXFP8 with DeepSeekV3** (with G=-8 to simulate different degrees of expert parallelism) + +| M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup | +| ----------------------- | --------------: | ----------------: | ---------------------: | +| (128000, 2048, 7168, 1) | 13064.80 | 10996.00 | 1.188x | +| (128000, 2048, 7168, 2) | 14900.20 | 11283.40 | 1.321x | +| (128000, 2048, 7168, 4) | 15823.60 | 9919.36 | 1.595x | +| (128000, 2048, 7168, 8) | 14966.80 | 10397.20 | 1.440x | To reproduce this benchmark, on a B200 GPU machine, run the following command: