From 06a147d9ee4db61c03c45e7ae321226f9edb1729 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 20 Oct 2025 13:35:07 -0700 Subject: [PATCH] Update QAT README --- README.md | 3 +- torchao/quantization/qat/README.md | 153 ++++++++++++++++++++++------- 2 files changed, 117 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index ad3e0b6f97..9a6d45c06d 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ### PyTorch-Native Training-to-Serving Model Optimization - Pre-train Llama-3.1-70B **1.5x faster** with float8 training -- Recover **77% of quantized perplexity degradation** on Llama-3.2-3B with QAT +- Recover **67% of quantized accuracy degradation** on Gemma3-4B with QAT - Quantize Llama-3-8B to int4 for **1.89x faster** inference with **58% less memory**
@@ -118,6 +118,7 @@ Please see the [torchao compability table](https://github.com/pytorch/ao/issues/ TorchAO is integrated into some of the leading open-source libraries including: +* Unsloth for QAT, blog post coming soon! * HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865) * HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md) * HuggingFace PEFT for LoRA using TorchAO as their [quantization backend](https://huggingface.co/docs/peft/en/developer_guides/quantization#torchao-pytorch-architecture-optimization) diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 9a11aa7b51..c699e9648d 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -142,7 +142,8 @@ quantize_(m, qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding ``` -### Quantizer API (legacy) +
+

Quantizer API (legacy)

Alternatively, torchao provides a few hardcoded quantization settings through the following Quantizers, but these may be removed soon: @@ -191,8 +192,51 @@ model = qat_quantizer.prepare(model) train_loop(model) model = qat_quantizer.convert(model) ``` +
-## torchtune integration +## Axolotl integration + +[Axolotl](https://github.com/axolotl-ai-cloud) uses TorchAO to support quantized-aware fine-tuning. You can use the following commands to fine-tune, and then quantize a Llama-3.2-3B model: + +```bash +axolotl train examples/llama-3/3b-qat-fsdp2.yaml +# once training is complete, perform the quantization step +axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml +# you should now have a quantized model saved in ./outputs/qat_out/quatized +``` + +Please see the [QAT documentation](https://docs.axolotl.ai/docs/qat.html) in axolotl for more details. + + +## Unsloth integration + +[Unsloth](https://github.com/unslothai/unsloth) also leverages TorchAO for quantized-aware fine-tuning. Unsloth's QAT support can be used with both full and LoRA fine-tuning. For example: + +```python +from unsloth import FastLanguageModel + +model, tokenizer = FastLanguageModel.from_pretrained( + "unsloth/Qwen3-4B-Instruct-2507", + max_seq_len = 2048, + dtype = torch.bfloat16, + load_in_4bit = False, + full_finetuning = False, +) + +model = FastLanguageModel.get_peft_model( + model, + r = 16, + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",], + lora_alpha = 16, + qat_scheme = "int4", +) +``` + +For a full notebook example, see: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(14B)-Reasoning-Conversational.ipynb. A QAT-specific notebook is coming soon. + + +
+

torchtune integration (legacy)

torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune) to allow users to run quantized-aware fine-tuning as follows: @@ -210,47 +254,80 @@ tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config ll ``` For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). +
-## Axolotl integration +## Evaluation Results -[Axolotl](https://github.com/axolotl-ai-cloud) uses torchao to support quantized-aware fine-tuning. You can use the following commands to fine-tune, and then quantize a Llama-3.2-3B model: +Int4 weight-only QAT + LoRA using a group size of 128, fine-tuned using Unsloth. +Both fine-tuning and evaluation was done on a single H100 GPU using the +[mlabonne/FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) +dataset. Learning rate was 2e-5 and batch size was 64 with no gradient accumulation. -```bash -axolotl train examples/llama-3/3b-qat-fsdp2.yaml -# once training is complete, perform the quantization step -axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml -# you should now have a quantized model saved in ./outputs/qat_out/quatized +``` +# gemma3-12b-it ++-------------+-----------------+-----------------+------------+-------------+ +| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered | ++=============+=================+=================+============+=============+ +| wikitext | 9.1477 | 9.7745 | 9.5631 | 33.727% | ++-------------+-----------------+-----------------+------------+-------------+ +| bbh | 0.8079 | 0.7624 | 0.7831 | 45.495% | ++-------------+-----------------+-----------------+------------+-------------+ + +# gemma3-4b-it ++-------------+-----------------+-----------------+------------+-------------+ +| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered | ++=============+=================+=================+============+=============+ +| wikitext | 12.1155 | 13.247 | 12.797 | 39.770% | ++-------------+-----------------+-----------------+------------+-------------+ +| bbh | 0.7074 | 0.6415 | 0.6666 | 38.088 | ++-------------+-----------------+-----------------+------------+-------------+ +| gpqa | 0.3232 | 0.3081 | 0.3182 | 66.887% | ++-------------+-----------------+-----------------+------------+-------------+ + +# Qwen3-4B-Instruct ++-------------+-----------------+-----------------+------------+-------------+ +| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered | ++=============+=================+=================+============+=============+ +| mmlu-pro | 0.4909 | 0.4328 | 0.4524 | 33.735% | ++-------------+-----------------+-----------------+------------+-------------+ + +# Llama3.2-3B ++-------------+-----------------+-----------------+------------+-------------+ +| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered | ++=============+=================+=================+============+=============+ +| wikitext | 12.1322 | 13.3459 | 12.8796 | 38.420% | ++-------------+-----------------+-----------------+------------+-------------+ +| bbh | 0.5483 | 0.4967 | 0.5174 | 40.116% | ++-------------+-----------------+-----------------+------------+-------------+ +| gpqa | 0.3333 | 0.2879 | 0.303 | 33.260% | ++-------------+-----------------+-----------------+------------+-------------+ +| mmlu-pro | 0.2771 | 0.2562 | 0.2629 | 32.057% | ++-------------+-----------------+-----------------+------------+-------------+ ``` -Please see the [QAT documentation](https://docs.axolotl.ai/docs/qat.html) in axolotl for more details. - -## Evaluation Results +NVFP4 QAT full fine-tuning, fine-tuned using Axolotl on 8x B200 GPUs on the +[yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned) +dataset. Learning rate was 2e-5 and batch size was 128 for `gemma3-12b-it` +and 32 for `Qwen3-8B`. -Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT -integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) -on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset) -for 5000 steps using a group size of 256 for the weights. Note that extensive -hyperparameter tuning may further improve these results. - -Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5: - -| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) | -| ---------------- | ------ | ------ | ------ | ------ | ------ | -| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 | -| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 | -| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 | -| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 | -| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 | - -Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the -quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097). - -| | hellaswag
(acc) | hellaswag
(acc_norm) | wikitext
(word_perplexity) | wikitext
(byte_perplexity) | wikitext
(bits_per_byte) | -| ---------------- | -------- | ------- | ------ | ------ | ------ | -| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 | -| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 | -| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 | -| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 | -| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 | +``` +# gemma3-12b-it ++-------------+-----------------+------------------+-------------+-------------+ +| Eval task | bf16 baseline | nvfp4 baseline | nvfp4 QAT | recovered | ++=============+=================+==================+=============+=============+ +| bbh | 0.7527 | 0.7068 | 0.7222 | 33.551% | ++-------------+-----------------+------------------+-------------+-------------+ +| mmlu-pro | 0.4074 | 0.3621 | 0.3702 | 17.881% | ++-------------+-----------------+------------------+-------------+-------------+ + +# Qwen3-8B ++-------------+-----------------+------------------+-------------+-------------+ +| Eval task | bf16 baseline | nvfp4 baseline | nvfp4 QAT | recovered | ++=============+=================+==================+=============+=============+ +| bbh | 0.7771 | 0.7262 | 0.7397 | 26.523% | ++-------------+-----------------+------------------+-------------+-------------+ +| mmlu-pro | 0.4929 | 0.4519 | 0.4686 | 40.732% | ++-------------+-----------------+------------------+-------------+-------------+ +``` For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training).