# Appendix: Quantisation and Quantisation-Aware Training

This tutorial-style appendix introduces the motivations, numerics, and tooling behind **model quantisation** with a focus on transformers. We provide conceptual explanations and concise PyTorch code snippets so you can reason about precision trade-offs without running large-scale LLM training.

## Learning goals

By the end of this appendix you should be able to:

- Explain why quantisation is attractive for deployment and how it differs from pruning or distillation.
- Describe the most common integer formats (e.g. int8, int4) and how scale/zero-point pairs represent floating-point ranges.
- Compare post-training quantisation (PTQ) with quantisation-aware training (QAT).
- Prototype a minimal PTQ and QAT workflow in PyTorch using toy data.

## 1. Why quantise large language models?

Quantisation maps high-precision parameters and activations (usually 16-bit or 32-bit floats) to lower-precision integer representations. The key benefits are:

- **Latency:** Integer matrix multiplications are faster on CPUs and accelerators that ship with dedicated int8/int4 instructions.
- **Memory footprint:** Reducing precision from 16-bit to 8-bit halves the storage requirements for weights, activations, and gradients.
- **Bandwidth:** Smaller tensors move more quickly across device memory hierarchies, improving throughput for autoregressive decoding.

The trade-off is reduced representational fidelity. Effective quantisation strategies aim to minimise task degradation by carefully choosing calibration data, numeric formats, and training procedures.

### Quantisation versus other efficiency techniques

| Technique | Core idea | Typical gain | Trade-offs |
|-----------|-----------|--------------|------------|
| Pruning | Remove parameters entirely | Smaller models, sparse compute | Requires sparse kernels, may harm accuracy |
| Distillation | Train a smaller student on teacher outputs | Compact model with similar behaviour | Needs extra training, may miss rare behaviours |
| Quantisation | Lower the precision of weights/activations | Faster + smaller with same architecture | Numerical noise can degrade quality |

Quantisation composes well with pruning and distillation, but each solves a distinct optimisation problem.

## 2. Quantisation fundamentals

A uniform affine quantiser represents a real value $x$ with:

$$\hat{x} = \text{clip}\left(\left\lfloor \frac{x}{s} \right\rceil + z, q_{\min}, q_{\max} \right),$$

where $s$ is the **scale**, $z$ is the **zero-point**, and $q_{\min}, q_{\max}$ bound the integer range (e.g. $[-128, 127]$ for signed int8).

To recover an approximate float, we dequantise via:

$$x \approx s \cdot (\hat{x} - z).$$

Two practical choices dominate transformer quantisation:

- **Per-tensor vs. per-channel scales:** LayerNorm and attention projections often benefit from per-channel scaling to preserve dynamic ranges across attention heads.
- **Symmetric vs. asymmetric:** Symmetric quantisers fix $z = 0$ and work well when data are zero-centred (common for weights). Asymmetric quantisers add a zero-point to better capture skewed activation distributions.

### Choosing bit-widths

| Bit-width | Common usage | Notes |
|-----------|--------------|-------|
| 16-bit (FP16/BF16/FP8) | Training and mixed-precision inference | Floating formats preserve wide dynamic range. |
| 8-bit int | Widely adopted for weights + activations | Best hardware support across CPU/GPU/TPU. |
| 4-bit int | Aggressive compression for decoder-only LLMs | Needs more careful calibration, often weight-only. |
| 2-bit / Binary | Specialised research | Requires custom kernels; accuracy drops are larger. |

## 3. PyTorch quantisation toolchain overview

PyTorch (since 2.0) exposes quantisation APIs under `torch.ao.quantization`. The typical workflow is:

1. **Define** a float32 model.
2. **Prepare** it with a quantisation configuration (qconfig) that describes observer types and backend kernels.
3. **Calibrate** or **train** to collect activation statistics.
4. **Convert** to a quantised model with integer kernels.

Depending on whether calibration happens after or during training we distinguish PTQ and QAT.

## 4. Post-training quantisation (PTQ)

PTQ transforms a pre-trained float model to a quantised one using a small calibration dataset. No gradient updates are required, making PTQ attractive when the original training data are unavailable or expensive to reproduce.

Below we quantise a toy multilayer perceptron. Although tiny, the code mirrors what you would do with transformer blocks: prepare, calibrate, convert.

In [None]:
import torch
from torch import nn

# A toy float32 model (e.g. mimicking an MLP block in a transformer)
float_model = nn.Sequential(
    nn.Linear(16, 32),
    nn.ReLU(),
    nn.Linear(32, 16),
)

float_model.eval()
example_inputs = torch.randn(64, 16)

# Baseline inference
with torch.inference_mode():
    float_outputs = float_model(example_inputs)
print(f"Float model dtype: {float_outputs.dtype}, shape: {float_outputs.shape}")

In [None]:
# Configure PTQ using the fbgemm backend (appropriate for x86 CPUs)
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
prepared_model = torch.ao.quantization.prepare(float_model, qconfig)

# Calibration pass: run a few batches to collect activation statistics
with torch.inference_mode():
    prepared_model(example_inputs)

# Convert to a quantised model
quantized_model = torch.ao.quantization.convert(prepared_model)

with torch.inference_mode():
    quantized_outputs = quantized_model(example_inputs)

print(quantized_model)
print(f"Quantised output dtype: {quantized_outputs.dtype}")
print(f"Max absolute diff vs float32: {(float_outputs - quantized_outputs.float()).abs().max():.4f}")

Even without additional fine-tuning, int8 PTQ often introduces only mild accuracy degradation for decoder-only LLMs when you calibrate with a few hundred representative prompts. For activation-heavy modules (e.g. attention scores), per-channel observers and smooth quantisation transformations can further reduce error.

## 5. Quantisation-aware training (QAT)

PTQ struggles when activation distributions shift significantly between calibration data and real workloads, or when you push to 4-bit precision. **Quantisation-aware training** mitigates this by simulating quantisation effects during fine-tuning.

The high-level idea:

1. Insert fake-quantisation modules that emulate quant/dequant in the forward pass.
2. Backpropagate through straight-through estimators so gradients flow to float weights.
3. After training, convert to real integer kernels.

Below we fine-tune our toy network for a few gradient steps with QAT. Instead of optimising a real task, we match a random target tensor just to illustrate the mechanics.

In [None]:
torch.manual_seed(0)

qat_model = nn.Sequential(
    nn.Linear(16, 32),
    nn.ReLU(),
    nn.Linear(32, 16),
)

qat_model.train()

qat_qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")
prepared_qat = torch.ao.quantization.prepare_qat(qat_model, qat_qconfig)

optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=1e-2)
loss_fn = nn.MSELoss()

for step in range(5):
    optimizer.zero_grad()
    inputs = torch.randn(32, 16)
    targets = torch.randn(32, 16)
    outputs = prepared_qat(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    if step == 0 or (step + 1) % 2 == 0:
        print(f"Step {step+1}: loss={loss.item():.4f}")

prepared_qat.eval()
quantized_qat_model = torch.ao.quantization.convert(prepared_qat)

print(quantized_qat_model)

In a real LLM fine-tuning session you would:

- Start from a float checkpoint, insert QAT observers in linear/attention modules, and resume training with a small learning rate.
- Use task-representative prompts so the model adapts to quantisation noise where it matters most (e.g. long-context decoding).
- Optionally freeze embedding layers or layer norms if they destabilise under fake-quantisation.

### Practical tips

- **Gradient scaling:** Lower precisions amplify rounding error, so gradient clipping and slightly higher weight decay can stabilise QAT.
- **Progressive bit-width schedules:** Begin with int8 fake quantisers and anneal to int4 once the model adapts.
- **Layer-wise decisions:** Some components (e.g. logits projection) may remain in higher precision to safeguard perplexity.
- **Evaluation:** Always compare quantised and float baselines on downstream metrics (perplexity, accuracy) and monitor decoding latency to ensure the trade-off is worthwhile.

## 6. Beyond standard QAT

Research is rapidly evolving:

- **GPTQ/AWQ:** Optimise weight quantisation with second-order approximations or activation-aware scaling.
- **LLM.int8():** Hybrid scheme keeping outlier activations in FP16 while quantising the rest.
- **FP8 training/inference:** Uses floating-point formats with learned scaling, now supported on some GPUs.
- **Structured sparsity + quantisation:** Combining 2:4 sparsity with int8 kernels for maximal throughput.

These techniques build upon the PTQ/QAT primitives described above.

## 7. Key takeaways

- Quantisation trades numerical precision for latency and memory gains, making it crucial for serving LLMs on cost-sensitive hardware.
- PTQ is simple and works well when you have decent calibration data and stick to 8-bit weights/activations.
- QAT injects quantisation noise during training, enabling more aggressive bit-widths or challenging domains at the cost of additional fine-tuning.
- PyTorch's `torch.ao.quantization` module provides batteries-included utilities to prepare, calibrate, and convert models, so you can experiment without rewriting your architecture.