In [None]:
import pytest
import numpy as np
import itertools
from matplotlib import pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearnex import patch_sklearn, unpatch_sklearn

plt.rcParams['figure.figsize'] = [12, 12]
plt.rcParams['figure.dpi'] = 100

In [None]:
%load_ext cython

In [None]:
%%cython -f --compile-args=-fopenmp --link-args=-fopenmp --annotate
# cython: profile=False
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
import numpy as np
cimport numpy as np
from libc.stdlib cimport malloc, calloc, free
from libc.stdio cimport printf
from libc.math cimport sqrt
from cython.parallel cimport prange, parallel
from cython cimport floating, integral

# As of now the best size found by tuning
DEF CHUNK_SIZE = 4096
DEF N_THREADS = 4
DEF DEBUG = False
DEF INF = 1e19

from sklearn.utils._cython_blas cimport _gemm, BLAS_Order, BLAS_Trans
from sklearn.utils._cython_blas cimport ColMajor, RowMajor, Trans, NoTrans
    
cdef void _argmin_chunk(
    floating[:, ::1] X_c,            # IN
    floating[:, ::1] Y_c,            # IN
    floating[::1] Y_sq_norms,        # IN
    integral[::1] Z_c_idx,           # OUT  
    floating[::1] Z_c_r_dist,        # OUT
    floating *pairwise_red_distances,# OUT
    integral idx_offset,
) nogil:
    cdef:
# Careful, LDA, LDB and LDC are given for F-ordered arrays.
# Here, we use their counterpart values as indicated in the documentation.
# See the documentation of parameters here:
# https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html
        integral i, j
        floating r_dist_j
        integral M = X_c.shape[0]
        integral N = Y_c.shape[0]
        integral K = X_c.shape[1]
        integral LDA = K
        integral LDB = K
        integral LDC = N
        floating alpha = -2.0
        floating beta = 1.0

    # Instead of computing the full pairwise squared distances matrix,
    # ||X_c - Y_c||² = ||X_c||² - 2 X_c.Y_c^T + ||Y_c||², we only need
    # to store the - 2 X_c.Y_c^T + ||Y_c||² term since the argmin
    # for a given sample X_c^{i} does not depend on ||X_c^{i}||²
    for i in range(X_c.shape[0]):
        for j in range(Y_c.shape[0]):
            pairwise_red_distances[i * Y_c.shape[0] + j] = Y_sq_norms[j]

    # pairwise_distances += -2 * X_c.dot(Y_c.T)
    _gemm(RowMajor, NoTrans, Trans,
          X_c.shape[0], Y_c.shape[0], X_c.shape[1],
          -2.0,
          &X_c[0, 0], X_c.shape[1],
          &Y_c[0, 0], X_c.shape[1], 1.0,
          pairwise_red_distances, Y_c.shape[0])
    
    # Computing argmins here
    for i in range(X_c.shape[0]):
        for j in range(Y_c.shape[0]):
            if pairwise_red_distances[i * Y_c.shape[0] + j] < Z_c_r_dist[i]:
                Z_c_r_dist[i] = pairwise_red_distances[i * Y_c.shape[0] + j] 
                Z_c_idx[i] = j + idx_offset
                
                
cdef int _pairwise_argmin(
    floating[:, ::1] X,              # IN
    floating[:, ::1] Y,              # IN
    floating[::1] Y_sq_norms,        # IN
    integral[::1] Z_idx,             # OUT
    floating[::1] Z_r_dist,          # OUT
    integral chunk_size,
) nogil except -1:
    cdef:
        integral n_samples_chunk = chunk_size / X.shape[1]
        
        integral X_n_samples_chunk = min(X.shape[0], n_samples_chunk)
        integral X_n_full_chunks = X.shape[0] // X_n_samples_chunk
        integral X_n_samples_rem = X.shape[0] % X_n_samples_chunk

        integral Y_n_samples_chunk = min(Y.shape[0], n_samples_chunk)
        integral Y_n_full_chunks = Y.shape[0] / Y_n_samples_chunk
        integral Y_n_samples_rem = Y.shape[0] % Y_n_samples_chunk
        
        # Counting remainder chunk in total number of chunks
        integral X_n_chunks = X_n_full_chunks + (X.shape[0] != (X_n_full_chunks * X_n_samples_chunk))
        integral Y_n_chunks = Y_n_full_chunks + (Y.shape[0] != (Y_n_full_chunks * Y_n_samples_chunk))
        
        integral n_chunks = X_n_chunks * Y_n_chunks
        integral num_threads = min(n_chunks, N_THREADS)

        integral idx
        integral i_chunk = 0
        
        integral X_start, X_end, Y_start, Y_end
        integral X_chunk_idx, Y_chunk_idx
                
        floating *pairwise_red_distances

    IF DEBUG:
        printf("Using %ld chunk pairs\nUsing %ld threads\n", n_chunks, num_threads)
    
    with nogil, parallel(num_threads=num_threads):
        pairwise_red_distances = <floating*> malloc(Y_n_samples_chunk * X_n_samples_chunk * sizeof(floating))
                
        for X_chunk_idx in prange(X_n_chunks, schedule='static'):
            X_start = X_chunk_idx * X_n_samples_chunk
            if X_chunk_idx == X_n_chunks - 1 and X_n_samples_rem > 0:
                X_end = X_start + X_n_samples_rem
            else:
                X_end = X_start + X_n_samples_chunk

            for Y_chunk_idx in range(Y_n_chunks):
                Y_start = Y_chunk_idx * Y_n_samples_chunk
                if Y_chunk_idx == Y_n_chunks - 1 and Y_n_samples_rem > 0:
                    Y_end = Y_start + Y_n_samples_rem
                else:
                    Y_end = Y_start + Y_n_samples_chunk

                _argmin_chunk(X[X_start:X_end, :],
                              Y[Y_start:Y_end, :],
                              Y_sq_norms[Y_start:Y_end],
                              Z_idx[X_start:X_end],
                              Z_r_dist[X_start:X_end],
                              pairwise_red_distances,
                              Y_start)
                
        free(pairwise_red_distances)
        
    IF DEBUG:
        printf("Done cdef _pairwise_argmin\n\n")
        
# Python-accessible interfaces
def pairwise_argmin(
    floating[:, ::1] X,
    floating[:, ::1] Y,
    integral chunk_size = CHUNK_SIZE,
):
    int_dtype = np.int32 if integral is int else np.int64
    float_dtype = np.float32 if floating is float else np.float64
    cdef:
        integral[::1] Z_idx = np.zeros((X.shape[0],), dtype=int_dtype)
        floating[::1] Z_r_dist = np.full((X.shape[0],), INF, dtype=float_dtype)
        floating[::1] Y_sq_norms = np.einsum('ij,ij->i', Y, Y)
    
    _pairwise_argmin(X, Y, Y_sq_norms, Z_idx, Z_r_dist, chunk_size)
    
    return np.asarray(Z_idx)

### Visual screening for 2D

In [None]:
X_train = np.random.rand(int(20 * 2)).reshape((-1, 2))
X_test = np.copy(X_train[::2, :]) + 0.02

In [None]:
X_train.nbytes / 1e9, X_test.nbytes /1e9

In [None]:
X_train.shape, X_test.shape

In [None]:
argmins = pairwise_argmin(X_test, X_train, chunk_size = 100)

In [None]:
argmins

In [None]:
argmins.shape

In [None]:
_ = plt.scatter(X_test[:,0], X_test[:,1])
_ = plt.scatter(X_train[:,0], X_train[:,1])
_ = plt.legend(["X_test", "X_train"])

for i, j in enumerate(argmins):
     _ = plt.plot((X_test[i, 0], X_train[j, 0]),
             (X_test[i, 1], X_train[j, 1]),
             'g--', linewidth=0.5)
    
_ = plt.title("Argmin for X_test on X_train")

plt.show()

### Testing for correctness against scikit-learn

In [None]:
@pytest.mark.parametrise("n", [10 ** i for i in [2, 3, 4]])
@pytest.mark.parametrise("d", [2, 5, 10, 100])
def test_correctness(n, d):
    Y = np.random.rand(int(n * d)).reshape((-1, d))
    X = np.random.rand(int(n * d // 2)).reshape((-1, d))
        
    neigh = NearestNeighbors(n_neighbors=1, algorithm='brute')
    neigh.fit(Y)
    
    argmins_sk = neigh.kneighbors(X, return_distance=False)
    print("Done with neigh.kneighbors")
    argmins = pairwise_argmin(X, Y)
    print("Done with pairwise_argmin")

    np.testing.assert_array_equal(argmins, np.ndarray.flatten(argmins_sk))


In [None]:
for n, d in itertools.product([10 ** i for i in [2, 3, 4]], [2, 5, 10, 100]):
    print(n, d)
    test_correctness(n, d)

### Comparison against scikit-learn

In [None]:
Y = np.random.rand(int(1e6)).reshape((-1, 500))
X = np.copy(Y[::2, :]) + 0.02

In [None]:
X.nbytes / 1e9, Y.nbytes /1e9

In [None]:
X.shape, Y.shape

In [None]:
neigh = NearestNeighbors(n_neighbors=1, algorithm='brute').fit(Y)

In [None]:
%timeit neigh.kneighbors(X, return_distance=False)

In [None]:
%timeit pairwise_argmin(X, Y, chunk_size=512)

In [None]:
%timeit pairwise_argmin(X, Y, chunk_size=1024)

In [None]:
%timeit pairwise_argmin(X, Y, chunk_size=2048)

In [None]:
%timeit pairwise_argmin(X, Y, chunk_size=4096)

In [None]:
%timeit pairwise_argmin(X, Y, chunk_size=4096 * 2)

In [None]:
patch_sklearn()
neigh = NearestNeighbors(n_neighbors=1, algorithm='brute').fit(Y)

In [None]:
%timeit neigh.kneighbors(X, return_distance=False)

In [None]:
unpatch_sklearn()