In [None]:
# In this notebook we will implement product quantization (PQ) using simple, readable, Python code.
x = [1, 8, 3, 9, 1, 2, 9, 4, 5, 4, 6, 2]

# The first step is the creation of m subvectors:
m = 4
D = len(x)

# ensure D is divisable by m
assert D % m == 0

# length of each subvector will be D / m (D* in notation)
D_ = int(D / m)

In [None]:
# now create the subvectors
u = [x[row:row+D_] for row in range(0, D, D_)]
print(u)

In [None]:
# Now we must create a set of clusters for each subvector space - giving us m seperate codebooks (codebook will map our subvectors to their assigned cluster centroids - reproduction values).
# The clusters would usually be trained, we will not do that here as this example is using only one vector. We will use randomly generated centroid positions.
# We need to decide how many centroids create - more centroids == lower error between vector positions and the centroids they are assigned to (more centroids increases the chances of vectors being assigned to a closer centroid).
# This value is chosen by k, which must be divisable by m to create equal (sub)centroid ranges for each subvector.

k = 2**5
assert k % m == 0
k_ = int(k/m)
print(f"{k=}, {k_=}")

# We have 32 centroids in total, and 8 centroids per subvector space (subspace).
# Each of these centroids will have three dimensions - aligned to our subvector dimensionality. Let's generate them.

In [None]:
from random import randint

c = []  # our overall list of reproduction values
for j in range(m):
    # each j represents a subvector (and therefore subquantizer) position
    c_j = []
    for i in range(k_):
        # each i represents a cluster/reproduction value position *inside* each subspace j
        c_ji = [randint(0, 9) for _ in range(D_)]
        c_j.append(c_ji)  # add cluster centroid to subspace list
    # add subspace list of centroids to overall list
    c.append(c_j)

In [None]:
# There are a lot of centroids in here so the easiest way for us to see them is to visualize:
import matplotlib.pyplot as plt

fig = plt.figure()

for j in range(m):
    ax = fig.add_subplot(2, 2, j + 1, projection = '3d')
    # get centroid positions
    X = [c[j][i][0] for i in range(k_)]
    Y = [c[j][i][1] for i in range(k_)]
    Z = [c[j][i][2] for i in range(k_)]
    # plot
    ax.scatter(X, Y, Z)
    ax.set_title(f"c_{j}")
    # remove tick values (they're messy)
    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])
    ax.zaxis.set_ticklabels([])

In [None]:
# These are the centroids for each of our subspaces, subvector u_0 will be mapped to a centroid within subspace c_0, u_1 to c_1, etc, etc.
# Let's go ahead and do this. First, we will define a function to find the nearest centroid using Euclidean distance.

def euclidean(v, u):
    distance = sum((x - y) ** 2 for x, y in zip(v, u)) ** .5
    return distance

def nearest(c_j, u_j):
    distance = 9e9
    for i in range(k_):
        new_dist = euclidean(c_j[i], u_j)
        if new_dist < distance:
            nearest_idx = i
            distance = new_dist
    return nearest_idx

In [None]:
# And now we calculate the nearest centroids for each subspace.
ids = []
for j in range(m):
    i = nearest(c[j], u[j])
    ids.append(i)
ids

In [None]:
# Finally, we need a way to translate these IDs back into the centroid co-ordinates - well, we already have it -
# our codebook c, when it comes to comparing vectors we don't use the centroid IDs, we use the centroids themselves (our reproduction values).
q = []
for j in range(m):
    c_ji = c[j][ids[j]]
    q.extend(c_ji)
q

In [None]:
# We typical measure the error between our quantized vectors q and the originals x using mean squard error (MSE):
def mse(v, u):
    error = sum((x - y) ** 2 for x, y in zip(v, u)) / len(v)
    return error
mse(x, q)

# When using many vectors, we can to minimize the MSE over our original vectors and the centroids by increasing the number of centroids. However this will increase index size and so must be balanced.
# Lower MSE == more accurate search results and higher memory usage.

In [None]:
# Now lets try PQ with FAISS

import numpy as np

# now define a function to read the fvecs file format of Sift1M dataset
def read_fvecs(fp):
    a = np.fromfile(fp, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy().view('float32')

# 1M samples, cut down to 500K
xb = read_fvecs('/mnt/sda/vectors/sift/sift_base.fvecs')[:500_000]
# queries
xq = read_fvecs('/mnt/sda/vectors/sift/sift_query.fvecs')[0].reshape(1, -1)

print(xb.shape)
print(xq.shape)

In [None]:
# Our first index is a pure PQ implementation using IndexPQ. To initialize the index we need to define three parameters.
import faiss


# We have our vector dimensionality D, the number of subvectors we’d like to split our full vectors into (we must assert that D is divisible by m).

# Finally, we include the nbits parameter. This defines the number of bits that each subquantizer can use, we can translate this into the number
# of centroids assigned to each subspace as k_ = 2**nbits. An nbits of 11 leaves us with 2048 centroids per subspace.
D = xb.shape[1]
m = 8
assert D % m == 0
nbits = 8  # number of bits per subquantizer, k* = 2**nbits
index = faiss.IndexPQ(D, m, nbits)

In [None]:
# Time to train
print(index.is_trained)
index.train(xb)
print(index.is_trained)

In [None]:
index.add(xb)  # this is also very slow for large nbits
k = 100  # return top k results

In [None]:
dist, I = index.search(xq, k)

In [None]:
%%timeit
index.search(xq, k)

In [None]:
# Search time is nothing special, PQ alone is still an exhaustive search so we would expect nothing spectacular here - but we can make it fast as we'll see later.
# Let's compare our results against those produced by a non-quantized flat index.
l2_index = faiss.IndexFlatL2(D)
l2_index.add(xb)

In [None]:
%%time
l2_dist, l2_I = l2_index.search(xq, k)

In [None]:
sum([1 for i in I[0] if i in l2_I])

In [None]:
# A recall of 50%, not cutting-edge, but a reasonable sacrifice if this allows us to search larger datasets.
# Let's see if PQ has made good on it's promise of reduced memory usage.

dir_path = "/mnt/sda/vectors/perf/"

import os
def get_memory(filename, index):
    faiss.write_index(index, filename)
    file_size = os.path.getsize(filename)
    os.remove(filename)
    return file_size

print(get_memory(dir_path + "temp.index", l2_index))
print(get_memory(dir_path + "temp.index", index))

In [None]:
# This is a reduction from 256 MB to 4.1 MB
# But what about this not so great search-speed? Is there anything we can do about that? Fortunately yes!
# We can improve search-speed by using another quantization step - we add a coarse quantizer, IndexIVF to the process.

# But what about this not so great search-speed? Is there anything we can do about that? Fortunately yes!
# We can improve search-speed by using another quantization step - we add a coarse quantizer, IndexIVF to the process.

vecs = faiss.IndexFlatL2(D)

nlist = 2048  # how many Voronoi cells (must be >= k* which is 2**nbits)
nbits = 8  # when using IVF+PQ, higher nbits values are not supported
index = faiss.IndexIVFPQ(vecs, D, nlist, m, nbits)
print(f"{2**nbits=}")  # our value for nlist

In [None]:
print(index.is_trained)
index.train(xb)
index.add(xb)
print(index.is_trained)

In [43]:
%%timeit
dist, I = index.search(xq, k)

197 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [44]:
# Lightning fast search... But how is the recall?
sum([1 for i in I[0] if i in l2_I])

41

In [45]:
# We can improve the recall by increasing nprobe
index.nprobe = 2
dist, I = index.search(xq, k)
print(sum([1 for i in I[0] if i in l2_I]))

41


In [48]:
# We can improve the recall by increasing nprobe
index.nprobe = 48
dist, I = index.search(xq, k)
print(sum([1 for i in I[0] if i in l2_I]))

47


In [49]:
%%timeit
dist, I = index.search(xq, k)

274 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [50]:
%%timeit
l2_dist, l2_I = l2_index.search(xq, k)

44.4 ms ± 736 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [52]:
# A significant speed increase from 44ms to 274µs, and what are the differences in memory usage?
print(get_memory(dir_path + "temp.index", l2_index))
print(get_memory(dir_path + "temp.index", index))

256000045
9196212


In [None]:
# # This is a reduction from 256 MB to 9.2 MB. Slightly more than 4.1 MB but still worth it.