Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,42 @@
# MX training and inference with native PyTorch

This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
e2e training and inference with mxfp8, mxfp4, nvfp4 formats from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
in native PyTorch.

> :warning: We are currently in prototype. Use nightly versions of PyTorch and torchao (or build from source) for best results.

## Overall status

| workflow | emulation | performance | accuracy |
| --- | --- | --- | --- |
| training with mxfp8 | ✅ | ✅ | ✅ |
| inference with mxfp8, mxfp6, mxfp4 | ✅ | 🔲 | 🔲 |
### mxfp8

| workflow | emulation | performance | accuracy | API polish |
| --- | --- | --- | --- | --- |
| training for `torch.nn.Linear` | ✅ | 🟡 / 🟢 | 🟢 | 🟡 |
| inference for `torch.nn.Linear` | ✅ | 🟡 / 🟢 | 🟢 | 🟡 |

### nvfp4

| workflow | emulation | performance | accuracy | API polish |
| --- | --- | --- | --- | --- |
| training for `torch.nn.Linear` | ✅ | 🔴 | 🟡 | 🟡 |
| QAT for `torch.nn.Linear` | ✅ | n/a | 🟢 | 🟡 |
| inference for `torch.nn.Linear` | ✅ | 🟡 / 🟢 | 🟢 | 🟡 |

### mxfp4

ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>
| workflow | emulation | performance | accuracy | API polish |
| --- | --- | --- | --- | --- |
| training for `torch.nn.Linear` | ✅ | 🔴 | 🟡 | 🟡 |
| QAT for `torch.nn.Linear` | planned | n/a | planned | planned |
| inference for `torch.nn.Linear` | ✅ | 🔴 | 🟢 | 🟡 |

### planned improvements

* mxfp8 support for grouped_gemm and all2all for MoE training (see https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training ).
* mxfp8, nvfp4, mxfp4 performance optimizations for inference
* polish the nvpf4 QAT recipe, and enable mxfp4 QAT
* blocked formats for faster training
* stochastic rounding and hadamard transforms for improved fp4 training numerics

## Training e2e benchmarks on NVIDIA B200

Expand Down Expand Up @@ -42,6 +68,8 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re

## MX training

Below is a toy training loop. For an example real training loop, see our torchtitan integration here: https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/mx.py .

```python
import torch
from torchao.quantization import quantize_
Expand Down Expand Up @@ -150,7 +178,7 @@ x_hp = x_mx.to_dtype(torch.float)
## mxfp8 gemm

On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op.
We observe a speedup of **2x to 3x** vs the bf16 baseline on common shapes. To reproduce this
We observe a speedup of **up to ~2x** vs the bf16 baseline on common shapes. To reproduce this
on supported hardware, you can run the following command:

```bash
Expand All @@ -160,7 +188,7 @@ on supported hardware, you can run the following command:

## to_mx cast across dim0 and dim1

On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile),
On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 6.3 TB/s** for the dim0 cast (with torch.compile),
and **up to 3.9 TB/s** for the dim1 cast (with a triton kernel). We are actively working on improving
the performance of this cast ([details](https://github.com/pytorch/ao/issues/1768)).

Expand All @@ -176,16 +204,11 @@ To reproduce this on supported hardware, you can run the following command:
// example output: https://gist.github.com/vkuzo/7ac5fce44c9b90bfb9eae2a07b721cda
```

## performance tracker

Please see our [performance tracker](https://github.com/pytorch/ao/issues/1768) for the latest on MX training and inference performance!

# accuracy

## training

* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (code not in this repo)
* we match bitwise to other implementations of the OCP MX spec (code not in this repo), with a couple of edge cases left to resolve
* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (via torchtitan)

## inference

Expand Down
Loading