# Quantization-Aware Training (QAT) on a 3-Layer Neural Network

**Author**: Srimugunthan  
**Date**: February 2026

## Overview

This notebook demonstrates **Quantization-Aware Training (QAT)**, a technique where quantization effects are simulated during training to help the model learn to be robust to reduced precision.

### What is Quantization?

Quantization converts high-precision floating-point weights (32-bit) to lower-precision integers (8-bit, 4-bit, etc.) to:
- Reduce model size (4x smaller for 8-bit)
- Speed up inference
- Enable deployment on edge devices

### Why QAT vs Post-Training Quantization (PTQ)?

- **PTQ**: Quantize after training ‚Üí Can cause significant accuracy drop
- **QAT**: Simulate quantization during training ‚Üí Model learns to be robust to quantization noise

### Key Technique: Fake Quantization

- **Forward pass**: Quantize weights/activations to simulate low precision
- **Backward pass**: Use full precision gradients (straight-through estimator)

---

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 2. Understanding Quantization

### 2.1 Quantization Formula

For a floating-point value $x$, quantization to $n$ bits:

$$x_{\text{quantized}} = \text{clamp}\left(\left\lfloor \frac{x}{s} \right\rfloor + z, 0, 2^n - 1\right)$$

Where:
- $s$ = scale = $\frac{x_{\max} - x_{\min}}{2^n - 1}$
- $z$ = zero point = $-\frac{x_{\min}}{s}$

### 2.2 Visualization

In [None]:
def visualize_quantization(num_bits=8):
    """
    Visualize how quantization affects a continuous signal
    """
    # Generate continuous signal
    x = np.linspace(-5, 5, 1000)
    y = np.sin(x) * 2 + np.random.randn(1000) * 0.1
    
    # Quantize
    qmin, qmax = 0, 2**num_bits - 1
    scale = (y.max() - y.min()) / (qmax - qmin)
    zero_point = qmin - y.min() / scale
    
    y_int = np.clip(np.round(y / scale + zero_point), qmin, qmax)
    y_quant = (y_int - zero_point) * scale
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Original vs Quantized
    axes[0].plot(x, y, label='Original (FP32)', alpha=0.7, linewidth=2)
    axes[0].plot(x, y_quant, label=f'Quantized ({num_bits}-bit)', alpha=0.7, linewidth=2)
    axes[0].set_xlabel('X')
    axes[0].set_ylabel('Value')
    axes[0].set_title('Original vs Quantized Signal')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Quantization error
    error = np.abs(y - y_quant)
    axes[1].hist(error, bins=50, alpha=0.7, edgecolor='black')
    axes[1].set_xlabel('Absolute Error')
    axes[1].set_ylabel('Frequency')
    axes[1].set_title(f'Quantization Error Distribution\nMean Error: {error.mean():.6f}')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nQuantization Statistics ({num_bits}-bit):")
    print(f"  Scale: {scale:.6f}")
    print(f"  Zero point: {zero_point:.2f}")
    print(f"  Quantization levels: {2**num_bits}")
    print(f"  Mean absolute error: {error.mean():.6f}")
    print(f"  Max absolute error: {error.max():.6f}")

visualize_quantization(num_bits=8)

## 3. Implementing Fake Quantization

The core of QAT is **fake quantization**: quantize in the forward pass, but pass gradients unchanged in the backward pass.

In [None]:
class FakeQuantize(torch.autograd.Function):
    """
    Fake Quantization: Simulates quantization in forward pass,
    but passes gradients unchanged in backward pass (Straight-Through Estimator).
    
    This allows the network to learn to be robust to quantization noise.
    """
    @staticmethod
    def forward(ctx, x, num_bits=8):
        """
        Forward pass: Quantize the input
        
        Args:
            x: Input tensor (float32)
            num_bits: Number of bits for quantization (default: 8)
            
        Returns:
            Quantized tensor (still in float32, but with reduced precision)
        """
        # Define quantization range
        qmin = 0
        qmax = 2**num_bits - 1
        
        # Calculate scale and zero point
        min_val = x.min()
        max_val = x.max()
        
        scale = (max_val - min_val) / (qmax - qmin)
        scale = scale if scale > 1e-8 else 1e-8  # Avoid division by zero
        
        zero_point = qmin - min_val / scale
        zero_point = torch.clamp(torch.round(zero_point), qmin, qmax)
        
        # Quantize: float -> int
        x_int = torch.clamp(torch.round(x / scale + zero_point), qmin, qmax)
        
        # Dequantize: int -> float (simulates low precision inference)
        x_quant = (x_int - zero_point) * scale
        
        return x_quant
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: Straight-Through Estimator (STE)
        Pass gradients unchanged, as if quantization didn't happen.
        
        This is key: we need gradients to flow for training,
        but quantization is not differentiable.
        """
        return grad_output, None  # None for num_bits (not trainable)


# Test fake quantization
print("Testing FakeQuantize:")
x = torch.randn(5, 5) * 10
print(f"\nOriginal tensor:\n{x}")

x_quant_8bit = FakeQuantize.apply(x, 8)
print(f"\n8-bit quantized:\n{x_quant_8bit}")

x_quant_4bit = FakeQuantize.apply(x, 4)
print(f"\n4-bit quantized:\n{x_quant_4bit}")

print(f"\n8-bit error: {(x - x_quant_8bit).abs().mean().item():.6f}")
print(f"4-bit error: {(x - x_quant_4bit).abs().mean().item():.6f}")

## 4. Building Quantization-Aware Layers

Now we create a custom Linear layer that applies fake quantization to both weights and activations.

In [None]:
class QATLinear(nn.Module):
    """
    Quantization-Aware Linear Layer
    
    Applies fake quantization to:
    1. Weights
    2. Biases
    3. Output activations
    """
    def __init__(self, in_features, out_features, num_bits=8):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.num_bits = num_bits
        
    def forward(self, x):
        # 1. Quantize weights
        weight_quant = FakeQuantize.apply(self.linear.weight, self.num_bits)
        
        # 2. Quantize bias (if exists)
        if self.linear.bias is not None:
            bias_quant = FakeQuantize.apply(self.linear.bias, self.num_bits)
        else:
            bias_quant = None
        
        # 3. Forward pass with quantized weights
        output = nn.functional.linear(x, weight_quant, bias_quant)
        
        # 4. Quantize output activations
        output_quant = FakeQuantize.apply(output, self.num_bits)
        
        return output_quant
    
    def extra_repr(self):
        return f'in_features={self.linear.in_features}, out_features={self.linear.out_features}, num_bits={self.num_bits}'


# Test QATLinear
print("Testing QATLinear:")
layer = QATLinear(10, 5, num_bits=8)
x = torch.randn(3, 10)
output = layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nLayer: {layer}")

## 5. Building the Neural Networks

We'll create two versions of a 3-layer fully connected network:
1. **Normal**: Standard network (for Post-Training Quantization)
2. **QAT**: Network with fake quantization layers (for Quantization-Aware Training)

In [None]:
class ThreeLayerNet(nn.Module):
    """
    Standard 3-layer fully connected network
    Architecture: input -> FC -> ReLU -> FC -> ReLU -> FC -> output
    """
    def __init__(self, input_dim=10, hidden_dim=64, output_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class ThreeLayerNetQAT(nn.Module):
    """
    3-layer fully connected network with Quantization-Aware Training
    Same architecture, but uses QATLinear layers with fake quantization
    """
    def __init__(self, input_dim=10, hidden_dim=64, output_dim=2, num_bits=8):
        super().__init__()
        self.fc1 = QATLinear(input_dim, hidden_dim, num_bits)
        self.fc2 = QATLinear(hidden_dim, hidden_dim, num_bits)
        self.fc3 = QATLinear(hidden_dim, output_dim, num_bits)
        self.relu = nn.ReLU()
        self.num_bits = num_bits
        
    def forward(self, x):
        # Quantize input activations
        x = FakeQuantize.apply(x, self.num_bits)
        
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Display architectures
print("="*60)
print("NORMAL NETWORK ARCHITECTURE")
print("="*60)
model_normal = ThreeLayerNet(input_dim=10, hidden_dim=64, output_dim=2)
print(model_normal)
print(f"\nTotal parameters: {sum(p.numel() for p in model_normal.parameters()):,}")

print("\n" + "="*60)
print("QAT NETWORK ARCHITECTURE")
print("="*60)
model_qat = ThreeLayerNetQAT(input_dim=10, hidden_dim=64, output_dim=2, num_bits=8)
print(model_qat)
print(f"\nTotal parameters: {sum(p.numel() for p in model_qat.parameters()):,}")

## 6. Generate Synthetic Dataset

We'll create a simple binary classification task.

In [None]:
def generate_dataset(n_samples=2000, n_features=10, test_size=0.2, random_state=42):
    """
    Generate synthetic binary classification dataset
    Label = 1 if sum of features > 0, else 0
    """
    np.random.seed(random_state)
    torch.manual_seed(random_state)
    
    # Generate features
    X = torch.randn(n_samples, n_features)
    
    # Generate labels (binary classification)
    y = (X.sum(dim=1) > 0).long()
    
    # Split into train and test
    n_train = int(n_samples * (1 - test_size))
    
    X_train, X_test = X[:n_train], X[n_train:]
    y_train, y_test = y[:n_train], y[n_train:]
    
    print(f"Dataset Generated:")
    print(f"  Training samples: {len(X_train)}")
    print(f"  Test samples: {len(X_test)}")
    print(f"  Features: {n_features}")
    print(f"  Classes: {len(torch.unique(y))}")
    print(f"  Class distribution (train): {torch.bincount(y_train).tolist()}")
    print(f"  Class distribution (test): {torch.bincount(y_test).tolist()}")
    
    return X_train, X_test, y_train, y_test


# Generate data
X_train, X_test, y_train, y_test = generate_dataset(
    n_samples=2000, 
    n_features=10, 
    test_size=0.2,
    random_state=42
)

# Create DataLoaders
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"\nDataLoader created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 7. Training Functions

In [None]:
def train_model(model, train_loader, test_loader, epochs=20, lr=0.001, model_name="Model"):
    """
    Train a neural network and track metrics
    
    Returns:
        model: Trained model
        history: Dictionary with training history
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }
    
    print(f"\nTraining {model_name}...")
    print("="*60)
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += y_batch.size(0)
            train_correct += predicted.eq(y_batch).sum().item()
        
        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total
        
        # Evaluation phase
        model.eval()
        test_loss = 0
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += y_batch.size(0)
                test_correct += predicted.eq(y_batch).sum().item()
        
        test_loss = test_loss / len(test_loader)
        test_acc = 100. * test_correct / test_total
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        # Print progress
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch [{epoch+1:2d}/{epochs}] | "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
                  f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    print("="*60)
    print(f"Training completed!")
    print(f"Final Test Accuracy: {test_acc:.2f}%")
    
    return model, history


def plot_training_history(histories, labels):
    """
    Plot training curves for multiple models
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for history, label in zip(histories, labels):
        epochs = range(1, len(history['train_loss']) + 1)
        
        # Loss plot
        axes[0].plot(epochs, history['train_loss'], label=f'{label} (Train)', linewidth=2)
        axes[0].plot(epochs, history['test_loss'], '--', label=f'{label} (Test)', linewidth=2)
        
        # Accuracy plot
        axes[1].plot(epochs, history['train_acc'], label=f'{label} (Train)', linewidth=2)
        axes[1].plot(epochs, history['test_acc'], '--', label=f'{label} (Test)', linewidth=2)
    
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Test Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training and Test Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 8. Post-Training Quantization (PTQ)

This function quantizes an already-trained model's weights.

In [None]:
def quantize_model_weights(model, num_bits=8):
    """
    Apply Post-Training Quantization to a trained model
    Quantizes all weights and biases in-place
    """
    qmin = 0
    qmax = 2**num_bits - 1
    
    print(f"\nApplying {num_bits}-bit Post-Training Quantization...")
    
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'weight' in name or 'bias' in name:
                # Calculate scale and zero point
                min_val = param.min()
                max_val = param.max()
                
                scale = (max_val - min_val) / (qmax - qmin)
                scale = scale if scale > 1e-8 else 1e-8
                
                zero_point = qmin - min_val / scale
                zero_point = torch.clamp(torch.round(zero_point), qmin, qmax)
                
                # Quantize
                param_int = torch.clamp(torch.round(param / scale + zero_point), qmin, qmax)
                
                # Dequantize and update parameter
                param_quant = (param_int - zero_point) * scale
                param.copy_(param_quant)
                
                print(f"  Quantized: {name} (shape: {param.shape})")
    
    print("Quantization complete!")
    return model


def evaluate_model(model, test_loader):
    """
    Evaluate model accuracy on test set
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            outputs = model(X_batch)
            _, predicted = outputs.max(1)
            total += y_batch.size(0)
            correct += predicted.eq(y_batch).sum().item()
    
    accuracy = 100. * correct / total
    return accuracy

## 9. Experiment 1: Train Normal Model + Apply PTQ

In [None]:
print("\n" + "#"*70)
print("# EXPERIMENT 1: Normal Model with Post-Training Quantization (PTQ)")
print("#"*70)

# Train normal model
model_normal = ThreeLayerNet(input_dim=10, hidden_dim=64, output_dim=2)
model_normal, history_normal = train_model(
    model_normal, 
    train_loader, 
    test_loader, 
    epochs=20, 
    lr=0.001,
    model_name="Normal Model"
)

# Evaluate before quantization
acc_before_ptq = evaluate_model(model_normal, test_loader)
print(f"\n‚úì Accuracy BEFORE quantization: {acc_before_ptq:.2f}%")

# Apply PTQ
model_normal = quantize_model_weights(model_normal, num_bits=8)

# Evaluate after quantization
acc_after_ptq = evaluate_model(model_normal, test_loader)
print(f"\n‚úì Accuracy AFTER 8-bit quantization: {acc_after_ptq:.2f}%")
print(f"‚úó Accuracy drop: {acc_before_ptq - acc_after_ptq:.2f}%")

## 10. Experiment 2: Train QAT Model

In [None]:
print("\n" + "#"*70)
print("# EXPERIMENT 2: Quantization-Aware Training (QAT)")
print("#"*70)

# Train QAT model
model_qat = ThreeLayerNetQAT(input_dim=10, hidden_dim=64, output_dim=2, num_bits=8)
model_qat, history_qat = train_model(
    model_qat, 
    train_loader, 
    test_loader, 
    epochs=20, 
    lr=0.001,
    model_name="QAT Model"
)

# Evaluate QAT model (already trained with quantization)
acc_qat = evaluate_model(model_qat, test_loader)
print(f"\n‚úì QAT Model accuracy: {acc_qat:.2f}%")

## 11. Compare Results

In [None]:
print("\n" + "="*70)
print("FINAL COMPARISON: Post-Training Quantization vs QAT")
print("="*70)

print(f"\n{'Method':<30} {'Accuracy':<15} {'Notes'}")
print("-" * 70)
print(f"{'Normal (FP32 - Before PTQ)':<30} {acc_before_ptq:>6.2f}%        {'Baseline (full precision)'}")
print(f"{'Normal (8-bit PTQ)':<30} {acc_after_ptq:>6.2f}%        {'After post-training quantization'}")
print(f"{'QAT (8-bit)':<30} {acc_qat:>6.2f}%        {'Trained with quantization'}")
print("-" * 70)

ptq_drop = acc_before_ptq - acc_after_ptq
qat_advantage = acc_qat - acc_after_ptq

print(f"\nüìä Key Insights:")
print(f"   ‚Ä¢ PTQ accuracy drop: {ptq_drop:.2f}%")
print(f"   ‚Ä¢ QAT advantage over PTQ: {qat_advantage:.2f}%")
print(f"   ‚Ä¢ QAT recovery: {(acc_qat/acc_before_ptq)*100:.1f}% of original accuracy")

if qat_advantage > 0:
    print(f"\n‚úÖ QAT successfully maintains higher accuracy than PTQ!")
else:
    print(f"\n‚ö†Ô∏è  In this case, PTQ performed similarly to QAT (dataset may be too simple)")

## 12. Visualize Training Curves

In [None]:
plot_training_history(
    [history_normal, history_qat],
    ['Normal Model', 'QAT Model']
)

## 13. Analyze Weight Distributions

In [None]:
def visualize_weight_distributions(model_normal, model_qat):
    """
    Compare weight distributions between normal and QAT models
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    fig.suptitle('Weight Distributions: Normal vs QAT Models', fontsize=14, fontweight='bold')
    
    layers = ['fc1', 'fc2', 'fc3']
    
    for idx, layer_name in enumerate(layers):
        # Get weights
        normal_weights = getattr(model_normal, layer_name).weight.detach().cpu().numpy().flatten()
        qat_weights = getattr(model_qat, layer_name).linear.weight.detach().cpu().numpy().flatten()
        
        # Normal model
        axes[0, idx].hist(normal_weights, bins=50, alpha=0.7, color='blue', edgecolor='black')
        axes[0, idx].set_title(f'Normal Model - {layer_name.upper()}')
        axes[0, idx].set_xlabel('Weight Value')
        axes[0, idx].set_ylabel('Frequency')
        axes[0, idx].grid(True, alpha=0.3)
        
        # QAT model
        axes[1, idx].hist(qat_weights, bins=50, alpha=0.7, color='green', edgecolor='black')
        axes[1, idx].set_title(f'QAT Model - {layer_name.upper()}')
        axes[1, idx].set_xlabel('Weight Value')
        axes[1, idx].set_ylabel('Frequency')
        axes[1, idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_weight_distributions(model_normal, model_qat)

## 14. Model Size Comparison

In [None]:
def calculate_model_size(model):
    """
    Calculate model size in MB (assuming float32)
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.numel() * 4  # 4 bytes per float32
    
    size_mb = param_size / (1024 * 1024)
    return size_mb

# Calculate sizes
size_fp32 = calculate_model_size(model_normal)
size_int8 = size_fp32 / 4  # 8-bit is 4x smaller

print("\n" + "="*70)
print("MODEL SIZE COMPARISON")
print("="*70)
print(f"\n{'Precision':<20} {'Size (MB)':<15} {'Compression Ratio'}")
print("-" * 70)
print(f"{'FP32 (Original)':<20} {size_fp32:>8.4f}        {'1.0x (baseline)'}")
print(f"{'INT8 (Quantized)':<20} {size_int8:>8.4f}        {f'{size_fp32/size_int8:.1f}x smaller'}")
print("-" * 70)
print(f"\nüíæ Storage savings: {size_fp32 - size_int8:.4f} MB ({((size_fp32-size_int8)/size_fp32)*100:.1f}% reduction)")

## 15. Key Takeaways

### What is QAT?
- **Quantization-Aware Training** simulates low-precision inference during training
- Uses **fake quantization**: quantize forward pass, full precision backward pass
- Allows model to learn to be robust to quantization noise

### QAT vs Post-Training Quantization (PTQ)
1. **PTQ**: Train in FP32 ‚Üí Quantize afterwards ‚Üí Can lose significant accuracy
2. **QAT**: Train with quantization simulation ‚Üí Better accuracy retention

### When to Use QAT?
‚úÖ **Use QAT when:**
- Deploying to edge devices (mobile, IoT)
- Need aggressive quantization (4-bit, lower)
- Accuracy is critical
- Have time/resources for retraining

‚úÖ **Use PTQ when:**
- Quick quantization needed
- Model already robust to noise
- Limited training resources
- 8-bit quantization sufficient

### Implementation Details
1. **Straight-Through Estimator (STE)**: Pass gradients through non-differentiable quantization
2. **Quantize everything**: Weights, biases, and activations
3. **Scale and zero-point**: Map float range to integer range

### Benefits
- **4x smaller models** (FP32 ‚Üí INT8)
- **Faster inference** (integer operations)
- **Lower power consumption**
- **Better accuracy** than PTQ

---

## Further Reading
- [PyTorch Quantization Documentation](https://pytorch.org/docs/stable/quantization.html)
- [Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference](https://arxiv.org/abs/1712.05877)
- [A Survey of Quantization Methods for Efficient Neural Network Inference](https://arxiv.org/abs/2103.13630)