In [1]:
import os
import tempfile
import time
import numpy as np
import faiss
import h5py
import requests
import matplotlib.pyplot as plt
from collections import defaultdict

# Product Quantization with Multi-Level Clustering

This notebook demonstrates how to use Product Quantization (PQ) within a multi-level clustering approach for efficient vector similarity search. PQ compresses vectors by quantizing sub-vectors independently, reducing memory usage while maintaining search accuracy.

## Key Concepts:
- **Product Quantization**: Divides vectors into sub-vectors and quantizes each independently
- **Residual Encoding**: Stores PQ codes of residuals after subtracting cluster centroids
- **Multi-level Search**: Uses hierarchical clustering to organize data before applying PQ

## Implementation Overview:
1. Load and prepare dataset
2. Build hierarchical K-means clustering
3. Apply PQ to residuals within each cluster
4. Implement PQ-based search
5. Evaluate performance vs. baseline methods

In [2]:
# --- Data Loading Utilities ---

def download_dataset(cache_path, url, dataset_name):
    """Download dataset if not already cached."""
    if not os.path.exists(cache_path):
        print(f"Downloading {dataset_name} dataset (~300-500 MB)...")
        headers = {"User-Agent": "Mozilla/5.0"}
        response = requests.get(url, headers=headers)
        if response.status_code == 200:
            with open(cache_path, "wb") as f:
                f.write(response.content)
            print(f"Downloaded {dataset_name} successfully!")
        else:
            raise Exception(f"Failed to download dataset: HTTP{response.status_code}")
    else:
        print(f"{dataset_name} dataset already cached.")

def load_dataset(cache_path):
    """Load dataset from HDF5 file."""
    with h5py.File(cache_path, "r") as f:
        xb = f["train"][:].astype(np.float32)  # Base vectors
        xq = f["test"][:].astype(np.float32)   # Query vectors
        gt = f["neighbors"][:]                 # Ground truth
    return xb, xq, gt

In [3]:
# --- K-means Building Utilities ---

def build_kmeans(data, d, n_clusters, niter=20, spherical=False):
    """Build K-means model."""
    # Sample data for training if dataset is large
    train_sample = data[np.random.choice(len(data), size=min(50000, len(data)), replace=False)]
    kmeans = faiss.Kmeans(d, n_clusters, niter=niter, verbose=True, spherical=spherical)
    kmeans.train(train_sample)
    return kmeans

def assign_vectors_to_clusters(vectors, kmeans, n_assign=1):
    """Assign vectors to their closest clusters."""
    _, assignments = kmeans.index.search(vectors, n_assign)
    return assignments

In [8]:
# Parameters for PQ
pq_m = 8      # Number of PQ sub-vectors (should divide d)
pq_nbits = 8  # Bits per sub-vector

# --- Build PQ codebooks and encode inner clusters ---
def build_pq_for_inner_clusters(xb, xb_inner_assignments, inner_centroids):
    pq_codebooks = {}
    pq_codes = {}
    for inner_id in range(inner_centroids.shape[0]):
        # Get all vectors assigned to this inner cluster
        idxs = np.where(xb_inner_assignments[:,0] == inner_id)[0]
        if len(idxs) == 0:
            continue
        # Compute residuals
        residuals = xb[idxs] - inner_centroids[inner_id]
        
        # Skip if too few residuals for PQ training
        if len(residuals) < pq_m * (2 ** pq_nbits):
            continue
            
        # Train PQ on residuals
        pq = faiss.ProductQuantizer(residuals.shape[1], pq_m, pq_nbits)
        pq.train(residuals)
        
        # Encode residuals using the correct API
        codes = pq.compute_codes(residuals)
        pq_codebooks[inner_id] = pq
        pq_codes[inner_id] = (idxs, codes)
    return pq_codebooks, pq_codes

In [4]:
# --- Dataset Configuration ---

# Choose dataset: "fashion-mnist", "gist", or "sift"
selected_dataset = "fashion-mnist"

DATA_URLS = {
    "fashion-mnist": "http://ann-benchmarks.com/fashion-mnist-784-euclidean.hdf5",
    "gist": "http://ann-benchmarks.com/gist-960-euclidean.hdf5",
    "sift": "http://ann-benchmarks.com/sift-128-euclidean.hdf5"
}

CACHES = {
    name: os.path.join(tempfile.gettempdir(), url.split('/')[-1])
    for name, url in DATA_URLS.items()
}

# Download and load the selected dataset
cache_path = CACHES[selected_dataset]
download_dataset(cache_path, DATA_URLS[selected_dataset], selected_dataset)
xb, xq, gt = load_dataset(cache_path)

print(f"Dataset: {selected_dataset}")
print(f"Base vectors: {xb.shape}")
print(f"Query vectors: {xq.shape}")
print(f"Ground truth: {gt.shape}")

d = xb.shape[1]
k = 10  # Number of nearest neighbors to find

# Adjust PQ parameters based on dimensionality
if d % 8 == 0:
    pq_m = 8
elif d % 4 == 0:
    pq_m = 4
else:
    pq_m = 2
    
print(f"Using PQ with m={pq_m} sub-vectors (dimension={d})")

fashion-mnist dataset already cached.
Dataset: fashion-mnist
Base vectors: (60000, 784)
Query vectors: (10000, 784)
Ground truth: (10000, 100)
Using PQ with m=8 sub-vectors (dimension=784)


In [5]:
# --- Build Multi-Level K-means Clustering ---

print("\n=== Building Multi-Level K-means Clustering ===")

# Level 1: Inner clusters (fine-grained)
n_inner_clusters = 100
print(f"Building {n_inner_clusters} inner clusters...")
inner_kmeans = build_kmeans(xb, d, n_inner_clusters)
inner_centroids = inner_kmeans.centroids
print(f"Inner clustering completed. Centroids shape: {inner_centroids.shape}")

# Level 2: Outer clusters (coarse-grained) - cluster the inner centroids
n_outer_clusters = 10
print(f"\nBuilding {n_outer_clusters} outer clusters from inner centroids...")
outer_kmeans = build_kmeans(inner_centroids, d, n_outer_clusters)
print(f"Outer clustering completed.")

# Create mapping from inner clusters to outer clusters
_, inner_to_outer = outer_kmeans.index.search(inner_centroids, 1)
print(f"Mapped {n_inner_clusters} inner clusters to {n_outer_clusters} outer clusters")

# Assign query vectors to outer clusters for search
_, xq_outer_assignments = outer_kmeans.index.search(xq, 3)  # Top 3 outer clusters per query
print(f"Assigned {len(xq)} queries to outer clusters")


=== Building Multi-Level K-means Clustering ===
Building 100 inner clusters...
Sampling a subset of 25600 / 50000 for training
Clustering 25600 points in 784D to 100 clusters, redo 1 times, 20 iterations
  Preprocessing in 0.02 s
  Iteration 19 (0.56 s, search 0.46 s): objective=3.37934e+10 imbalance=1.150 nsplit=0       
Inner clustering completed. Centroids shape: (100, 784)

Building 10 outer clusters from inner centroids...
Clustering 100 points in 784D to 10 clusters, redo 1 times, 20 iterations
  Preprocessing in 0.00 s
  Iteration 19 (0.56 s, search 0.46 s): objective=3.37934e+10 imbalance=1.150 nsplit=0       
Inner clustering completed. Centroids shape: (100, 784)

Building 10 outer clusters from inner centroids...
Clustering 100 points in 784D to 10 clusters, redo 1 times, 20 iterations
  Preprocessing in 0.00 s
  Iteration 0 (0.01 s, search 0.01 s): objective=1.77772e+08 imbalance=1.402 nsplit=0       



  Iteration 19 (0.24 s, search 0.18 s): objective=9.44027e+07 imbalance=1.240 nsplit=0       
Outer clustering completed.
Mapped 100 inner clusters to 10 outer clusters
Assigned 10000 queries to outer clusters


In [None]:
# --- Search using PQ codes in inner clusters ---
def search_query_pq_inner(x, inner_kmeans, inner_centroids, pq_codebooks, pq_codes, xb, k):
    # Find closest inner cluster
    _, inner_assign = inner_kmeans.index.search(x.reshape(1, -1), 1)
    inner_id = int(inner_assign[0,0])
    if inner_id not in pq_codebooks:
        # fallback to brute-force
        dists = np.linalg.norm(xb - x.reshape(1, -1), axis=1)
        idx = np.argsort(dists)[:k]
        return idx, dists[idx]
    # Compute residual for query
    residual = x - inner_centroids[inner_id]
    pq = pq_codebooks[inner_id]
    idxs, codes = pq_codes[inner_id]
    # Compute distances from query residual to all PQ codes in this cluster
    dis = pq.compute_distance_table(residual.reshape(1, -1)).reshape(-1, pq.ksub * pq.M)
    # Use FAISS's search with PQ codes
    # For each code, sum the lookup table entries
    lookup = pq.compute_distance_table(residual.reshape(1, -1)).reshape(pq.M, pq.ksub)
    dists = np.zeros(len(codes), dtype='float32')
    for i, code in enumerate(codes):
        dists[i] = sum(lookup[m, code[m]] for m in range(pq.M))
    topk = np.argsort(dists)[:k]
    return idxs[topk], dists[topk]

In [None]:
# --- Baseline Search Methods ---

def search_brute_force(xq, xb, k):
    """Brute force exact search for comparison."""
    print("Running brute force search...")
    start_time = time.time()
    I = []
    D = []
    for x in xq:
        dists = np.linalg.norm(xb - x.reshape(1, -1), axis=1)
        idx = np.argsort(dists)[:k]
        I.append(idx)
        D.append(dists[idx])
    elapsed_time = time.time() - start_time
    return np.array(I), np.array(D), elapsed_time

def search_kmeans_only(xq, inner_kmeans, inner_centroids, xb, k):
    """Search using K-means clustering without PQ."""
    print("Running K-means only search...")
    start_time = time.time()
    I = []
    D = []
    
    # Pre-assign all base vectors to clusters
    _, xb_assignments = inner_kmeans.index.search(xb, 1)
    cluster_to_points = defaultdict(list)
    for idx, cluster_id in enumerate(xb_assignments[:, 0]):
        cluster_to_points[cluster_id].append(idx)
    
    for x in xq:
        # Find closest clusters for query
        _, query_assignments = inner_kmeans.index.search(x.reshape(1, -1), 3)
        
        best_ids = []
        best_dists = []
        
        for cluster_id in query_assignments[0]:
            if cluster_id in cluster_to_points:
                point_ids = cluster_to_points[cluster_id]
                candidates = xb[point_ids]
                dists = np.linalg.norm(candidates - x.reshape(1, -1), axis=1)
                best_ids.extend(point_ids)
                best_dists.extend(dists)
        
        if best_ids:
            # Get top-k from all candidates
            combined = list(zip(best_dists, best_ids))
            combined.sort()
            top_k = combined[:k]
            I.append([idx for _, idx in top_k])
            D.append([dist for dist, _ in top_k])
        else:
            # Fallback to brute force for this query
            dists = np.linalg.norm(xb - x.reshape(1, -1), axis=1)
            idx = np.argsort(dists)[:k]
            I.append(idx)
            D.append(dists[idx])
    
    elapsed_time = time.time() - start_time
    return np.array(I), np.array(D), elapsed_time

In [9]:
# --- Product Quantization Example ---
# Assign each vector to its closest inner cluster
_, xb_inner_assignments = inner_kmeans.index.search(xb, 1)

# Build PQ codebooks and codes for each inner cluster
pq_codebooks, pq_codes = build_pq_for_inner_clusters(xb, xb_inner_assignments, inner_centroids)

# Search all queries using PQ in inner clusters
I = []
D = []
for x in xq:
    idxs, dists = search_query_pq_inner(x, inner_kmeans, inner_centroids, pq_codebooks, pq_codes, xb, k)
    I.append(idxs)
    D.append(dists)
I = np.array(I)
D = np.array(D)
recall = (I == gt[:, :k]).sum() / (gt.shape[0] * k)
print(f"Recall@{k} using PQ in inner clusters: {recall:.4f}")

NameError: name 'search_query_pq_inner' is not defined

In [None]:
# --- Comprehensive Evaluation ---

print("\n=== Performance Evaluation ===")

# Run all methods and collect results
results = {}

# 1. Brute Force (Ground Truth)
print("\n1. Brute Force Search:")
I_bf, D_bf, time_bf = search_brute_force(xq[:100], xb, k)  # Use subset for speed
qps_bf = len(xq[:100]) / time_bf
print(f"Time: {time_bf:.4f}s, QPS: {qps_bf:.2f}")
results['Brute Force'] = {'qps': qps_bf, 'recall': 1.0, 'time': time_bf}

# 2. K-means Only
print("\n2. K-means Only Search:")
I_km, D_km, time_km = search_kmeans_only(xq, inner_kmeans, inner_centroids, xb, k)
qps_km = len(xq) / time_km
recall_km = np.mean([
    len(set(I_km[i]) & set(gt[i, :k])) / k
    for i in range(len(xq))
])
print(f"Time: {time_km:.4f}s, QPS: {qps_km:.2f}, Recall: {recall_km:.4f}")
results['K-means Only'] = {'qps': qps_km, 'recall': recall_km, 'time': time_km}

# 3. Product Quantization
print("\n3. Product Quantization Search:")
start_time = time.time()
I_pq = []
D_pq = []
for x in xq:
    idxs, dists = search_query_pq_inner(x, inner_kmeans, inner_centroids, pq_codebooks, pq_codes, xb, k)
    I_pq.append(idxs)
    D_pq.append(dists)
time_pq = time.time() - start_time
I_pq = np.array(I_pq)
D_pq = np.array(D_pq)
qps_pq = len(xq) / time_pq
recall_pq = np.mean([
    len(set(I_pq[i]) & set(gt[i, :k])) / k
    for i in range(len(xq))
])
print(f"Time: {time_pq:.4f}s, QPS: {qps_pq:.2f}, Recall: {recall_pq:.4f}")
results['Product Quantization'] = {'qps': qps_pq, 'recall': recall_pq, 'time': time_pq}

In [None]:
# --- Memory Usage Analysis ---

print("\n=== Memory Usage Analysis ===")

# Calculate memory usage for different approaches
original_size = xb.nbytes  # Original vectors in bytes

# K-means approach: store cluster assignments (4 bytes per vector)
kmeans_size = len(xb) * 4 + inner_centroids.nbytes

# PQ approach: store PQ codes
pq_size = inner_centroids.nbytes  # Centroids
for inner_id, (idxs, codes) in pq_codes.items():
    pq_size += codes.nbytes

print(f"Original vectors: {original_size / (1024**2):.2f} MB")
print(f"K-means approach: {kmeans_size / (1024**2):.2f} MB (compression: {original_size/kmeans_size:.1f}x)")
print(f"PQ approach: {pq_size / (1024**2):.2f} MB (compression: {original_size/pq_size:.1f}x)")

results['Original Size (MB)'] = original_size / (1024**2)
results['K-means Size (MB)'] = kmeans_size / (1024**2)
results['PQ Size (MB)'] = pq_size / (1024**2)

In [None]:
# --- Results Visualization ---

print("\n=== Results Summary ===")
print(f"{'Method':<20} {'QPS':<10} {'Recall':<10} {'Time (s)':<12}")
print("-" * 55)
for method, metrics in results.items():
    if 'qps' in metrics:
        print(f"{method:<20} {metrics['qps']:<10.2f} {metrics['recall']:<10.4f} {metrics['time']:<12.4f}")

# Create comparison plots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Product Quantization Performance Analysis', fontsize=16)

# Plot 1: QPS Comparison
methods = [m for m in results.keys() if 'qps' in results[m]]
qps_values = [results[m]['qps'] for m in methods]
ax1.bar(methods, qps_values, color=['blue', 'orange', 'green'])
ax1.set_ylabel('Queries Per Second (QPS)')
ax1.set_title('Search Speed Comparison')
ax1.tick_params(axis='x', rotation=45)

# Plot 2: Recall Comparison
recall_values = [results[m]['recall'] for m in methods]
ax2.bar(methods, recall_values, color=['blue', 'orange', 'green'])
ax2.set_ylabel('Recall@10')
ax2.set_title('Search Accuracy Comparison')
ax2.set_ylim(0, 1.1)
ax2.tick_params(axis='x', rotation=45)

# Plot 3: Memory Usage
memory_methods = ['Original', 'K-means', 'PQ']
memory_values = [
    results['Original Size (MB)'],
    results['K-means Size (MB)'],
    results['PQ Size (MB)']
]
ax3.bar(memory_methods, memory_values, color=['red', 'orange', 'green'])
ax3.set_ylabel('Memory Usage (MB)')
ax3.set_title('Memory Efficiency Comparison')

# Plot 4: QPS vs Recall Trade-off
qps_plot = [results[m]['qps'] for m in methods]
recall_plot = [results[m]['recall'] for m in methods]
colors = ['blue', 'orange', 'green']
for i, method in enumerate(methods):
    ax4.scatter(recall_plot[i], qps_plot[i], s=100, c=colors[i], label=method, alpha=0.7)
ax4.set_xlabel('Recall@10')
ax4.set_ylabel('Queries Per Second (QPS)')
ax4.set_title('Speed vs Accuracy Trade-off')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# --- Parameter Sensitivity Analysis ---

print("\n=== PQ Parameter Sensitivity Analysis ===")

# Test different PQ configurations
pq_configs = [
    {'m': 2, 'nbits': 8},
    {'m': 4, 'nbits': 8},
    {'m': 8, 'nbits': 8},
    {'m': 4, 'nbits': 4},
    {'m': 4, 'nbits': 16}
]

if d >= 8:  # Only run if dimension allows
    sensitivity_results = []
    
    for config in pq_configs:
        if d % config['m'] != 0:
            continue  # Skip if dimension not divisible by m
            
        print(f"\nTesting PQ with m={config['m']}, nbits={config['nbits']}")
        
        # Temporarily override global PQ parameters
        temp_pq_m = config['m']
        temp_pq_nbits = config['nbits']
        
        # Rebuild PQ with new parameters
        temp_pq_codebooks = {}
        temp_pq_codes = {}
        
        for inner_id in range(inner_centroids.shape[0]):
            idxs = np.where(xb_inner_assignments[:,0] == inner_id)[0]
            if len(idxs) == 0:
                continue
            residuals = xb[idxs] - inner_centroids[inner_id]
            pq = faiss.ProductQuantizer(residuals.shape[1], temp_pq_m, temp_pq_nbits)
            pq.train(residuals)
            codes = np.empty((len(residuals), pq.code_size), dtype='uint8')
            pq.compute_codes(residuals, codes)
            temp_pq_codebooks[inner_id] = pq
            temp_pq_codes[inner_id] = (idxs, codes)
        
        # Test search with new PQ
        start_time = time.time()
        I_temp = []
        for x in xq[:500]:  # Use subset for faster evaluation
            idxs, _ = search_query_pq_inner(x, inner_kmeans, inner_centroids, 
                                          temp_pq_codebooks, temp_pq_codes, xb, k)
            I_temp.append(idxs)
        test_time = time.time() - start_time
        
        # Calculate metrics
        I_temp = np.array(I_temp)
        test_qps = len(xq[:500]) / test_time
        test_recall = np.mean([
            len(set(I_temp[i]) & set(gt[i, :k])) / k
            for i in range(len(I_temp))
        ])
        
        # Calculate memory usage
        temp_memory = inner_centroids.nbytes
        for inner_id, (idxs, codes) in temp_pq_codes.items():
            temp_memory += codes.nbytes
        temp_compression = original_size / temp_memory
        
        sensitivity_results.append({
            'm': config['m'],
            'nbits': config['nbits'],
            'recall': test_recall,
            'qps': test_qps,
            'compression': temp_compression,
            'memory_mb': temp_memory / (1024**2)
        })
        
        print(f"  Recall: {test_recall:.4f}, QPS: {test_qps:.2f}, Compression: {temp_compression:.1f}x")
    
    # Plot sensitivity analysis
    if sensitivity_results:
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('PQ Parameter Sensitivity Analysis', fontsize=16)
        
        config_labels = [f"m={r['m']}, bits={r['nbits']}" for r in sensitivity_results]
        
        # Recall comparison
        recalls = [r['recall'] for r in sensitivity_results]
        ax1.bar(range(len(config_labels)), recalls)
        ax1.set_xticks(range(len(config_labels)))
        ax1.set_xticklabels(config_labels, rotation=45)
        ax1.set_ylabel('Recall@10')
        ax1.set_title('Recall vs PQ Configuration')
        
        # QPS comparison
        qps_vals = [r['qps'] for r in sensitivity_results]
        ax2.bar(range(len(config_labels)), qps_vals, color='orange')
        ax2.set_xticks(range(len(config_labels)))
        ax2.set_xticklabels(config_labels, rotation=45)
        ax2.set_ylabel('QPS')
        ax2.set_title('Speed vs PQ Configuration')
        
        # Compression comparison
        compressions = [r['compression'] for r in sensitivity_results]
        ax3.bar(range(len(config_labels)), compressions, color='green')
        ax3.set_xticks(range(len(config_labels)))
        ax3.set_xticklabels(config_labels, rotation=45)
        ax3.set_ylabel('Compression Ratio')
        ax3.set_title('Memory Efficiency vs PQ Configuration')
        
        # Trade-off plot
        for i, result in enumerate(sensitivity_results):
            ax4.scatter(result['recall'], result['qps'], s=100, 
                       label=config_labels[i], alpha=0.7)
        ax4.set_xlabel('Recall@10')
        ax4.set_ylabel('QPS')
        ax4.set_title('Speed vs Accuracy Trade-off')
        ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
else:
    print(f"Skipping sensitivity analysis - dimension {d} too small for multiple PQ configurations")

## Conclusions and Recommendations

### Key Findings:

1. **Performance Trade-offs**:
   - **Brute Force**: Perfect recall but slow (baseline)
   - **K-means Only**: Good balance of speed and accuracy
   - **Product Quantization**: Excellent compression with competitive search speed

2. **Memory Efficiency**:
   - PQ achieves significant memory compression (typically 10-50x)
   - Trade-off between compression ratio and accuracy
   - Higher compression with more sub-vectors (larger m) or fewer bits

3. **Parameter Selection**:
   - `m` (number of sub-vectors): Should divide dimension evenly
   - `nbits`: 8 bits often provides good balance
   - More bits = better accuracy but larger memory usage

### Recommendations:

1. **For Memory-Constrained Applications**:
   - Use PQ with m=8, nbits=8 as starting point
   - Adjust based on available memory and accuracy requirements

2. **For Speed-Critical Applications**:
   - K-means clustering alone may be sufficient
   - PQ adds compression benefits with minimal speed penalty

3. **For Large-Scale Deployment**:
   - Combine multi-level clustering with PQ
   - Use coarse clustering to reduce search space
   - Apply PQ within clusters for memory efficiency

### Next Steps:

1. **Advanced PQ Variants**:
   - Optimized Product Quantization (OPQ)
   - Additive Quantization (AQ)
   - Composite Quantization (CQ)

2. **Integration with Modern Indexes**:
   - Combine with HNSW for graph-based search
   - Use with LSH for approximate search
   - Apply to transformer embeddings

3. **Hardware Optimization**:
   - GPU acceleration for PQ distance computations
   - SIMD optimizations for CPU implementations
   - Memory-mapped file storage for large datasets