In [1]:
import os
import tempfile
import time
import numpy as np
import faiss
import h5py
import requests
from collections import defaultdict
import bisect
import matplotlib.pyplot as plt

In [2]:
# --- Data Utilities ---

def download_fashion_mnist(cache_path, url):
    if not os.path.exists(cache_path):
        print("Downloading Fashion-MNIST (~300 MB)…")
        headers = {"User-Agent": "Mozilla/5.0"}
        response = requests.get(url, headers=headers)
        if response.status_code == 200:
            with open(cache_path, "wb") as f:
                f.write(response.content)
        else:
            raise Exception(f"Failed to download dataset: HTTP{response.status_code}")

def load_fashion_mnist(cache_path):
    with h5py.File(cache_path, "r") as f:
        xb = f["train"][:].astype(np.float32)
        xq = f["test"][:].astype(np.float32)
        gt = f["neighbors"][:]
    return xb, xq, gt

In [17]:
# --- Indexing Utilities ---

def build_kmeans(xb, d, n_clusters, niter=20):
    train_sample = xb[np.random.choice(len(xb), size=min(50000, len(xb)), replace=False)]
    kmeans = faiss.Kmeans(d, n_clusters, niter=niter, verbose=True, spherical=False)
    kmeans.train(train_sample)
    return kmeans

def assign_vectors_to_clusters(xb, kmeans, n_assign):
    _, assignments = kmeans.index.search(xb, n_assign)
    return assignments

def cross_pollinate_metadata(xb, xb_inner_assignments, inner_centroids, inner_to_outer, N_CROSS):
    vector_metadata = defaultdict(dict)
    for idx, inner_ids in enumerate(xb_inner_assignments):
        for inner_id in inner_ids[:N_CROSS]:
            outer_id = inner_to_outer[inner_id][0]
            centroid = inner_centroids[inner_id]
            vec = xb[idx]
            dist = np.linalg.norm(vec - centroid)
            cos_sim = np.dot(vec, centroid) / (np.linalg.norm(vec) * np.linalg.norm(centroid) + 1e-8)
            vector_metadata[outer_id].setdefault(inner_id, []).append((dist, cos_sim, idx))
    # Sort each list by Euclidean distance to centroid
    for outer_id in vector_metadata:
        for inner_id in vector_metadata[outer_id]:
            vector_metadata[outer_id][inner_id].sort()
    return vector_metadata

def build_outer_to_num_inner_clusters(vector_metadata):
    """
    Build a dictionary mapping each outer cluster id to the number of inner clusters it contains.
    Args:
        vector_metadata: dict[outer_id][inner_id] -> list of vectors
    Returns:
        outer_to_num_inner_clusters: dict[outer_id] -> int (number of inner clusters)
    """
    outer_to_num_inner_clusters = {}
    for outer_id in vector_metadata:
        outer_to_num_inner_clusters[outer_id] = len(vector_metadata[outer_id])

    print("\nNumber of inner clusters for each outer cluster:")
    for outer_id in sorted(outer_to_num_inner_clusters):
        print(f"Outer cluster {outer_id}: {outer_to_num_inner_clusters[outer_id]} inner clusters")
    
    return outer_to_num_inner_clusters

In [34]:
# --- Search Utilities ---
def pad_to_k(arr, k, pad_value):
    arr = list(arr)
    if len(arr) < k:
        arr += [pad_value] * (k - len(arr))
    return arr[:k]

def search_query_cross_pollination(
    x, outer_ids, inner_kmeans, vector_metadata, xb, k, d,
    N_PROBE=1, probe_strategy="nprobe", tshirt_size="small", n_outer_total=10
):
    """
    Probes clusters using either N_PROBE or T-SHIRT size strategy.
    T-SHIRT size: Small/Medium/Large, each probes a percentage of outer and inner clusters.
    """
    # T-SHIRT size settings
    tshirt_settings = {
        "small": 0.10,
        "medium": 0.20,
        "large": 0.30
    }
    best_heap = []
    tau = float("inf")
    probed_inner_ids = set()
    inner_probed = 0

    if probe_strategy == "tshirt":
        pct = tshirt_settings[tshirt_size]
        n_outer_probe = max(1, int(np.ceil(n_outer_total * pct)))
        outer_ids = outer_ids[:n_outer_probe]
    else:
        # nprobe strategy: probe as many outer clusters as needed to reach N_PROBE inner clusters
        n_outer_probe = len(outer_ids)

    outer_idx = 0
    total_outer = len(outer_ids)

    while (probe_strategy == "nprobe" and inner_probed < N_PROBE and outer_idx < total_outer) or \
          (probe_strategy == "tshirt" and outer_idx < total_outer):
        outer_id = outer_ids[outer_idx]
        outer_idx += 1
        if outer_id not in vector_metadata:
            continue
        inner_ids = list(vector_metadata[outer_id].keys())
        if not inner_ids:
            continue
        if probe_strategy == "tshirt":
            n_inner_probe = max(1, int(np.ceil(len(inner_ids) * tshirt_settings[tshirt_size])))
        else:
            n_inner_probe = min(N_PROBE - inner_probed, len(inner_ids))
        inner_ids_to_probe = [iid for iid in inner_ids if iid not in probed_inner_ids][:n_inner_probe]
        if not inner_ids_to_probe:
            continue
        inner_centroids_subset = inner_kmeans.centroids[inner_ids_to_probe]
        index_l2 = faiss.IndexFlatL2(d)
        index_l2.add(inner_centroids_subset)
        _, inner_ranks_local = index_l2.search(x.reshape(1, -1), len(inner_ids_to_probe))
        selected_inner_ids = [inner_ids_to_probe[j] for j in inner_ranks_local[0] if j < len(inner_ids_to_probe)]
        for inner_id in selected_inner_ids:
            probed_inner_ids.add(inner_id)
            idxs_meta = vector_metadata[outer_id][inner_id]
            if not idxs_meta:
                continue
            centroid = inner_kmeans.centroids[inner_id]
            d_qc = np.linalg.norm(x - centroid)
            for dist_ic, cos_theta, idx2 in idxs_meta:
                lower_bound = abs(d_qc - dist_ic)
                if lower_bound > tau:
                    continue
                est_dist = np.sqrt(max(0.0, d_qc ** 2 + dist_ic ** 2 - 2 * d_qc * dist_ic * cos_theta))
                if est_dist > tau:
                    continue
                actual_dist = np.linalg.norm(x - xb[idx2])
                best_heap.append((actual_dist, idx2))
                if len(best_heap) > k:
                    best_heap.sort()
                    best_heap = best_heap[:k]
                    tau = best_heap[-1][0]
            inner_probed += 1
            if probe_strategy == "nprobe" and inner_probed >= N_PROBE:
                break
    return best_heap

def search_all_queries_cross_pollination(
    xq, xq_outer_assignments, inner_kmeans, vector_metadata, xb, k, d, gt,
    N_PROBE=1, probe_strategy="nprobe", tshirt_size="small", n_outer_total=10
):
    I = []
    D = []
    start_time = time.time()
    for i, x in enumerate(xq):
        outer_ids = xq_outer_assignments[i]
        best_heap = search_query_cross_pollination(
            x, outer_ids, inner_kmeans, vector_metadata, xb, k, d,
            N_PROBE=N_PROBE, probe_strategy=probe_strategy, tshirt_size=tshirt_size, n_outer_total=n_outer_total
        )
        if best_heap:
            best_heap.sort()
            idxs = [idx for _, idx in best_heap]
            dists = [dist for dist, _ in best_heap]
            I.append(pad_to_k(idxs, k, -1))  # Use -1 or another invalid index as pad
            D.append(pad_to_k(dists, k, float('inf')))
        else:
            dists = np.linalg.norm(xb - x.reshape(1, -1), axis=1)
            idx = np.argsort(dists)[:k]
            I.append(idx)
            D.append(dists[idx])
    D = np.array(D)
    I = np.array(I)
    elapsed_time = time.time() - start_time
    qps = len(xq) / elapsed_time
    recall = (I == gt[:, :k]).sum() / (gt.shape[0] * k)
    return I, D, recall, qps                                  

In [32]:
# --- Experiment Utilities ---

def evaluate_cross_pollination(
    xb, xq, gt, inner_kmeans, inner_to_outer, xq_outer_assignments, k, d,
    N_PROBE=1, min_cross=1, max_cross=5, probe_strategy="nprobe", tshirt_size="small", n_outer_total=10
):
    recalls = []
    qps_list = []
    cross_range = range(min_cross, max_cross + 1)
    inner_centroids = inner_kmeans.centroids
    for N_CROSS in cross_range:
        print(f"Evaluating N_CROSS = {N_CROSS}")
        xb_inner_assignments = assign_vectors_to_clusters(xb, inner_kmeans, N_CROSS)
        vector_metadata = cross_pollinate_metadata(
            xb, xb_inner_assignments, inner_centroids, inner_to_outer, N_CROSS
        )
        outer_to_num_inner_clusters = build_outer_to_num_inner_clusters(vector_metadata)
        I, D, recall, qps = search_all_queries_cross_pollination(
            xq, xq_outer_assignments, inner_kmeans, vector_metadata, xb, k, d, gt,
            N_PROBE=N_PROBE, probe_strategy=probe_strategy, tshirt_size=tshirt_size, n_outer_total=n_outer_total
        )
        recalls.append(recall)
        qps_list.append(qps)
        print(f"N_CROSS={N_CROSS}: recall={recall:.4f}, qps={qps:.2f}")
    return cross_range, recalls, qps_list

def plot_cross_pollination_results(cross_range, recalls, qps_list):
    fig, ax1 = plt.subplots()
    color = 'tab:blue'
    ax1.set_xlabel('N_CROSS (number of clusters each vector is inserted into)')
    ax1.set_ylabel('Recall', color=color)
    ax1.plot(cross_range, recalls, marker='o', color=color)
    ax1.tick_params(axis='y', labelcolor=color)
    ax2 = ax1.twinx()
    color = 'tab:red'
    ax2.set_ylabel('QPS', color=color)
    ax2.plot(cross_range, qps_list, marker='x', color=color)
    ax2.tick_params(axis='y', labelcolor=color)
    plt.title('Recall and QPS vs N_CROSS (cross-pollination)')
    plt.show()


In [26]:
def run_cross_pollination_experiment(
    dataset_path,
    n_inner_clusters=400,
    probe_strategy="nprobe",
    N_PROBE=2,
    min_cross=1,
    max_cross=6,
    tshirt_size="small"
):
    # Load data
    xb, xq, gt = load_fashion_mnist(dataset_path)
    d = xb.shape[1]
    k = 10

    # Build KMeans
    inner_kmeans = build_kmeans(xb, d, n_inner_clusters)
    _, xq_inner_assignments = inner_kmeans.index.search(xq, 1)
    inner_centroids = inner_kmeans.centroids

    outer_kmeans = build_kmeans(inner_centroids, d, 10)
    _, inner_to_outer = outer_kmeans.index.search(inner_centroids, 1)
    _, xq_outer_assignments = outer_kmeans.index.search(xq, 3)

    # Evaluate cross-pollination
    cross_range, recalls, qps_list = evaluate_cross_pollination(
        xb, xq, gt, inner_kmeans, inner_to_outer, xq_outer_assignments, k, d,
        N_PROBE=N_PROBE, min_cross=min_cross, max_cross=max_cross, probe_strategy=probe_strategy,
        tshirt_size=tshirt_size, n_outer_total=10
    )

    # Plot results
    plot_cross_pollination_results(cross_range, recalls, qps_list)
    return cross_range, recalls, qps_list

In [None]:
# Specify which dataset to use: "fashion-mnist", "gist", or "sift"
selected_dataset = "fashion-mnist"

DATA_URLS = {
    "fashion-mnist": "http://ann-benchmarks.com/fashion-mnist-784-euclidean.hdf5",
    "gist": "http://ann-benchmarks.com/gist-960-euclidean.hdf5",
    "sift": "http://ann-benchmarks.com/sift-128-euclidean.hdf5"
}
CACHES = {
    name: os.path.join(tempfile.gettempdir(), url.split('/')[-1])
    for name, url in DATA_URLS.items()
}

# Download the selected dataset if not present
cache_path = CACHES[selected_dataset]
download_fashion_mnist(cache_path, DATA_URLS[selected_dataset])

# Run cross-pollination experiment for the selected dataset
print(f"\n=== Running experiment for {selected_dataset} ===")
run_cross_pollination_experiment(
    dataset_path=cache_path,
    n_inner_clusters=100,
    probe_strategy="nprobe",
    N_PROBE=5,
    min_cross=4,
    max_cross=5,
    tshirt_size="small"
)


=== Running experiment for fashion-mnist ===
Sampling a subset of 25600 / 50000 for training
Clustering 25600 points in 784D to 100 clusters, redo 1 times, 20 iterations
  Preprocessing in 0.02 s
  Iteration 19 (0.74 s, search 0.64 s): objective=3.38648e+10 imbalance=1.173 nsplit=0       
Clustering 100 points in 784D to 10 clusters, redo 1 times, 20 iterations
  Preprocessing in 0.00 s
  Iteration 19 (0.01 s, search 0.00 s): objective=1.01799e+08 imbalance=1.434 nsplit=0       
Evaluating N_CROSS = 4





Number of inner clusters for each outer cluster:
Outer cluster 0: 22 inner clusters
Outer cluster 1: 5 inner clusters
Outer cluster 2: 6 inner clusters
Outer cluster 3: 10 inner clusters
Outer cluster 4: 23 inner clusters
Outer cluster 5: 4 inner clusters
Outer cluster 6: 5 inner clusters
Outer cluster 7: 11 inner clusters
Outer cluster 8: 7 inner clusters
Outer cluster 9: 7 inner clusters
N_CROSS=4: recall=0.1746, qps=15.37
Evaluating N_CROSS = 5

Number of inner clusters for each outer cluster:
Outer cluster 0: 22 inner clusters
Outer cluster 1: 5 inner clusters
Outer cluster 2: 6 inner clusters
Outer cluster 3: 10 inner clusters
Outer cluster 4: 23 inner clusters
Outer cluster 5: 4 inner clusters
Outer cluster 6: 5 inner clusters
Outer cluster 7: 11 inner clusters
Outer cluster 8: 7 inner clusters
Outer cluster 9: 7 inner clusters


In [12]:
# Parameters for PQ
pq_m = 8      # Number of PQ sub-vectors (should divide d)
pq_nbits = 8  # Bits per sub-vector

# --- Build PQ codebooks and encode inner clusters ---
def build_pq_for_inner_clusters(xb, xb_inner_assignments, inner_centroids):
    pq_codebooks = {}
    pq_codes = {}
    for inner_id in range(inner_centroids.shape[0]):
        # Get all vectors assigned to this inner cluster
        idxs = np.where(xb_inner_assignments[:,0] == inner_id)[0]
        if len(idxs) == 0:
            continue
        # Compute residuals
        residuals = xb[idxs] - inner_centroids[inner_id]
        # Train PQ on residuals
        pq = faiss.ProductQuantizer(residuals.shape[1], pq_m, pq_nbits)
        pq.train(residuals)
        # Encode residuals
        codes = np.empty((len(residuals), pq.code_size), dtype='uint8')
        pq.compute_codes(residuals, codes)
        pq_codebooks[inner_id] = pq
        pq_codes[inner_id] = (idxs, codes)
    return pq_codebooks, pq_codes

In [13]:
# --- Search using PQ codes in inner clusters ---
def search_query_pq_inner(x, inner_kmeans, inner_centroids, pq_codebooks, pq_codes, xb, k):
    # Find closest inner cluster
    _, inner_assign = inner_kmeans.index.search(x.reshape(1, -1), 1)
    inner_id = int(inner_assign[0,0])
    if inner_id not in pq_codebooks:
        # fallback to brute-force
        dists = np.linalg.norm(xb - x.reshape(1, -1), axis=1)
        idx = np.argsort(dists)[:k]
        return idx, dists[idx]
    # Compute residual for query
    residual = x - inner_centroids[inner_id]
    pq = pq_codebooks[inner_id]
    idxs, codes = pq_codes[inner_id]
    # Compute distances from query residual to all PQ codes in this cluster
    dis = pq.compute_distance_table(residual.reshape(1, -1)).reshape(-1, pq.ksub * pq.M)
    # Use FAISS's search with PQ codes
    # For each code, sum the lookup table entries
    lookup = pq.compute_distance_table(residual.reshape(1, -1)).reshape(pq.M, pq.ksub)
    dists = np.zeros(len(codes), dtype='float32')
    for i, code in enumerate(codes):
        dists[i] = sum(lookup[m, code[m]] for m in range(pq.M))
    topk = np.argsort(dists)[:k]
    return idxs[topk], dists[topk]

In [15]:
# --- Product Quantization Example ---
# Assign each vector to its closest inner cluster
_, xb_inner_assignments = inner_kmeans.index.search(xb, 1)

# Build PQ codebooks and codes for each inner cluster
pq_codebooks, pq_codes = build_pq_for_inner_clusters(xb, xb_inner_assignments, inner_centroids)

# Search all queries using PQ in inner clusters
I = []
D = []
for x in xq:
    idxs, dists = search_query_pq_inner(x, inner_kmeans, inner_centroids, pq_codebooks, pq_codes, xb, k)
    I.append(idxs)
    D.append(dists)
I = np.array(I)
D = np.array(D)
recall = (I == gt[:, :k]).sum() / (gt.shape[0] * k)
print(f"Recall@{k} using PQ in inner clusters: {recall:.4f}")

RuntimeError: Error in void faiss::Clustering::train_encoded(idx_t, const uint8_t *, const Index *, Index &, const float *) at /Users/runner/miniconda3/conda-bld/faiss-pkg_1745590552381/work/faiss/Clustering.cpp:279: Error: 'nx >= k' failed: Number of training points (82) should be at least as large as number of clusters (256)