-
Notifications
You must be signed in to change notification settings - Fork 363
Open
Labels
Description
We want to support DeepSeekV3-style FP8 blockwise training in torchao for both dense and MoE models.
Support for dense models (linears)
We can extend the fp8 blockwise training prototype for dense models here which has the core functionality complete, but performance is unoptimized.
The work has been broken down into the following tasks, which anyone is free to work on:
- Functionality
- 1x128 quantization for LHS activations, write to row major layout
- 128x1 quantization for RHS activations, write to col major layout
- 128x128 quantization for weights, write to col major layout
- 1x128 @ 128x128 gemm, use for:
-
output = input @ weight.t() -
dgrad = grad_output @ weight
-
- 1x128 @ 128x1 gemm, use for:
-
wgrad = grad_output.t() @ input
-
- Autograd function implementing forward and backward
- DTensor handling for TP support
- Custom ops around all custom kernels for
torch.compilecomposability - Tests for FSDP, TP
-
quantize_model conversion api peforming module swap of nn.Linear to FP8BlockwiseLinear (wraps autograd func) - [P1] fp8 blockwise all-gather for FSDP (would need to ensure weight-shards are padded to be divisible by 128x128 blocks, design TBD)
- Performance
- all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
- benchmark scripts for each quantization kernel
- all gemm kernels run at 80%+ of peak achievable TFLOPs/sec on Hopper
- benchmark scripts for each gemm
- all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
- Integration into torchtitan
- Validate loss convergence virtually identical to bf16 for 3k+ steps on full size Llama3 8b/70b
- Validate e2e throughput (TPS) improvement in same training run as above
- Documentation
- README
- torchao docsite
- Migrate out of prototype directory, integrate into
torchao.float8module
Support for MoE layers (grouped GEMMs)
We can extend the low precision MoE training code here to support fp8 blockwise by doing the following:
- Functionality
- Quantization
- 128x128 quantization compatible with 3d expert weights, write to per-expert col major layout (e.g. shape (E,N,K) with strides (N*K,1,N))
- Per-token group 1x128 scale conversion where group boundaries are along K/contracting dim
- Alternative: pad token group alignment to multiples of 128, like is done in torchtitan here)
- Per-token group 128x1 scale conversion where group boundaries are along K/contracting dim
- Alternative: pad token group alignment to multiples of 128, like is done in torchtitan here)
- GEMMs
- 1x128 @ 128x128 scaled grouped gemm
-
output = input @ weight.transpose(-2,-1) -
dgrad = grad_output @ weight
-
- 1x128 @ 128x1 scaled grouped gemm
-
wgrad = grad_output.transpose(-2,-1) @ input
-
- 1x128 @ 128x128 scaled grouped gemm
- Autograd function implementing forward and backward with dynamic quant on inputs (see mxfp8 example)
- DTensor handling for TP support
- Custom ops around all custom kernels for
torch.compilecomposability - Tests for FSDP, TP
-
quantize_model conversion api peforming module swap of nn.Linear to FP8BlockwiseLinear (wraps autograd func)
- Quantization
- Performance
- all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
- benchmark scripts for each quantization kernel
- all gemm kernels run at 80%+ of peak achievable TFLOPs/sec on Hopper
- benchmark scripts for each gemm
- all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
- Integration into torchtitan
- Validate loss convergence virtually identical to bf16 for 3k+ steps on full size DeepSeekV3 671b
- Validate e2e throughput (TPS) improvement in same training run as above
- Documentation
- README
- torchao docsite
vkuzo, omkaark, thib-s, R0n12, a-r-r-o-w and 6 more