In [1]:
import numpy as np
from sklearn.cluster import KMeans

In [4]:
class ProductQuantizer:
    def __init__(self, n_subvectors, n_clusters):
        """
        Initialize the Product Quantizer.
        :param n_subvectors: Number of subvectors to divide each vector.
        :param n_clusters: Number of clusters for quantization in each subvector.
        """
        self.n_subvectors = n_subvectors
        self.n_clusters = n_clusters
        self.kmeans = [KMeans(n_clusters=n_clusters) for _ in range(n_subvectors)]

    def fit(self, data):
        """
        Fit the model on the data.
        :param data: Array of shape (n_samples, n_features).
        """
        subvector_len = data.shape[1] // self.n_subvectors
        for i in range(self.n_subvectors):
            # Extract subvectors and fit KMeans
            sub_data = data[:, i * subvector_len : (i + 1) * subvector_len]
            self.kmeans[i].fit(sub_data)

    def quantize(self, data):
        """
        Quantize the data using the fitted model.
        :param data: Array of shape (n_samples, n_features).
        :return: Quantized data as indices of centroids.
        """
        subvector_len = data.shape[1] // self.n_subvectors
        quantized_data = []
        for i in range(self.n_subvectors):
            # Extract subvectors and predict the closest centroid
            sub_data = data[:, i * subvector_len : (i + 1) * subvector_len]
            quantized_data.append(self.kmeans[i].predict(sub_data))
        return np.array(quantized_data).T

    def inverse_transform(self, quantized_data):
        """
        Convert quantized data back to approximate vectors.
        :param quantized_data: Array of quantized data (indices of centroids).
        :return: Approximate original vectors.
        """
        subvector_len = self.kmeans[0].cluster_centers_.shape[1]
        approx_data = np.zeros(
            (quantized_data.shape[0], subvector_len * self.n_subvectors)
        )
        for i in range(self.n_subvectors):
            centroids = self.kmeans[i].cluster_centers_[quantized_data[:, i]]
            approx_data[:, i * subvector_len : (i + 1) * subvector_len] = centroids
        return approx_data

In [5]:
# Generate sample data

np.random.seed(0)
data = np.random.rand(100, 64)  # 100 samples, 64-dimensional vectors

In [6]:
pq = ProductQuantizer(
    n_subvectors=8, n_clusters=10
)  # Divide into 8 subvectors, 10 clusters each
pq.fit(data)
quantized_data = pq.quantize(data)
approx_data = pq.inverse_transform(quantized_data)

# Demonstrate the approximation
print("Original data (first vector):", data[0])
print("Compressed data (first vector):", quantized_data[0])
print("Approximated data (first vector):", approx_data[0])



Original data (first vector): [0.5488135  0.71518937 0.60276338 0.54488318 0.4236548  0.64589411
 0.43758721 0.891773   0.96366276 0.38344152 0.79172504 0.52889492
 0.56804456 0.92559664 0.07103606 0.0871293  0.0202184  0.83261985
 0.77815675 0.87001215 0.97861834 0.79915856 0.46147936 0.78052918
 0.11827443 0.63992102 0.14335329 0.94466892 0.52184832 0.41466194
 0.26455561 0.77423369 0.45615033 0.56843395 0.0187898  0.6176355
 0.61209572 0.616934   0.94374808 0.6818203  0.3595079  0.43703195
 0.6976312  0.06022547 0.66676672 0.67063787 0.21038256 0.1289263
 0.31542835 0.36371077 0.57019677 0.43860151 0.98837384 0.10204481
 0.20887676 0.16130952 0.65310833 0.2532916  0.46631077 0.24442559
 0.15896958 0.11037514 0.65632959 0.13818295]
Compressed data (first vector): [0 6 6 4 2 2 3 8]
Approximated data (first vector): [0.54019654 0.36485555 0.80442656 0.52945912 0.26097166 0.43403487
 0.35285817 0.79455641 0.5989435  0.21066257 0.73069621 0.33785665
 0.68675651 0.60267202 0.57986277 0.37

In [7]:
# Nearest neighbor search using quantized vectors
def nearest_neighbor(query, quantized_data, pq):
    approx_query = pq.inverse_transform(pq.quantize(query.reshape(1, -1)))[0]
    distances = np.linalg.norm(approx_data - approx_query, axis=1)
    return np.argmin(distances)

In [8]:
# Example query
query = np.random.rand(64)
nn_index = nearest_neighbor(query, quantized_data, pq)
print(f"Nearest neighbor index for the query: {nn_index}")

Nearest neighbor index for the query: 58
