Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

### PyTorch-Native Training-to-Serving Model Optimization
- Pre-train Llama-3.1-70B **1.5x faster** with float8 training
- Recover **77% of quantized perplexity degradation** on Llama-3.2-3B with QAT
- Recover **67% of quantized accuracy degradation** on Gemma3-4B with QAT
- Quantize Llama-3-8B to int4 for **1.89x faster** inference with **58% less memory**

<div align="center">
Expand Down Expand Up @@ -118,6 +118,7 @@ Please see the [torchao compability table](https://github.com/pytorch/ao/issues/

TorchAO is integrated into some of the leading open-source libraries including:

* Unsloth for QAT, blog post coming soon!
* HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
* HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
* HuggingFace PEFT for LoRA using TorchAO as their [quantization backend](https://huggingface.co/docs/peft/en/developer_guides/quantization#torchao-pytorch-architecture-optimization)
Expand Down
153 changes: 115 additions & 38 deletions torchao/quantization/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ quantize_(m, qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding
```


### Quantizer API (legacy)
<details>
<summary><h3>Quantizer API (legacy)</h3></summary>

Alternatively, torchao provides a few hardcoded quantization settings through
the following Quantizers, but these may be removed soon:
Expand Down Expand Up @@ -191,8 +192,51 @@ model = qat_quantizer.prepare(model)
train_loop(model)
model = qat_quantizer.convert(model)
```
</details>

## torchtune integration
## Axolotl integration

[Axolotl](https://github.com/axolotl-ai-cloud) uses TorchAO to support quantized-aware fine-tuning. You can use the following commands to fine-tune, and then quantize a Llama-3.2-3B model:

```bash
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
# once training is complete, perform the quantization step
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
# you should now have a quantized model saved in ./outputs/qat_out/quatized
```

Please see the [QAT documentation](https://docs.axolotl.ai/docs/qat.html) in axolotl for more details.


## Unsloth integration

[Unsloth](https://github.com/unslothai/unsloth) also leverages TorchAO for quantized-aware fine-tuning. Unsloth's QAT support can be used with both full and LoRA fine-tuning. For example:

```python
from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/Qwen3-4B-Instruct-2507",
max_seq_len = 2048,
dtype = torch.bfloat16,
load_in_4bit = False,
full_finetuning = False,
)

model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
qat_scheme = "int4",
)
```

For a full notebook example, see: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(14B)-Reasoning-Conversational.ipynb. A QAT-specific notebook is coming soon.


<details>
<summary><h2>torchtune integration (legacy)</h2></summary>

torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune)
to allow users to run quantized-aware fine-tuning as follows:
Expand All @@ -210,47 +254,80 @@ tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config ll
```

For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
</details>

## Axolotl integration
## Evaluation Results

[Axolotl](https://github.com/axolotl-ai-cloud) uses torchao to support quantized-aware fine-tuning. You can use the following commands to fine-tune, and then quantize a Llama-3.2-3B model:
Int4 weight-only QAT + LoRA using a group size of 128, fine-tuned using Unsloth.
Both fine-tuning and evaluation was done on a single H100 GPU using the
[mlabonne/FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k)
dataset. Learning rate was 2e-5 and batch size was 64 with no gradient accumulation.

```bash
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
# once training is complete, perform the quantization step
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
# you should now have a quantized model saved in ./outputs/qat_out/quatized
```
# gemma3-12b-it
+-------------+-----------------+-----------------+------------+-------------+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
+=============+=================+=================+============+=============+
| wikitext | 9.1477 | 9.7745 | 9.5631 | 33.727% |
+-------------+-----------------+-----------------+------------+-------------+
| bbh | 0.8079 | 0.7624 | 0.7831 | 45.495% |
+-------------+-----------------+-----------------+------------+-------------+

# gemma3-4b-it
+-------------+-----------------+-----------------+------------+-------------+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
+=============+=================+=================+============+=============+
| wikitext | 12.1155 | 13.247 | 12.797 | 39.770% |
+-------------+-----------------+-----------------+------------+-------------+
| bbh | 0.7074 | 0.6415 | 0.6666 | 38.088 |
+-------------+-----------------+-----------------+------------+-------------+
| gpqa | 0.3232 | 0.3081 | 0.3182 | 66.887% |
+-------------+-----------------+-----------------+------------+-------------+

# Qwen3-4B-Instruct
+-------------+-----------------+-----------------+------------+-------------+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
+=============+=================+=================+============+=============+
| mmlu-pro | 0.4909 | 0.4328 | 0.4524 | 33.735% |
+-------------+-----------------+-----------------+------------+-------------+

# Llama3.2-3B
+-------------+-----------------+-----------------+------------+-------------+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
+=============+=================+=================+============+=============+
| wikitext | 12.1322 | 13.3459 | 12.8796 | 38.420% |
+-------------+-----------------+-----------------+------------+-------------+
| bbh | 0.5483 | 0.4967 | 0.5174 | 40.116% |
+-------------+-----------------+-----------------+------------+-------------+
| gpqa | 0.3333 | 0.2879 | 0.303 | 33.260% |
+-------------+-----------------+-----------------+------------+-------------+
| mmlu-pro | 0.2771 | 0.2562 | 0.2629 | 32.057% |
+-------------+-----------------+-----------------+------------+-------------+
```

Please see the [QAT documentation](https://docs.axolotl.ai/docs/qat.html) in axolotl for more details.

## Evaluation Results
NVFP4 QAT full fine-tuning, fine-tuned using Axolotl on 8x B200 GPUs on the
[yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned)
dataset. Learning rate was 2e-5 and batch size was 128 for `gemma3-12b-it`
and 32 for `Qwen3-8B`.

Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT
integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset)
for 5000 steps using a group size of 256 for the weights. Note that extensive
hyperparameter tuning may further improve these results.

Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5:

| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
| ---------------- | ------ | ------ | ------ | ------ | ------ |
| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 |
| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 |
| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 |
| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 |
| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 |

Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the
quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097).

| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
| ---------------- | -------- | ------- | ------ | ------ | ------ |
| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 |
| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 |
| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 |
| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 |
| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 |
```
# gemma3-12b-it
+-------------+-----------------+------------------+-------------+-------------+
| Eval task | bf16 baseline | nvfp4 baseline | nvfp4 QAT | recovered |
+=============+=================+==================+=============+=============+
| bbh | 0.7527 | 0.7068 | 0.7222 | 33.551% |
+-------------+-----------------+------------------+-------------+-------------+
| mmlu-pro | 0.4074 | 0.3621 | 0.3702 | 17.881% |
+-------------+-----------------+------------------+-------------+-------------+

# Qwen3-8B
+-------------+-----------------+------------------+-------------+-------------+
| Eval task | bf16 baseline | nvfp4 baseline | nvfp4 QAT | recovered |
+=============+=================+==================+=============+=============+
| bbh | 0.7771 | 0.7262 | 0.7397 | 26.523% |
+-------------+-----------------+------------------+-------------+-------------+
| mmlu-pro | 0.4929 | 0.4519 | 0.4686 | 40.732% |
+-------------+-----------------+------------------+-------------+-------------+
```

For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training).
Loading