In [None]:
import numpy as np
import faiss
import matplotlib.pyplot as plt

In [None]:
# Parameters for PQ
pq_m = 8      # Number of PQ sub-vectors (should divide d)
pq_nbits = 8  # Bits per sub-vector

# --- Build PQ codebooks and encode inner clusters ---
def build_pq_for_inner_clusters(xb, xb_inner_assignments, inner_centroids):
    pq_codebooks = {}
    pq_codes = {}
    for inner_id in range(inner_centroids.shape[0]):
        # Get all vectors assigned to this inner cluster
        idxs = np.where(xb_inner_assignments[:,0] == inner_id)[0]
        if len(idxs) == 0:
            continue
        # Compute residuals
        residuals = xb[idxs] - inner_centroids[inner_id]
        # Train PQ on residuals
        pq = faiss.ProductQuantizer(residuals.shape[1], pq_m, pq_nbits)
        pq.train(residuals)
        # Encode residuals
        codes = np.empty((len(residuals), pq.code_size), dtype='uint8')
        pq.compute_codes(residuals, codes)
        pq_codebooks[inner_id] = pq
        pq_codes[inner_id] = (idxs, codes)
    return pq_codebooks, pq_codes

In [None]:
# --- Search using PQ codes in inner clusters ---
def search_query_pq_inner(x, inner_kmeans, inner_centroids, pq_codebooks, pq_codes, xb, k):
    # Find closest inner cluster
    _, inner_assign = inner_kmeans.index.search(x.reshape(1, -1), 1)
    inner_id = int(inner_assign[0,0])
    if inner_id not in pq_codebooks:
        # fallback to brute-force
        dists = np.linalg.norm(xb - x.reshape(1, -1), axis=1)
        idx = np.argsort(dists)[:k]
        return idx, dists[idx]
    # Compute residual for query
    residual = x - inner_centroids[inner_id]
    pq = pq_codebooks[inner_id]
    idxs, codes = pq_codes[inner_id]
    # Compute distances from query residual to all PQ codes in this cluster
    dis = pq.compute_distance_table(residual.reshape(1, -1)).reshape(-1, pq.ksub * pq.M)
    # Use FAISS's search with PQ codes
    # For each code, sum the lookup table entries
    lookup = pq.compute_distance_table(residual.reshape(1, -1)).reshape(pq.M, pq.ksub)
    dists = np.zeros(len(codes), dtype='float32')
    for i, code in enumerate(codes):
        dists[i] = sum(lookup[m, code[m]] for m in range(pq.M))
    topk = np.argsort(dists)[:k]
    return idxs[topk], dists[topk]

In [None]:
# --- Product Quantization Example ---
# Assign each vector to its closest inner cluster
_, xb_inner_assignments = inner_kmeans.index.search(xb, 1)

# Build PQ codebooks and codes for each inner cluster
pq_codebooks, pq_codes = build_pq_for_inner_clusters(xb, xb_inner_assignments, inner_centroids)

# Search all queries using PQ in inner clusters
I = []
D = []
for x in xq:
    idxs, dists = search_query_pq_inner(x, inner_kmeans, inner_centroids, pq_codebooks, pq_codes, xb, k)
    I.append(idxs)
    D.append(dists)
I = np.array(I)
D = np.array(D)
recall = (I == gt[:, :k]).sum() / (gt.shape[0] * k)
print(f"Recall@{k} using PQ in inner clusters: {recall:.4f}")