# Introduction to Weight Quantization
## Reducing Large Language Model Size with 8-bit Quantization

This notebook provides a comprehensive introduction to weight quantization techniques for Large Language Models (LLMs).

## Overview
- **Goal**: Understand fundamental quantization techniques
- **Methods Covered**: Absmax, Zero-point, and LLM.int8()
- **Benefits**: Reduced memory footprint, faster inference
- **Use Case**: Educational and practical applications

## What is Quantization?
Quantization reduces the precision of model weights from high-precision formats (e.g., FP32, FP16) to lower-precision formats (e.g., INT8), significantly reducing model size while maintaining acceptable performance.


## Part 1: Basic Quantization Functions

We'll implement two fundamental quantization methods:
1. **Absmax Quantization**: Symmetric quantization using absolute maximum
2. **Zero-point Quantization**: Asymmetric quantization with zero-point offset


In [None]:
import torch

def absmax_quantize(X):
    """
    Absmax quantization: Symmetric quantization method
    Maps values to [-127, 127] range using absolute maximum
    """
    # Calculate scale factor
    scale = 127 / torch.max(torch.abs(X))
    
    # Quantize: scale and round
    X_quant = (scale * X).round()
    
    # Dequantize: reverse the scaling
    X_dequant = X_quant / scale
    
    return X_quant.to(torch.int8), X_dequant


def zeropoint_quantize(X):
    """
    Zero-point quantization: Asymmetric quantization method
    Maps values to [-128, 127] range with zero-point offset
    """
    # Calculate value range
    x_range = torch.max(X) - torch.min(X)
    x_range = 1 if x_range == 0 else x_range
    
    # Calculate scale factor
    scale = 255 / x_range
    
    # Calculate zero-point offset
    zeropoint = (-scale * torch.min(X) - 128).round()
    
    # Quantize: scale, shift, and clip
    X_quant = torch.clip((X * scale + zeropoint).round(), -128, 127)
    
    # Dequantize: reverse shift and scaling
    X_dequant = (X_quant - zeropoint) / scale
    
    return X_quant.to(torch.int8), X_dequant

print("✓ Quantization functions defined")


## Part 2: Install Dependencies

Install required libraries for model quantization.


In [None]:
# Install required packages
!pip install -q bitsandbytes>=0.39.0
!pip install -q git+https://github.com/huggingface/accelerate.git
!pip install -q git+https://github.com/huggingface/transformers.git


## Part 3: Load Model and Extract Weights

Load GPT-2 model and examine its weights.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
torch.manual_seed(0)

# Set device
device = 'cpu'

# Load model and tokenizer
model_id = 'gpt2'
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Print model size
print(f"Model size: {model.get_memory_footprint():,} bytes")
print(f"Model size: {model.get_memory_footprint() / 1e6:.2f} MB")


## Part 4: Apply Quantization to Model Weights

Extract weights from the first layer and apply both quantization methods.


In [None]:
# Extract weights from first attention layer
weights = model.transformer.h[0].attn.c_attn.weight.data
print("Original weights:")
print(weights)
print(f"\nShape: {weights.shape}")
print(f"Data type: {weights.dtype}")

# Apply absmax quantization
weights_abs_quant, weights_abs_dequant = absmax_quantize(weights)
print("\n" + "="*50)
print("Absmax quantized weights:")
print(weights_abs_quant)
print(f"Data type: {weights_abs_quant.dtype}")

# Apply zero-point quantization
weights_zp_quant, weights_zp_dequant = zeropoint_quantize(weights)
print("\n" + "="*50)
print("Zero-point quantized weights:")
print(weights_zp_quant)
print(f"Data type: {weights_zp_quant.dtype}")


## Part 5: Quantize Entire Model

Apply quantization to all model parameters.


In [None]:
import numpy as np
from copy import deepcopy

# Store original weights
weights = [param.data.clone() for param in model.parameters()]

# Create model with absmax quantization
model_abs = deepcopy(model)
weights_abs = []
print("Applying absmax quantization to all layers...")
for param in model_abs.parameters():
    _, dequantized = absmax_quantize(param.data)
    param.data = dequantized
    weights_abs.append(dequantized)
print("✓ Absmax quantization complete")

# Create model with zero-point quantization
model_zp = deepcopy(model)
weights_zp = []
print("Applying zero-point quantization to all layers...")
for param in model_zp.parameters():
    _, dequantized = zeropoint_quantize(param.data)
    param.data = dequantized
    weights_zp.append(dequantized)
print("✓ Zero-point quantization complete")


## Part 6: Visualize Weight Distributions

Compare the distributions of original and quantized weights.


In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# Flatten weight tensors for visualization
weights_flat = np.concatenate([t.cpu().numpy().flatten() for t in weights])
weights_abs_flat = np.concatenate([t.cpu().numpy().flatten() for t in weights_abs])
weights_zp_flat = np.concatenate([t.cpu().numpy().flatten() for t in weights_zp])

# Set plot style
plt.style.use('ggplot')

# Create figure with two subplots
fig, axs = plt.subplots(2, figsize=(10, 10), dpi=300, sharex=True)

# Plot 1: Original vs Absmax
axs[0].hist(weights_flat, bins=150, alpha=0.5, label='Original weights', 
            color='blue', range=(-2, 2))
axs[0].hist(weights_abs_flat, bins=150, alpha=0.5, label='Absmax weights', 
            color='red', range=(-2, 2))

# Plot 2: Original vs Zero-point
axs[1].hist(weights_flat, bins=150, alpha=0.5, label='Original weights', 
            color='blue', range=(-2, 2))
axs[1].hist(weights_zp_flat, bins=150, alpha=0.5, label='Zero-point weights', 
            color='green', range=(-2, 2))

# Add grid and legends
for ax in axs:
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()
    ax.set_ylabel('Count', fontsize=14)
    ax.yaxis.set_major_formatter(ticker.EngFormatter())

# Add titles and labels
axs[0].set_title('Comparison of Original and Absmax Quantized Weights', fontsize=16)
axs[1].set_title('Comparison of Original and Zero-point Quantized Weights', fontsize=16)
axs[1].set_xlabel('Weight Values', fontsize=14)

plt.rc('font', size=12)
plt.tight_layout()
plt.show()

print(f"Total weights visualized: {len(weights_flat):,}")


## Part 7: Test Generation Quality

Generate text with original and quantized models to compare output quality.


In [None]:
def generate_text(model, input_text, max_length=50):
    """Generate text using the model"""
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    output = model.generate(
        inputs=input_ids,
        max_length=max_length,
        do_sample=True,
        top_k=30,
        pad_token_id=tokenizer.eos_token_id,
        attention_mask=input_ids.new_ones(input_ids.shape)
    )
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Generate text with each model
prompt = "I have a dream"
print(f"Prompt: '{prompt}'\n")
print("="*70)

original_text = generate_text(model, prompt)
print("Original model:")
print(original_text)
print("\n" + "="*70)

absmax_text = generate_text(model_abs, prompt)
print("Absmax quantized model:")
print(absmax_text)
print("\n" + "="*70)

zp_text = generate_text(model_zp, prompt)
print("Zero-point quantized model:")
print(zp_text)
print("="*70)


## Part 8: Calculate Perplexity

Evaluate model quality using perplexity metric.


In [None]:
def calculate_perplexity(model, text):
    """Calculate perplexity of generated text"""
    # Encode the text
    encodings = tokenizer(text, return_tensors='pt').to(device)
    
    # Define input and target ids
    input_ids = encodings.input_ids
    target_ids = input_ids.clone()
    
    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
    
    # Calculate perplexity from loss
    neg_log_likelihood = outputs.loss
    ppl = torch.exp(neg_log_likelihood)
    
    return ppl

# Calculate perplexity for each model
ppl = calculate_perplexity(model, original_text)
ppl_abs = calculate_perplexity(model_abs, original_text)
ppl_zp = calculate_perplexity(model_zp, original_text)

print("Perplexity Comparison:")
print("="*50)
print(f"Original model:       {ppl.item():.2f}")
print(f"Absmax quantized:     {ppl_abs.item():.2f}")
print(f"Zero-point quantized: {ppl_zp.item():.2f}")
print("="*50)
print("\nNote: Lower perplexity indicates better model quality")


## Part 9: LLM.int8() - Advanced 8-bit Quantization

Use bitsandbytes library for production-ready 8-bit quantization.


In [None]:
# Check device availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model with 8-bit quantization
model_int8 = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map='auto',
    load_in_8bit=True,
)

# Compare model sizes
original_size = model.get_memory_footprint()
int8_size = model_int8.get_memory_footprint()

print("\nModel Size Comparison:")
print("="*50)
print(f"Original model: {original_size:,} bytes ({original_size / 1e6:.2f} MB)")
print(f"LLM.int8() model: {int8_size:,} bytes ({int8_size / 1e6:.2f} MB)")
print(f"Compression ratio: {original_size / int8_size:.2f}x")
print(f"Size reduction: {(1 - int8_size/original_size)*100:.1f}%")


## Part 10: Visualize LLM.int8() Weights

Compare LLM.int8() quantized weights with original weights.


In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# Extract int8 weights
weights_int8 = [param.data.clone() for param in model_int8.parameters()]
weights_int8_flat = np.concatenate([t.cpu().numpy().flatten() for t in weights_int8])

# Create visualization
plt.style.use('ggplot')
fig, ax = plt.subplots(figsize=(10, 5), dpi=300)

# Plot histograms
ax.hist(weights_flat, bins=150, alpha=0.5, label='Original weights',
        color='blue', range=(-2, 2))
ax.hist(weights_int8_flat, bins=150, alpha=0.5, label='LLM.int8() weights',
        color='red', range=(-2, 2))

# Formatting
ax.grid(True, linestyle='--', alpha=0.6)
ax.legend()
ax.set_title('Comparison of Original and LLM.int8() Quantized Weights', fontsize=16)
ax.set_xlabel('Weight Values', fontsize=14)
ax.set_ylabel('Count', fontsize=14)
ax.yaxis.set_major_formatter(ticker.EngFormatter())

plt.rc('font', size=12)
plt.tight_layout()
plt.show()


## Part 11: Test LLM.int8() Generation Quality


In [None]:
# Generate text with int8 model
text_int8 = generate_text(model_int8, prompt)

print("Text Generation Comparison:")
print("="*70)
print("Original model:")
print(original_text)
print("\n" + "="*70)
print("LLM.int8() model:")
print(text_int8)
print("="*70)


## Part 12: Final Perplexity Comparison


In [None]:
# Calculate perplexity for int8 model
ppl_int8 = calculate_perplexity(model_int8, text_int8)

print("Final Perplexity Comparison:")
print("="*50)
print(f"Original model:       {ppl.item():.2f}")
print(f"LLM.int8() model:     {ppl_int8.item():.2f}")
print("="*50)
print(f"\nPerplexity difference: {abs(ppl.item() - ppl_int8.item()):.2f}")
print("\nConclusion: LLM.int8() maintains quality while reducing size by ~4x!")


## Summary

This notebook demonstrated three quantization approaches:

1. **Absmax Quantization**: Simple symmetric quantization
   - Pros: Fast, simple implementation
   - Cons: Less accurate for asymmetric distributions

2. **Zero-point Quantization**: Asymmetric quantization with offset
   - Pros: Better handling of asymmetric distributions
   - Cons: Slightly more complex

3. **LLM.int8()**: Production-ready quantization from bitsandbytes
   - Pros: Minimal quality loss, significant memory savings
   - Cons: Requires specific hardware support

### Key Takeaways
- Quantization can reduce model size by 4x or more
- Modern quantization methods (LLM.int8()) preserve quality remarkably well
- Choice of quantization method depends on hardware and quality requirements
