# Quantization for Fine-Tuning LLMs: A Complete Guide

## What is Quantization?

**Quantization** is the process of reducing the precision of a model's parameters (weights and activations) by representing them with fewer bits. Instead of using 32-bit floating-point numbers (FP32), we can use 16-bit floats (FP16), 8-bit integers (INT8), or even 4-bit representations (INT4/NF4).

## Why Does Quantization Matter for Fine-Tuning?

When fine-tuning Large Language Models (LLMs), we face three main challenges:

1. **Memory Constraints**: A 7B parameter model in FP32 requires ~28GB of VRAM just for the weights
2. **Training Overhead**: Fine-tuning requires storing optimizer states (2x model size), gradients (1x model size), and activations
3. **Hardware Limitations**: Most researchers don't have access to 80GB A100 GPUs

**Quantization enables fine-tuning on consumer hardware** by reducing memory footprint by 75% (8-bit) or even 87.5% (4-bit) while maintaining model quality.

## What You'll Learn

This notebook demonstrates:

1. **Precision Formats**: How FP32, FP16, INT8, and INT4 represent numbers
2. **Memory Impact**: Quantifying the memory savings from quantization
3. **Quantization Methods**: Symmetric vs asymmetric quantization
4. **QLoRA & NF4**: State-of-the-art 4-bit quantization for fine-tuning
5. **Practical Trade-offs**: When to use each precision format

In [ ]:
import numpy as np
import struct
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for better visualizations
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

# Example weight value from a neural network
weight_value = 0.3457

# Function to visualize bit representation
def float_to_bin(num, precision='float32'):
    """Convert float to binary representation"""
    if precision == 'float32':
        packed = struct.pack('!f', num)
        bits = ''.join(f'{byte:08b}' for byte in packed)
        return bits, 32
    elif precision == 'float16':
        packed = struct.pack('!e', num)
        bits = ''.join(f'{byte:08b}' for byte in packed)
        return bits, 16

# Compare representations
fp32_bits, fp32_size = float_to_bin(weight_value, 'float32')
fp16_bits, fp16_size = float_to_bin(weight_value, 'float16')

print(f"Original value: {weight_value}")
print(f"\nFP32 (32-bit): {fp32_bits}")
print(f"  └─ Sign: {fp32_bits[0]} | Exponent: {fp32_bits[1:9]} | Mantissa: {fp32_bits[9:]}")
print(f"  └─ Memory: {fp32_size} bits = {fp32_size // 8} bytes")

print(f"\nFP16 (16-bit): {fp16_bits}")
print(f"  └─ Sign: {fp16_bits[0]} | Exponent: {fp16_bits[1:6]} | Mantissa: {fp16_bits[6:]}")
print(f"  └─ Memory: {fp16_size} bits = {fp16_size // 8} bytes")

# Show precision comparison
print(f"\n{'Format':<10} {'Bits':<8} {'Range':<30} {'Precision':<15}")
print("-" * 70)
print(f"{'FP32':<10} {'32':<8} {'±3.4e38':<30} {'~7 digits':<15}")
print(f"{'FP16':<10} {'16':<8} {'±6.5e4':<30} {'~3 digits':<15}")
print(f"{'INT8':<10} {'8':<8} {'-128 to 127':<30} {'Integer only':<15}")
print(f"{'INT4':<10} {'4':<8} {'-8 to 7':<30} {'Integer only':<15}")

In [ ]:
# Calculate memory for a 7B parameter LLM
model_size = 7_000_000_000  # 7 billion parameters

# Memory in different formats (in GB)
fp32_mem = (model_size * 4) / (1024**3)  # 4 bytes per parameter
fp16_mem = (model_size * 2) / (1024**3)  # 2 bytes per parameter
int8_mem = (model_size * 1) / (1024**3)  # 1 byte per parameter
int4_mem = (model_size * 0.5) / (1024**3)  # 0.5 bytes per parameter

print("Memory Requirements for 7B Parameter Model:")
print(f"  FP32:  {fp32_mem:.2f} GB")
print(f"  FP16:  {fp16_mem:.2f} GB ({fp32_mem/fp16_mem:.1f}x smaller)")
print(f"  INT8:  {int8_mem:.2f} GB ({fp32_mem/int8_mem:.1f}x smaller)")
print(f"  INT4:  {int4_mem:.2f} GB ({fp32_mem/int4_mem:.1f}x smaller)")

# Visualize
formats = ['FP32', 'FP16', 'INT8', 'INT4']
memory = [fp32_mem, fp16_mem, int8_mem, int4_mem]

plt.figure(figsize=(10, 6))
bars = plt.bar(formats, memory, color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'])
plt.ylabel('Memory (GB)', fontsize=12)
plt.title('Memory Requirements for 7B LLM in Different Formats', fontsize=14, fontweight='bold')
plt.grid(axis='y', alpha=0.3)

for bar, mem in zip(bars, memory):
    plt.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
            f'{mem:.1f} GB', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

In [ ]:
# Generate sample weights from a normal distribution (like real neural network weights)
np.random.seed(42)
weights = np.random.randn(10000) * 0.5  # Mean=0, std=0.5

def symmetric_quantization(weights, n_bits=8):
    """Symmetric quantization using absolute maximum"""
    # Calculate scale based on maximum absolute value
    max_val = np.max(np.abs(weights))
    scale = max_val / (2**(n_bits-1) - 1)
    
    # Quantize
    quantized = np.round(weights / scale).astype(np.int8)
    
    # Dequantize back
    dequantized = quantized * scale
    
    return quantized, dequantized, scale

def asymmetric_quantization(weights, n_bits=8):
    """Asymmetric quantization with zero-point"""
    min_val = np.min(weights)
    max_val = np.max(weights)
    
    # Calculate scale and zero-point
    scale = (max_val - min_val) / (2**n_bits - 1)
    zero_point = int(-min_val / scale)
    
    # Quantize
    quantized = np.round(weights / scale + zero_point).astype(np.uint8)
    
    # Dequantize
    dequantized = (quantized - zero_point) * scale
    
    return quantized, dequantized, scale, zero_point

# Apply both methods
sym_quant, sym_dequant, sym_scale = symmetric_quantization(weights)
asym_quant, asym_dequant, asym_scale, asym_zero = asymmetric_quantization(weights)

# Calculate errors
sym_error = np.mean(np.abs(weights - sym_dequant))
asym_error = np.mean(np.abs(weights - asym_dequant))

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Original distribution
axes[0].hist(weights, bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0].set_title('Original Weights Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Frequency')
axes[0].grid(alpha=0.3)

# Symmetric quantization
axes[1].scatter(weights[:500], sym_dequant[:500], alpha=0.5, s=10, color='green')
axes[1].plot([-2, 2], [-2, 2], 'r--', linewidth=2, label='Perfect reconstruction')
axes[1].set_title(f'Symmetric Quantization\nError: {sym_error:.6f}', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Original Weight')
axes[1].set_ylabel('Quantized Weight')
axes[1].legend()
axes[1].grid(alpha=0.3)

# Asymmetric quantization
axes[2].scatter(weights[:500], asym_dequant[:500], alpha=0.5, s=10, color='orange')
axes[2].plot([-2, 2], [-2, 2], 'r--', linewidth=2, label='Perfect reconstruction')
axes[2].set_title(f'Asymmetric Quantization\nError: {asym_error:.6f}', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Original Weight')
axes[2].set_ylabel('Quantized Weight')
axes[2].legend()
axes[2].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Symmetric quantization error: {sym_error:.6f}")
print(f"Asymmetric quantization error: {asym_error:.6f}")

## 4. QLoRA: 4-bit Quantization for Fine-Tuning

**QLoRA (Quantized Low-Rank Adaptation)** enables fine-tuning large language models on consumer GPUs.

### Key Concepts:

1. **4-bit NormalFloat (NF4)**: A special data type optimized for neural network weights (which follow a normal distribution)
2. **Frozen base model**: The original model stays in 4-bit and doesn't get updated
3. **Trainable adapters**: Small LoRA layers (in FP16) are added and trained

### Benefits:
- Fine-tune 70B models on a single 48GB GPU
- 4x memory reduction vs 16-bit fine-tuning
- Maintains 99%+ of full precision performance

In [None]:
# NF4 quantization levels (16 values optimized for normal distribution)
# These values are hardcoded based on quantiles of N(0,1) distribution
NF4_VALUES = np.array([
    -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
    -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
    0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
    0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
])

def nf4_quantization(weights, block_size=64):
    """
    Implement NF4 quantization similar to QLoRA
    
    Steps:
    1. Split weights into blocks
    2. For each block: normalize to [-1, 1]
    3. Map to closest NF4 value
    4. Store block scale factors
    """
    # Reshape into blocks
    n_blocks = len(weights) // block_size
    weights_blocked = weights[:n_blocks * block_size].reshape(n_blocks, block_size)
    
    quantized_blocks = []
    scale_factors = []
    
    for block in weights_blocked:
        # Calculate absmax for this block
        absmax = np.max(np.abs(block))
        scale_factors.append(absmax)
        
        # Normalize to [-1, 1]
        normalized = block / (absmax + 1e-8)  # Add small epsilon to avoid division by zero
        
        # Find closest NF4 value for each weight
        quantized = np.zeros_like(normalized, dtype=np.int8)
        for i, val in enumerate(normalized):
            # Find index of closest NF4 value
            idx = np.argmin(np.abs(NF4_VALUES - val))
            quantized[i] = idx
        
        quantized_blocks.append(quantized)
    
    return np.array(quantized_blocks).flatten(), np.array(scale_factors)

def nf4_dequantization(quantized, scale_factors, block_size=64):
    """Dequantize NF4 values back to original range"""
    n_blocks = len(scale_factors)
    quantized_blocked = quantized.reshape(n_blocks, block_size)
    
    dequantized = []
    for block_quant, scale in zip(quantized_blocked, scale_factors):
        # Map indices back to NF4 values
        block_nf4 = NF4_VALUES[block_quant]
        # Scale back to original range
        block_original = block_nf4 * scale
        dequantized.extend(block_original)
    
    return np.array(dequantized)

# Use the weights from the previous cell (generated in cell-3)
weights_sample = weights[:1024]  # Use 1024 weights for clean blocks
nf4_quant, nf4_scales = nf4_quantization(weights_sample, block_size=64)
nf4_dequant = nf4_dequantization(nf4_quant, nf4_scales, block_size=64)

# Calculate error and compression
nf4_error = np.mean(np.abs(weights_sample - nf4_dequant))
original_size = weights_sample.nbytes  # FP32 = 4 bytes per weight
nf4_size = len(nf4_quant) * 0.5 + len(nf4_scales) * 4  # 4 bits per weight + FP32 scales
compression_ratio = original_size / nf4_size

# Visualize NF4 quantization
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# NF4 quantization levels
axes[0, 0].stem(range(len(NF4_VALUES)), NF4_VALUES, basefmt=' ')
axes[0, 0].axhline(y=0, color='red', linestyle='--', linewidth=2)
axes[0, 0].set_title('NF4 Quantization Levels (16 values)', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Index')
axes[0, 0].set_ylabel('Value')
axes[0, 0].grid(alpha=0.3)

# Weight distribution with NF4 levels
axes[0, 1].hist(weights_sample / np.max(np.abs(weights_sample)), bins=50, alpha=0.6, color='blue', label='Normalized weights', density=True)
for val in NF4_VALUES:
    axes[0, 1].axvline(x=val, color='red', alpha=0.5, linewidth=1)
axes[0, 1].set_title('Weight Distribution vs NF4 Levels', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Normalized Weight Value')
axes[0, 1].set_ylabel('Density')
axes[0, 1].legend()

# Original vs Dequantized
axes[1, 0].scatter(weights_sample[:200], nf4_dequant[:200], alpha=0.5, s=15, color='purple')
axes[1, 0].plot([-2, 2], [-2, 2], 'r--', linewidth=2, label='Perfect reconstruction')
axes[1, 0].set_title(f'NF4 Quantization Quality\nError: {nf4_error:.6f}', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Original Weight')
axes[1, 0].set_ylabel('Dequantized Weight')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# Compression comparison
methods = ['FP32', 'INT8', 'NF4']
sizes = [original_size, original_size / 4, nf4_size]
colors_comp = ['#FF6B6B', '#4ECDC4', '#96CEB4']

bars = axes[1, 1].bar(methods, sizes, color=colors_comp, edgecolor='black', linewidth=2)
axes[1, 1].set_title('Storage Size Comparison', fontsize=14, fontweight='bold')
axes[1, 1].set_ylabel('Size (bytes)')
axes[1, 1].grid(axis='y', alpha=0.3)

for bar, size in zip(bars, sizes):
    axes[1, 1].text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                    f'{size:.0f}B\n({original_size/size:.1f}x)', 
                    ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\nNF4 Quantization Results:")
print(f"  Original size: {original_size} bytes (FP32)")
print(f"  NF4 size: {nf4_size:.0f} bytes")
print(f"  Compression ratio: {compression_ratio:.2f}x")
print(f"  Mean absolute error: {nf4_error:.6f}")
print(f"\n💡 NF4 achieves ~{compression_ratio:.1f}x compression with minimal accuracy loss!")

## 5. When to Use Each Precision Format

| Format | Best For | Memory Savings | Quality | Fine-tuning? |
|--------|----------|----------------|---------|--------------|
| **FP32** | Research, maximum accuracy | Baseline | ⭐⭐⭐⭐⭐ | ✅ Full precision |
| **FP16/BF16** | Standard training | 2x | ⭐⭐⭐⭐⭐ | ✅ Mixed precision |
| **INT8** | Inference only | 4x | ⭐⭐⭐⭐ | ❌ Not for training |
| **INT4/NF4** | Fine-tuning on limited GPUs | 8x | ⭐⭐⭐⭐ | ✅ With QLoRA |

### Quick Guide:

✅ **For inference only**: Use INT8 quantization  
✅ **For fine-tuning with limited GPU**: Use QLoRA with NF4  
✅ **For maximum accuracy**: Use FP32 or FP16  
✅ **For production deployment**: Use INT8 or INT4

## 6. Key Takeaways

### What You Learned:

1. **Quantization reduces memory** by using fewer bits (FP32 → FP16 → INT8 → INT4)
2. **Trade-off**: Lower precision = less memory but some accuracy loss
3. **NF4 is special**: Optimized for neural network weights (normally distributed)
4. **QLoRA enables fine-tuning** on consumer GPUs with 4-bit quantization

### Best Practices:

✅ Use **4-bit NF4 with QLoRA** for fine-tuning large models on limited hardware  
✅ Use **block size of 64** for good compression vs accuracy balance  
✅ Monitor quantization error - keep it under 1% of weight magnitude  
✅ Test on your specific task - impact varies by use case

### Resources:

- **bitsandbytes**: Python library for 4-bit/8-bit quantization
- **PEFT (Hugging Face)**: QLoRA implementation
- **QLoRA paper**: https://arxiv.org/abs/2305.14314

---

**You can now fine-tune massive LLMs on accessible hardware!** 🚀