From a2a135d047e763270e521d2f646135261e56a5e7 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Sat, 18 Oct 2025 15:27:02 -0400 Subject: [PATCH] mx_formats: Update README.md --- torchao/prototype/mx_formats/README.md | 53 ++++++++++++++++++-------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index c3869f2761..b5ba91326d 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -1,16 +1,42 @@ # MX training and inference with native PyTorch -This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) -in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware. +e2e training and inference with mxfp8, mxfp4, nvfp4 formats from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) +in native PyTorch. + +> :warning: We are currently in prototype. Use nightly versions of PyTorch and torchao (or build from source) for best results. ## Overall status -| workflow | emulation | performance | accuracy | -| --- | --- | --- | --- | -| training with mxfp8 | ✅ | ✅ | ✅ | -| inference with mxfp8, mxfp6, mxfp4 | ✅ | 🔲 | 🔲 | +### mxfp8 + +| workflow | emulation | performance | accuracy | API polish | +| --- | --- | --- | --- | --- | +| training for `torch.nn.Linear` | ✅ | 🟡 / đŸŸĸ | đŸŸĸ | 🟡 | +| inference for `torch.nn.Linear` | ✅ | 🟡 / đŸŸĸ | đŸŸĸ | 🟡 | + +### nvfp4 + +| workflow | emulation | performance | accuracy | API polish | +| --- | --- | --- | --- | --- | +| training for `torch.nn.Linear` | ✅ | 🔴 | 🟡 | 🟡 | +| QAT for `torch.nn.Linear` | ✅ | n/a | đŸŸĸ | 🟡 | +| inference for `torch.nn.Linear` | ✅ | 🟡 / đŸŸĸ | đŸŸĸ | 🟡 | + +### mxfp4 -â„šī¸ See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features. +| workflow | emulation | performance | accuracy | API polish | +| --- | --- | --- | --- | --- | +| training for `torch.nn.Linear` | ✅ | 🔴 | 🟡 | 🟡 | +| QAT for `torch.nn.Linear` | planned | n/a | planned | planned | +| inference for `torch.nn.Linear` | ✅ | 🔴 | đŸŸĸ | 🟡 | + +### planned improvements + +* mxfp8 support for grouped_gemm and all2all for MoE training (see https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training ). +* mxfp8, nvfp4, mxfp4 performance optimizations for inference +* polish the nvpf4 QAT recipe, and enable mxfp4 QAT +* blocked formats for faster training +* stochastic rounding and hadamard transforms for improved fp4 training numerics ## Training e2e benchmarks on NVIDIA B200 @@ -42,6 +68,8 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re ## MX training +Below is a toy training loop. For an example real training loop, see our torchtitan integration here: https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/mx.py . + ```python import torch from torchao.quantization import quantize_ @@ -150,7 +178,7 @@ x_hp = x_mx.to_dtype(torch.float) ## mxfp8 gemm On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op. -We observe a speedup of **2x to 3x** vs the bf16 baseline on common shapes. To reproduce this +We observe a speedup of **up to ~2x** vs the bf16 baseline on common shapes. To reproduce this on supported hardware, you can run the following command: ```bash @@ -160,7 +188,7 @@ on supported hardware, you can run the following command: ## to_mx cast across dim0 and dim1 -On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile), +On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 6.3 TB/s** for the dim0 cast (with torch.compile), and **up to 3.9 TB/s** for the dim1 cast (with a triton kernel). We are actively working on improving the performance of this cast ([details](https://github.com/pytorch/ao/issues/1768)). @@ -176,16 +204,11 @@ To reproduce this on supported hardware, you can run the following command: // example output: https://gist.github.com/vkuzo/7ac5fce44c9b90bfb9eae2a07b721cda ``` -## performance tracker - -Please see our [performance tracker](https://github.com/pytorch/ao/issues/1768) for the latest on MX training and inference performance! - # accuracy ## training -* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (code not in this repo) -* we match bitwise to other implementations of the OCP MX spec (code not in this repo), with a couple of edge cases left to resolve +* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (via torchtitan) ## inference