From 7d43a47c8794830a8cc4f6bbf6d9e5176a2cf2d4 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 13 Oct 2025 09:18:48 -0700 Subject: [PATCH 1/3] [moe training] update readme --- torchao/prototype/moe_training/README.md | 164 +++++++++++++++++++---- 1 file changed, 141 insertions(+), 23 deletions(-) diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index e53278840e..03ea86dbcf 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -1,40 +1,92 @@ -# Float8 MoE Training +# Low precision MoE Training -This prototype feature provides a way to use float8 rowwise training on MoE layers. +This prototype provides: -Below is a simple runnable example of how to use this feature, using the MoE layer -from the [torchtitan](https://github.com/pytorch/torchtitan) Llama4 implementation for demonstration. +1. Quantized building block for low precision MoE training: `_scaled_grouped_mm`. It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below. + - Using MXFP8 on a B200 GPU, this provides: + - ~1.4x - 1.8x speedups over bfloat16 `torch._grouped_mm` for Llama4 17b 16e shapes (depending on the `M` dimension, i.e. batch_size * seq_len) + - ~1.15 - 1.3x speedups over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes (depending on the `M` dimension, i.e. batch_size * seq_len) +2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: to get started with e2e pretraining of DeepSeekV3/Llama4 with torchtitan, simply add the flag to your training command: `--model.converters="quantize.grouped_mm.mx" [--quantize.grouped_mm.mx.fqns="experts"]` + +3. Model conversion via the torchao `quantize_(...)` API: this swaps all `torch._grouped_mm` ops in your model definition to use torchao `_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below). + + +## Table of Contents + +- [Examples](#examples) +- [Performance Benchmarks](#performance-benchmarks-mxfp8) +- [System Requirements](#system-requirements) +- [Implementation Details for Developers](#implementation-details-for-developers) +- [Limitations](#limitations) + +## Examples +#### torchao_scaled_grouped_mm example: forward + backward pass +```python +import torch +from torch.nn import functional as F +from torchao.prototype.moe_training import ( + _scaled_grouped_mm as torchao_scaled_grouped_mm +) +from torchao.prototype.moe_training.conversion_utils import MoEScalingType +from torchao.prototype.moe_training.utils import generate_jagged_offs + +num_groups, total_M, N, K = 8, 131072, 8192, 5120 + +# A = input actvations, B = expert weights +A = torch.randn(total_M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) +B = torch.randn(num_groups, N, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) + +# Token group offsets computed by router in actual MoE layer +offs = generate_jagged_offs(num_groups, total_M, device="cuda") + +# Forward and backward example +out = torchao_scaled_grouped_mm( + A, + B.transpose(-2, -1), + offs=offs, + scaling_type=MoEScalingType.MXFP8, +) + +# (Fake labels for demonstration purposes) +labels = torch.ones_like(out) +loss = F.mse_loss(out, labels) +loss.backward() +``` + +#### Model conversion API example: end-to-end training ```python import torch from torch import nn from torch.nn import functional as F -# this feature requires CUDA and SM89+ -assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +# this feature requires CUDA 12.8+ and SM100+ +assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0) from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig from torchao.quantization.quant_api import quantize_ # this example uses torchtitan llama4 MoE, see +# this benchmark requires torchtitan try: - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE -except ImportError as e: - raise ImportError( - "torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan" - ) from e + from torchtitan.distributed.expert_parallel import ( + set_token_group_alignment_size_m, + ) + from torchtitan.models.moe import MoE, MoEArgs +except ImportError: + pytest.skip( + "torchtitan not installed, skipping MoE tests.", allow_module_level=True + ) # initialize model device = torch.device("cuda") -model_args = TransformerModelArgs( - moe_enabled=True, +moe_args = MoEArgs( num_experts=8, - dim=256, ) -model = MoE(model_args).to(torch.bfloat16).to(device) +dim, hidden_dim = 5120, 8192 +model = MoE(moe_args, dim, hidden_dim).to(torch.bfloat16).to(device) init_std = 0.02 model.init_weights(init_std, device) @@ -48,6 +100,9 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return True return False +# Token group alignment size must be 32 for MXFP8 training +alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16 +set_token_group_alignment_size_m(alignment_size) # quantize the model config = MoETrainingConfig() @@ -55,8 +110,8 @@ quantize_(model, config=config, filter_fn=moe_module_filter_fn) # training loop optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) +batch_size, seq_len = 2, 2048 for step in range(10): - batch, seq, dim = 8, 2048, 256 x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -75,11 +130,75 @@ for step in range(10): ``` -## Requirements -- torchao nightly build -- CUDA compute capability 8.9+ (SM89+) +## System requirements +- torchao 0.14+ +- For MXFP8 MoE training, CUDA 12.8+ and SM100+ GPU arch are required. +- For FP8 rowwise MoE training, CUDA 12.4+ and SM89+ GPU arch are required. + +## Performance benchmarks: MXFP8 +#### Single device, torchao _scaled_grouped_mm forward + backward pass vs torch._grouped_mm + +To reproduce this benchmark, on a B200 GPU machine, run the following command: +- `python benchmarks/prototype/moe_training/^Cnchmark_scaled_grouped_mm_dq.py --compile` +- torchao: `0.14.0+gitc7b8e13da` +- torch: `2.10.0a0+gitf6de195` + +#### Single device, Llama4 16e MoE layer forward + backward pass vs bfloat16 baseline + +Llama4 16e shapes: +``` +CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=5120 --hidden_dim=8192 --local_num_experts=8 +total_M: 131072, N: 8192, K: 5120 +bf16 time: 275.270 ms +mxfp8 time: 192.420 ms +speedup: 1.431x +``` + +DeepSeekV3 shapes: +``` +CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=7168 --hidden_dim=2048 --local_num_experts=8 +total_M: 131072, N: 2048, K: 7168 +bf16 time: 92.032 ms +mxfp8 time: 80.182 ms +speedup: 1.148x +``` + +#### End-to-end training with TorchTitan of Llama4 16e MoE layer vs bfloat16 baseline +- Single node benchmarks with 4xB200 +- Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8 +- Reduced num layers from 48 -> 2 to avoid OOM in single node setting + +- Debug model config: + +```python +llama4_configs = { + "debugmodel": TransformerModelArgs( + dim=5120, + n_layers=2, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, + rope_theta=500000, + max_seq_len=10485760, + moe_args=MoEArgs(num_experts=16), + interleave_moe_layer_step=1, + ), +``` +- [Full repro commands](https://www.internalfb.com/phabricator/paste/view/P1974482524) + + +| Configuration | Throughput (Median Tokens/s) | Max Memory (GiB) | +|:---------------------------------------------------------------------------|-----------------------------:|-----------------:| +| bf16 baseline | 49381.0 | 145.55 | +| MXFP8 for Linears only | 52038.0 | 146.62 | +| MXFP8 for Grouped GEMMs only | 69350.0 | 144.71 | +| MXFP8 for Linears + Grouped GEMMs | 70747.0 | 145.32 | +| MXFP8 for Linears + Grouped GEMMs + A2A Dispatch | 72602.5 | 145.45 | +| MXFP8 for Linears + Grouped GEMMs + A2A Dispatch + A2A Combine | 73152.0 | 146.08 | + -## Modeling requirements +## Implementation details for developers This prototype is specifically designed to be used on MoE models using `torch._grouped_mm` to implement expert computation in token-choice routing, where expert weights are implemented as 3D nn.Parameters with `num_experts` as @@ -97,5 +216,4 @@ operands in both the forward and backward pass. For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. ## Limitations -- Only tested with eager mode, single GPU training so far. -- Composability with parallelisms and `torch.compile` are next steps. +- The new CUDA kernel for MXFP8 quantization of the non-transposed expert weights in the backwards pass does not support TP yet. From 5a34b0b6cb07adc366627ef1c50668f7f83fd7b3 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 13 Oct 2025 09:26:58 -0700 Subject: [PATCH 2/3] add repro commands for titan benchmarks --- torchao/prototype/moe_training/README.md | 26 ++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index 03ea86dbcf..85ae1dc924 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -185,8 +185,6 @@ llama4_configs = { interleave_moe_layer_step=1, ), ``` -- [Full repro commands](https://www.internalfb.com/phabricator/paste/view/P1974482524) - | Configuration | Throughput (Median Tokens/s) | Max Memory (GiB) | |:---------------------------------------------------------------------------|-----------------------------:|-----------------:| @@ -194,8 +192,28 @@ llama4_configs = { | MXFP8 for Linears only | 52038.0 | 146.62 | | MXFP8 for Grouped GEMMs only | 69350.0 | 144.71 | | MXFP8 for Linears + Grouped GEMMs | 70747.0 | 145.32 | -| MXFP8 for Linears + Grouped GEMMs + A2A Dispatch | 72602.5 | 145.45 | -| MXFP8 for Linears + Grouped GEMMs + A2A Dispatch + A2A Combine | 73152.0 | 146.08 | + +#### Commands to reproduce these benchmarks: + +bfloat16 baseline: +``` +rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion" ./llama4.sh +``` + +MXFP8 dense only: +``` +rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.linear.mx"" ./llama4.sh +``` + +MXFP8 MoE only: +``` +rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.grouped_mm.mx"" ./llama4.sh +``` + +MXFP8 MoE + Dense: +``` +rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=50 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.grouped_mm.mx,quantize.linear.mx"" ./llama4.sh +``` ## Implementation details for developers From 14082b60b86d4437a08dfc034cffc7b5f1035aa6 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 13 Oct 2025 17:48:50 -0700 Subject: [PATCH 3/3] clean up readme --- torchao/prototype/moe_training/README.md | 92 ++++++++++++------------ 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index 85ae1dc924..99e64e259a 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -8,9 +8,9 @@ This prototype provides: - ~1.15 - 1.3x speedups over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes (depending on the `M` dimension, i.e. batch_size * seq_len) -2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: to get started with e2e pretraining of DeepSeekV3/Llama4 with torchtitan, simply add the flag to your training command: `--model.converters="quantize.grouped_mm.mx" [--quantize.grouped_mm.mx.fqns="experts"]` +2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration of torchao's dynamically quantized `_scaled_grouped_mm`: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" [--quantize.grouped_mm.mx.fqns="experts"]` -3. Model conversion via the torchao `quantize_(...)` API: this swaps all `torch._grouped_mm` ops in your model definition to use torchao `_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below). +3. `quantize_(...)` API support for model conversion: this swaps all `torch._grouped_mm` ops in your model definition to use torchao `_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below). ## Table of Contents @@ -136,62 +136,66 @@ for step in range(10): - For FP8 rowwise MoE training, CUDA 12.4+ and SM89+ GPU arch are required. ## Performance benchmarks: MXFP8 -#### Single device, torchao _scaled_grouped_mm forward + backward pass vs torch._grouped_mm -To reproduce this benchmark, on a B200 GPU machine, run the following command: -- `python benchmarks/prototype/moe_training/^Cnchmark_scaled_grouped_mm_dq.py --compile` -- torchao: `0.14.0+gitc7b8e13da` -- torch: `2.10.0a0+gitf6de195` -#### Single device, Llama4 16e MoE layer forward + backward pass vs bfloat16 baseline +### Single MoE layer forward + backward pass vs bfloat16 baseline -Llama4 16e shapes: -``` +| Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup | +|--------------|---------|------|------|---------------|-----------------|---------| +| Llama4 16e | 131072 | 8192 | 5120 | 275.270 | 192.420 | 1.431x | +| DeepSeekV3 | 131072 | 2048 | 7168 | 92.032 | 80.182 | 1.148x | + +To reproduce these benchmarks, on a B200 GPU machine, run the following commands: + +Llama4 17b 16e shapes: +```bash CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=5120 --hidden_dim=8192 --local_num_experts=8 -total_M: 131072, N: 8192, K: 5120 -bf16 time: 275.270 ms -mxfp8 time: 192.420 ms -speedup: 1.431x ``` -DeepSeekV3 shapes: -``` +DeepSeekV3 671b shapes: +```bash CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=7168 --hidden_dim=2048 --local_num_experts=8 -total_M: 131072, N: 2048, K: 7168 -bf16 time: 92.032 ms -mxfp8 time: 80.182 ms -speedup: 1.148x ``` -#### End-to-end training with TorchTitan of Llama4 16e MoE layer vs bfloat16 baseline +### Individual bfloat16 torch._grouped_mm op vs torchao_scaled_grouped_mm + +MXFP8: + +| M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup | +|------------------------|-----------------|-------------------|------------------------| +| (128000, 8192, 5120, 1) | 40463 | 24406 | 1.658x | +| (128000, 8192, 5120, 2) | 35494.5 | 24705.1 | 1.437x | +| (128000, 8192, 5120, 4) | 38879.3 | 24508.5 | 1.586x | +| (128000, 8192, 5120, 8) | 35714.6 | 25937.6 | 1.377x | +| (128000, 1536, 5120, 1) | 6353.06 | 7401.54 | 0.858x | +| (128000, 1536, 5120, 2) | 6511.65 | 6729.33 | 0.968x | +| (128000, 1536, 5120, 4) | 6455.2 | 6626.5 | 0.974x | +| (128000, 1536, 5120, 8) | 7716.13 | 6516.74 | 1.184x | +| (128000, 2048, 7168, 1) | 11758 | 11255.7 | 1.045x | +| (128000, 2048, 7168, 2) | 15012.9 | 9917.9 | 1.514x | +| (128000, 2048, 7168, 4) | 14904.2 | 10493.8 | 1.42x | +| (128000, 2048, 7168, 8) | 13178 | 9638.38 | 1.367x | + + +To reproduce this benchmark, on a B200 GPU machine, run the following command: +- `python benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py --compile` +- torchao: `0.14.0+gitc7b8e13da` +- torch: `2.10.0a0+gitf6de195` + + +#### End-to-end training: Llama4 16e MoE layer vs bfloat16 baseline with TorchTitan - Single node benchmarks with 4xB200 - Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8 - Reduced num layers from 48 -> 2 to avoid OOM in single node setting +- TorchTitan debug model config (same as Llama4 17bx16e, but with 2 layers): -- Debug model config: - -```python -llama4_configs = { - "debugmodel": TransformerModelArgs( - dim=5120, - n_layers=2, - n_heads=40, - n_kv_heads=8, - ffn_dim_multiplier=1.2, - multiple_of=2048, - rope_theta=500000, - max_seq_len=10485760, - moe_args=MoEArgs(num_experts=16), - interleave_moe_layer_step=1, - ), -``` -| Configuration | Throughput (Median Tokens/s) | Max Memory (GiB) | -|:---------------------------------------------------------------------------|-----------------------------:|-----------------:| -| bf16 baseline | 49381.0 | 145.55 | -| MXFP8 for Linears only | 52038.0 | 146.62 | -| MXFP8 for Grouped GEMMs only | 69350.0 | 144.71 | -| MXFP8 for Linears + Grouped GEMMs | 70747.0 | 145.32 | +| Configuration | Throughput (Median Tokens/s) | Max Memory (GiB) | Speedup over bf16 +|:---------------------------------------------------------------------------|-----------------------------:|------------------|------------------| +| bf16 baseline | 49381.0 | 145.55 | - +| MXFP8 for Linears only | 52038.0 | 146.62 | 1.053x +| MXFP8 for Grouped GEMMs only | 69350.0 | 144.71 | 1.404x +| MXFP8 for Linears + Grouped GEMMs | 70747.0 | 145.32 | 1.433x #### Commands to reproduce these benchmarks: