Skip to content

Commit

Permalink
More comments and support for multiprobe inserts
Browse files Browse the repository at this point in the history
  • Loading branch information
tahle committed Apr 11, 2023
1 parent ef1927f commit 2df6a42
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 139 deletions.
43 changes: 43 additions & 0 deletions examples/multiprobes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
from fast_pq import IVF, FastPQ, cdist

np.random.seed(10)

n = 1000
d = 10
nq = 30
at = 10
dpb = 2

X = np.random.randn(n, d).astype(np.float32)
qs = np.random.randn(nq, d).astype(np.float32)

def compute_recall(metric, build_probes, query_probes):
if at < n:
trus = cdist(qs, X).argpartition(axis=1, kth=at)[:, :at]
else:
trus = np.broadcast_to(np.arange(n), (nq, n))
pq = FastPQ(dims_per_block=dpb)
ivf = IVF(metric, int(n**0.5), pq)
ivf.fit(X).build(X, n_probes=build_probes)
recall_at = 0
for q, tru in zip(qs, trus):
guess = ivf.query(q, k=at, n_probes=query_probes)
recall_at += len(set(guess) & set(tru))
return recall_at / nq / at

# Print header row
print(f"Recall {at}@{at} using build_probes=b and query_probes=q.")
print("b/q", end=' ')
for query_probes in range(1, 5):
print(f"{query_probes:5}", end=' ')
print()

# Print table content
for build_probes in range(1, 5):
print(f"{build_probes:4}", end=' ')
for query_probes in range(1, 5):
recall = compute_recall("euclidean", build_probes, query_probes)
print(f"{recall:.2f}", end=", ")
print()

19 changes: 17 additions & 2 deletions fast_pq/_fast_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,23 @@ cpdef void estimate_pq_sse(uint64_t[:,::1] data, uint64_t[::1] tables,

cpdef void query_pq_sse(uint64_t[:,::1] data, uint64_t[::1] tables, int[::1] indices,
int[::1] vals, bool signd) nogil:
''' Given a N x D dataset quantized into byte sizes chunks,
looks up each value in a table out outputs into `out`. '''
'''
Given a N x D dataset quantized into byte-sized chunks,
looks up each value in a table and outputs into `out`.
Parameters
----------
data : 2D memoryview of uint64_t
The quantized dataset.
tables : 1D memoryview of uint64_t
The lookup tables.
indices : 1D memoryview of int
The output indices.
vals : 1D memoryview of int
The output values.
signd : bool
Determines if the distance is signed or unsigned.
'''
cdef:
int i, j
__m128i block_dists, top_bound, cmp_mask
Expand Down
177 changes: 126 additions & 51 deletions fast_pq/fast_pq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
import numpy as np
import sklearn.cluster

import warnings
from sklearn.exceptions import ConvergenceWarning

from ._transform import transform_data, transform_tables
from ._fast_pq import query_pq_sse, estimate_pq_sse

warnings.simplefilter("error", category=ConvergenceWarning)

def pad(arr, mults):
"""
Pad an input array with zeros such that its dimensions are multiples of the specified values.
Parameters
----------
arr : numpy.ndarray
The input array to be padded.
mults : tuple of int
A tuple containing the desired multiples for each dimension of the input array.
The length of the tuple must match the number of dimensions in the input array.
Returns
-------
new_arr : numpy.ndarray
The padded array with dimensions that are multiples of the specified values.
Notes
-----
The padding is added by extending the dimensions with zeros. The original data
remains unchanged and is located at the beginning of each dimension in the output array.
"""
new_shape = tuple(s + (-s) % m for s, m in zip(arr.shape, mults))
# TODO: It would be nice to pad using the code with largest possible distance
# in each dimension, so as to maximiize the likelihood of not seeing any of them
Expand All @@ -15,31 +41,60 @@ def pad(arr, mults):


def bottom_k(arr, k):
"""Returns the k smallest indices of arr"""
if k >= len(arr):
return np.arange(len(arr))
return np.argpartition(arr, k)[:k]


class FastPQ:
def __init__(self, dims_per_block):
"""
Initializes the FastPQ class with the specified number of dimensions per block.
Parameters
----------
dims_per_block : int
The number of dimensions per block.
"""
self.dims_per_block = dims_per_block
self.centers = None # Shape: (n_blocks, 16, dims_per_block)
self.center_norms_sq = None # Shape: (n_blocks, 16)

def fit(self, data, verbose=False):
"""
Fits the FastPQ model to the given data.
Parameters
----------
data : array-like
The input data to fit the FastPQ model.
verbose : bool, optional, default=False
If True, prints additional information during the fitting process.
Returns
-------
self : FastPQ
The fitted FastPQ model.
"""
assert data.size > 0, "Can't fit no data"
# SSE assumes the number of rows is divisible by 16.
# It also needs the number of columns to be even, so we pad to a multiple
# of 2 * self.dims_per_block.
data = pad(data, (16, 2 * self.dims_per_block))
n, d = data.shape
assert d % self.dims_per_block == 0
assert (d // self.dims_per_block) % 2 == 0
dpb = self.dims_per_block
cl = sklearn.cluster.KMeans(16, n_init=1)
# We always use 16 clusters in FastPQ, since we want to use 4 bit SSE operations.
cl = sklearn.cluster.KMeans(16, n_init=2)
centers = []
for i in range(d // self.dims_per_block):
if verbose:
print(f"Fitting block {i}")
cl.fit(data[:, i * dpb : (i + 1) * dpb])
try:
cl.fit(data[:, i * dpb : (i + 1) * dpb])
except ConvergenceWarning:
pass
# It doesn't give too much precision to do separate centers for each block,
# but durnig queries we need seperate distance tables per block anyway, so
# it doesn't cost us much.
Expand All @@ -49,6 +104,19 @@ def fit(self, data, verbose=False):
return self

def transform(self, data):
"""
Transforms the given data using the FastPQ model.
Parameters
----------
data : array-like
The input data to transform using the FastPQ model.
Returns
-------
tuple
A tuple containing the transformed data and the number of true elements in the data.
"""
assert self.centers is not None, "PQ has not been fitted"
if data.size == 0:
return data
Expand All @@ -67,38 +135,59 @@ def fit_transform(self, data):
return self.fit(data).transform(data)

def distance_table(self, q):
"""
Computes the distance table for the given query vector q.
Parameters
----------
q : array-like
The query vector.
Returns
-------
_FastDistanceTable
The FastDistanceTable object containing the computed distance table.
"""
q = pad(q, (2 * self.dims_per_block,))
dpb = self.dims_per_block
n_blocks = q.size / dpb

parts = q.reshape(-1, dpb)

# Center the data in the range [-128, 128]
# TODO: Is this the best scaling formula?
# TODO: Would this be faster if we used the same centers for each block?
dists = self.center_norms_sq - 2 * np.einsum("ijk,ik->ij", self.centers, parts)
# dists += (parts * parts).sum(axis=1, keepdims=True)
# shift = np.mean(dists)
# print(np.mean(dists), np.median(dists))
# shift = 1
# shift = 128 / n_blocks
# scale = 128 / (np.max(-(dists-shift)) * np.sqrt(n_blocks))
# shift = np.mean(dists) / 2

# We do this by shifting by the median and scaling by the nuber of blocks.
# The idea is that we don't want to have an overflow as we add together
# distances in uint8 format.
# TODO: Is this the best scaling formula?
shift = np.median(dists)
# shift = 0
# scale = 1
scale = 128 / (np.max(np.abs(dists - shift)) * np.sqrt(n_blocks))
table = np.round(
(dists - shift) * scale
) # Round to nearest integer towards zero.

# Round to nearest integer towards zero.
table = np.round((dists - shift) * scale)

# The transformation doesn't care about the sign, so we just use uint
table = table.astype(np.uint8)
trans = transform_tables(table)
return _FastDistanceTable(q, trans, shift, scale, signed=True)

def udistance_table(self, q):
# Experimental
"""
Computes the unsigned distance table for the given query vector q.
Currently experimental.
Parameters
----------
q : array-like
The query vector.
Returns
-------
_FastDistanceTable
The FastDistanceTable object containing the computed unsigned distance table.
"""
q = pad(q, (2 * self.dims_per_block,))
dpb = self.dims_per_block
n_blocks = q.size / dpb
Expand Down Expand Up @@ -158,52 +247,38 @@ def top(self, transformed_data, data, k=1, rescore=None, out=None):

def ctop(self, transformed_data, data, k=1, rescore=None):
"""
Like top, but uses cython.
Like top, but uses the query_pq_sse Cython method to directly retrieve the
bottom-k indices from the transformed_data, rather than estimating all distances
and computing the bottom_k in numpy. Should generally be faster than than top(...)
"""
true_n, transformed_data = transformed_data
k = min(k, true_n)
# In the first pass we collect `rescore` many rows
if not rescore:
rescore = min(2 * k + 10, true_n)
assert true_n >= rescore >= k

indices = np.zeros((rescore,), dtype=np.int32)
values = np.zeros((rescore,), dtype=np.int32)
query_pq_sse(transformed_data, self.tables, indices, values, True)
good_indices = indices < true_n # TODO: We remove paddinig in a kinda dumb way

# The transformed_data has been padded with 0-rows to a multiple of 16.
# We remove those "fake positives" here.
good_indices = indices < true_n
indices = indices[good_indices]
if rescore > k:
diff = data[indices] - self.q[: data.shape[1]] # Remove padding from q
dists = np.einsum("ij,ij->i", diff, diff)
best = bottom_k(dists, k=k)
return indices[best], dists[best]
values = values[good_indices]
return indices, values

def ctops(self, transformed_datas, data, k=1, rescore=None):
"""
Like ctop, but does takes multiple datasets at the same time, which might be
useful when doing multi-probing IVF.
"""
true_n, transformed_data = transformed_data
k = min(k, true_n)
if not rescore:
rescore = min(2 * k + 10, true_n)
assert true_n >= rescore >= k

# Basically, these two arrays define a heap?
indices = np.zeros((rescore,), dtype=np.int32)
values = np.zeros((rescore,), dtype=np.int32)
for true_n, transformed_data in transformed_datas:
query_pq_sse(transformed_data, self.tables, indices, values, True)
# In a second pass we compute the true distances and return the actually
# closest points. If we got fewer or exactly k outputs, there is no need
# to compute the true distaneces.
if rescore <= k:
values = values[good_indices]
return indices, values

good_indices = indices < true_n # TODO: We remove paddinig in a kinda dumb way
indices = indices[good_indices]
if rescore > k:
diff = data[indices] - self.q[: data.shape[1]] # Remove padding from q
dists = np.einsum("ij,ij->i", diff, diff)
best = bottom_k(dists, k=k)
return indices[best], dists[best]
values = values[good_indices]
return indices, values
# Remove padding from q
diff = data[indices] - self.q[: data.shape[1]]
dists = np.einsum("ij,ij->i", diff, diff)
best = bottom_k(dists, k=k)
return indices[best], dists[best]


class DummyPQ:
Expand Down

0 comments on commit 2df6a42

Please sign in to comment.