Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions docs/source/backends/nxp/nxp-quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,120 @@ quantized_graph_module = calibrate_and_quantize(
```

See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information.

### Quantization Aware Training

The NeutronQuantizer supports two modes of quantization: *Post‑Training Quantization (PTQ)* and *Quantization Aware Training (QAT)*.
PTQ uses a calibration phase to tune quantization parameters on an already‑trained model in order to obtain a model with integer weights.
While this optimization reduces model size, it introduces quantization noise and can degrade the model's performance.
Compared to PTQ, QAT enables the model to adapt its weights to the introduced quantization noise.
In QAT, instead of calibration we run training to optimize both quantization parameters and model weights at the same time.

See the [Quantization Aware Training blog post](https://pytorch.org/blog/quantization-aware-training/) for an introduction to the QAT method.

To use QAT with the Neutron backend, toggle the `is_qat` parameter:

```python
from executorch.backends.nxp.quantizer.neutron_quantizer import (
NeutronQuantizer,
NeutronTargetSpec,
)

target_spec = NeutronTargetSpec(target="imxrt700")
neutron_quantizer = NeutronQuantizer(neutron_target_spec=target_spec, is_qat=True)
```

The rest of the quantization pipeline works similarly to the PTQ workflow.
The most significant change is that the calibration step is replaced by training.

<div class="admonition tip">
Note: QAT uses <code>prepare_qat_pt2e</code> prepare function instead of <code>prepare_pt2e</code>.
</div>

```python
import torch
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
from torchao.quantization.pt2e import (
move_exported_model_to_eval,
move_exported_model_to_train,
disable_observer,
)

model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()

neutron_target_spec = NeutronTargetSpec(target="imxrt700")
quantizer = NeutronQuantizer(neutron_target_spec, is_qat=True) # (1)

sample_inputs = (torch.randn(1, 3, 224, 224),)
training_ep = torch.export.export(model, sample_inputs).module() # (2)

## Steps different from PTQ (3–6)
prepared_model = prepare_qat_pt2e(training_ep, quantizer) # (3) !!! Different prepare function
prepared_model = move_exported_model_to_train(prepared_model) # (4)

# ---------------- Training phase (5) ----------------
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(prepared_model.parameters(), lr=1e-2, momentum=0.9)

train_data = datasets.ImageNet("./", split="train", transform=...)
train_loader = DataLoader(train_data, batch_size=5)

# Training replaces calibration in QAT
for epoch in range(num_epochs):
for imgs, labels in train_loader:
optimizer.zero_grad()
outputs = prepared_model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# It is recommended to disable quantization params
# updates after few epochs of training.
if epoch >= num_epochs / 3:
model.apply(disable_observer)
# --------------- End of training phase ---------------

prepared_model = move_exported_model_to_eval(prepared_model) # (6)
quantized_model = convert_pt2e(prepared_model) # (7)

# Optional step - fixes biasless convolution (see Known Limitations of QAT)
quantized_model = QuantizeFusedConvBnBiasAtenPass(
default_zero_bias=True
)(quantized_model).graph_module

...
```

Moving from PTQ to QAT check-list:
- Set `is_qat=True` in `NeutronQuantizer`
- Use `prepare_qat_pt2e` instead of `prepare_pt2e`
- Call `move_exported_model_to_train()` before training
- Train the model instead of calibrating
- Call `move_exported_model_to_eval()` after training

#### Known limitations of QAT

In the current ExecuTorch/TorchAO implementation, there is an issue when quantizing biasless convolutions during QAT.
The pipeline produces a non‑quantized empty bias, which causes the Neutron Converter to fail.
To mitigate this issue, use the `QuantizeFusedConvBnBiasAtenPass` post‑quantization:

```python
...

# training

prepared_model = move_exported_model_to_eval(prepared_model) # (6)
quantized_model = convert_pt2e(prepared_model) # (7)

quantized_model = QuantizeFusedConvBnBiasAtenPass(
default_zero_bias=True
)(quantized_model).graph_module

...
```
2 changes: 2 additions & 0 deletions docs/source/quantization-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ These quantizers usually support configs that allow users to specify quantizatio
* Precision (e.g., 8-bit or 4-bit)
* Quantization type (e.g., dynamic, static, or weight-only quantization)
* Granularity (e.g., per-tensor, per-channel)
* Post-Training Quantization vs. Quantization Aware Training

Not all quantization options are supported by all backends. Consult backend-specific guides for supported quantization modes and configuration, and how to initialize the backend-specific PT2E quantizer:

* [XNNPACK quantization](backends/xnnpack/xnnpack-quantization.md)
* [CoreML quantization](backends/coreml/coreml-quantization.md)
* [QNN quantization](backends-qualcomm.md#step-2-optional-quantize-your-model)
* [NXP quantization](backends/nxp/nxp-quantization.md)



Expand Down
Loading