# Module 5: Quantization Aware Training (QAT) from Scratch

**Objective:**
In Module 4, we saw that certain "Diva" layers break when quantized to INT4. Standard Post-Training Quantization (PTQ) fails here because the weights were never trained to survive the "rounding error."

**The Problem: The Vanishing Gradient**
We cannot simply add `x = round(x)` to our training loop. The mathematical derivative of the round function is **0** everywhere (flat steps). If we use it, gradients die, and backpropagation stops.

**The Solution: The Straight Through Estimator (STE)**
We will implement the industry-standard "hack" to bypass this:
1.  **Forward Pass:** Apply the rounding ($x_q = \text{round}(x)$) so the model feels the noise.
2.  **Backward Pass:** Lie to the optimizer. We pretend the function was Identity ($y=x$), allowing gradients to flow "straight through" the quantization block.

**Goal:**
Compare a **Naive PTQ** approach (Train FP32 -> Quantize) vs. **QAT** (Train Quantized) on a challenging INT4 regression task.

In [None]:
# @title Setup & Synthetic Data
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

# Set seed for reproducibility
torch.manual_seed(42)

# Generate noisy sine wave data
# Task: Regress y = sin(x)
X = torch.linspace(-3, 3, 200).unsqueeze(1) # Inputs
y = torch.sin(X) + torch.randn(X.size()) * 0.1 # Targets with noise

# Split into Train/Test
indices = torch.randperm(len(X))
X_train, y_train = X[indices[:160]], y[indices[:160]]
X_test, y_test = X[indices[160:]], y[indices[160:]]

plt.figure(figsize=(10, 5))
plt.scatter(X_train, y_train, s=10, label='Training Data')
plt.plot(X, torch.sin(X), color='red', alpha=0.5, label='True Function')
plt.title("The Task: Fit the Sine Wave")
plt.legend()
plt.show()

In [None]:
# @title The Straight Through Estimator (STE)
class STEQuantize(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, scale, zero_point, qmin, qmax):
        """
        Forward: Quantize the input (Discrete).
        """
        # Save metadata if needed for backward (not needed for simple STE)
        # ctx.save_for_backward(input)

        # 1. Quantize
        q_input = (input / scale + zero_point).round().clamp(qmin, qmax)

        # 2. Dequantize (Fake Quantization)
        dq_input = (q_input - zero_point) * scale

        return dq_input

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward: Pass the gradient straight through (Continuous).
        We return a gradient for every input argument.
        Only 'input' needs a real gradient; others are scalars (None).
        """
        # THE LIE: We pretend the forward pass was Identity.
        # grad_input = grad_output

        # Refinement: Clip gradients where input was clipped?
        # (Simple STE ignores clipping for stability, but we pass it through)
        return grad_output, None, None, None, None

# Helper wrapper to use it like a function
def ste_quantize(x, num_bits=4):
    # Calculate simple Min-Max scale
    qmin = -(2**(num_bits - 1))
    qmax = (2**(num_bits - 1)) - 1

    min_val, max_val = x.min(), x.max()

    # Avoid div-by-zero
    if min_val == max_val:
        return x

    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = qmin - min_val / scale

    return STEQuantize.apply(x, scale, zero_point, qmin, qmax)

In [None]:
# @title The QAT Layer (Linear + STE)
class QATLinear(nn.Linear):
    def __init__(self, in_features, out_features, num_bits=4):
        super().__init__(in_features, out_features)
        self.num_bits = num_bits

    def forward(self, input):
        # 1. Quantize Weights (Training the model to deal with jagged weights)
        w_quant = ste_quantize(self.weight, self.num_bits)

        # 2. Quantize Input (Optional: Activation Quantization)
        # For simplicity, we focus on Weight QAT here.
        # i_quant = ste_quantize(input, self.num_bits)

        # 3. Linear Operation using Quantized Weights
        return F.linear(input, w_quant, self.bias)

import torch.nn.functional as F
print("QAT Linear Layer defined.")

In [None]:
# @title Define Models: Baseline vs. QAT
# A simple MLP: 1 -> 64 -> 64 -> 1

# 1. Standard FP32 Model
model_fp32 = nn.Sequential(
    nn.Linear(1, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 1)
)

# 2. QAT Model (Identical architecture, but with QATLinear)
model_qat = nn.Sequential(
    QATLinear(1, 64, num_bits=4),
    nn.ReLU(),
    QATLinear(64, 64, num_bits=4),
    nn.ReLU(),
    QATLinear(64, 1, num_bits=4)
)

print("Models initialized.")

In [None]:
# @title Train Both Models
def train(model, name, epochs=500):
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()
    losses = []

    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        preds = model(X_train)
        loss = criterion(preds, y_train)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    return losses

print("Training FP32 Baseline...")
losses_fp32 = train(model_fp32, "FP32")

print("Training QAT Model (INT4)...")
losses_qat = train(model_qat, "QAT")

plt.plot(losses_fp32, label='FP32 Loss')
plt.plot(losses_qat, label='QAT Loss')
plt.yscale('log')
plt.legend()
plt.title("Training Convergence")
plt.show()

In [None]:
# @title Final Benchmark: PTQ vs QAT
def evaluate(model, X, y, quantize_weights=False):
    model.eval()
    with torch.no_grad():
        # If testing PTQ, we must manually quantize the FP32 weights now
        if quantize_weights:
            for module in model.modules():
                if isinstance(module, nn.Linear):
                    module.weight.data = ste_quantize(module.weight.data, num_bits=4)

        preds = model(X)
        loss = nn.MSELoss()(preds, y)
    return loss.item(), preds

# 1. Evaluate FP32 (Golden)
loss_fp32, preds_fp32 = evaluate(model_fp32, X_test, y_test, quantize_weights=False)

# 2. Evaluate PTQ (Take FP32 model -> Crush to INT4)
import copy
model_ptq = copy.deepcopy(model_fp32)
loss_ptq, preds_ptq = evaluate(model_ptq, X_test, y_test, quantize_weights=True)

# 3. Evaluate QAT (Already trained with quantization)
loss_qat, preds_qat = evaluate(model_qat, X_test, y_test, quantize_weights=False)

print(f"MSE Loss Comparison (Lower is Better):")
print(f"1. FP32 Baseline:  {loss_fp32:.5f}")
print(f"2. Naive PTQ:      {loss_ptq:.5f}  (Usually fails)")
print(f"3. Trained QAT:    {loss_qat:.5f}  (Should recover)")

# --- VISUALIZATION FIX ---
# Sort data by X value so the plot lines draw smoothly from left to right
sort_indices = X_test.flatten().argsort()

X_sorted = X_test[sort_indices]
y_sorted = y_test[sort_indices]
fp32_sorted = preds_fp32[sort_indices]
ptq_sorted = preds_ptq[sort_indices]
qat_sorted = preds_qat[sort_indices]

plt.figure(figsize=(10, 6))
plt.scatter(X_sorted, y_sorted, color='gray', alpha=0.5, label='Test Data')

plt.plot(X_sorted, fp32_sorted, label='FP32 (Golden)', linestyle='--')
plt.plot(X_sorted, ptq_sorted, label='Naive PTQ (Failure)', color='red', alpha=0.7)
plt.plot(X_sorted, qat_sorted, label='QAT (Success)', color='green', linewidth=2.5)

plt.title("Visualizing the Recovery: PTQ vs QAT")
plt.legend()
plt.show()