In [None]:
import pytest
import numpy as np
import itertools
from matplotlib import pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearnex import patch_sklearn, unpatch_sklearn

plt.rcParams['figure.figsize'] = [12, 12]
plt.rcParams['figure.dpi'] = 100

In [None]:
%load_ext cython

In [None]:
%%cython -f --compile-args=-fopenmp --link-args=-fopenmp --annotate
# cython: profile=False
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
import numpy as np
cimport numpy as np
from libc.stdlib cimport malloc, calloc, free
from libc.string cimport memset
from libc.stdio cimport printf 
from cython.parallel cimport prange, parallel
from cython cimport floating, integral

# As of now the best size found by tuning
DEF CHUNK_SIZE = 4096
DEF N_THREADS = 4
DEF DEBUG = False
DEF INF = 1e19

from sklearn.utils._cython_blas cimport _gemm, BLAS_Order, BLAS_Trans
from sklearn.utils._cython_blas cimport ColMajor, RowMajor, Trans, NoTrans

cdef int _push(
    floating* dist,
    integral* idx,
    integral size,
    floating val,
    integral i_val,
) nogil except -1:
    """push (val, i_val) into the heap (dist, idx) of the given size"""
    cdef:
        integral i, ic1, ic2, i_swap

    # check if val should be in heap
    if val > dist[0]:
        return 0

    # insert val at position zero
    dist[0] = val
    idx[0] = i_val

    # descend the heap, swapping values until the max heap criterion is met
    i = 0
    while True:
        ic1 = 2 * i + 1
        ic2 = ic1 + 1

        if ic1 >= size:
            break
        elif ic2 >= size:
            if dist[ic1] > val:
                i_swap = ic1
            else:
                break
        elif dist[ic1] >= dist[ic2]:
            if val < dist[ic1]:
                i_swap = ic1
            else:
                break
        else:
            if val < dist[ic2]:
                i_swap = ic2
            else:
                break

        dist[i] = dist[i_swap]
        idx[i] = idx[i_swap]

        i = i_swap

    dist[i] = val
    idx[i] = i_val

    return 0

cdef inline void dual_swap(
    floating* dist,
    integral* idx,
    integral i1,
    integral i2
) nogil:
    """swap the values at inex i1 and i2 of both dist and idx"""
    cdef:
        floating dtmp = dist[i1]
        integral itmp = idx[i1]
    
    dist[i1] = dist[i2]
    dist[i2] = dtmp

    idx[i1] = idx[i2]
    idx[i2] = itmp
    
cdef int _simultaneous_sort(
    floating* dist,
    integral* idx,
    integral size
) nogil except -1:
    """
    Perform a recursive quicksort on the dist array, simultaneously
    performing the same swaps on the idx array.
    """
    cdef:
        integral pivot_idx, i, store_idx
        floating pivot_val

    # in the small-array case, do things efficiently
    if size <= 1:
        pass
    elif size == 2:
        if dist[0] > dist[1]:
            dual_swap(dist, idx, 0, 1)
    elif size == 3:
        if dist[0] > dist[1]:
            dual_swap(dist, idx, 0, 1)
        if dist[1] > dist[2]:
            dual_swap(dist, idx, 1, 2)
            if dist[0] > dist[1]:
                dual_swap(dist, idx, 0, 1)
    else:
        # Determine the pivot using the median-of-three rule.
        # The smallest of the three is moved to the beginning of the array,
        # the middle (the pivot value) is moved to the end, and the largest
        # is moved to the pivot index.
        pivot_idx = size / 2
        if dist[0] > dist[size - 1]:
            dual_swap(dist, idx, 0, size - 1)
        if dist[size - 1] > dist[pivot_idx]:
            dual_swap(dist, idx, size - 1, pivot_idx)
            if dist[0] > dist[size - 1]:
                dual_swap(dist, idx, 0, size - 1)
        pivot_val = dist[size - 1]

        # partition indices about pivot.  At the end of this operation,
        # pivot_idx will contain the pivot value, everything to the left
        # will be smaller, and everything to the right will be larger.
        store_idx = 0
        for i in range(size - 1):
            if dist[i] < pivot_val:
                dual_swap(dist, idx, i, store_idx)
                store_idx += 1
        dual_swap(dist, idx, store_idx, size - 1)
        pivot_idx = store_idx

        # recursively sort each side of the pivot
        if pivot_idx > 1:
            _simultaneous_sort(dist, idx, pivot_idx)
        if pivot_idx + 2 < size:
            _simultaneous_sort(dist + pivot_idx + 1,
                               idx + pivot_idx + 1,
                               size - pivot_idx - 1)
    return 0

### 
    
cdef void _k_closest_chunk(
    floating[:, ::1] X_c,            # IN
    floating[:, ::1] Y,              # IN
    floating[::1] Y_sq_norms,        # IN
    floating *dist_middle_terms,     # IN
    floating *heap_red_distances,    # IN/OUT
    integral *heap_indices,          # IN/OUT
    integral k,                      # IN
    # ID of the first element of Y_c
) nogil:
    cdef:
        integral i, j
    # Instead of computing the full pairwise squared distances matrix,
    # ||X_c - Y||² = ||X_c||² - 2 X_c.Y^T + ||Y||², we only need
    # to store the - 2 X_c.Y^T + ||Y||² term since the argmin
    # for a given sample X_c^{i} does not depend on ||X_c^{i}||²
            
    # Careful: LDA, LDB and LDC are given for F-ordered arrays.
    # Here, we use their counterpart values as indicated in the documentation.
    # See the documentation of parameters here:
    # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html
    #
    # dist_middle_terms = -2 * X_c.dot(Y_c.T)
    _gemm(RowMajor, NoTrans, Trans,
          X_c.shape[0], Y.shape[0], X_c.shape[1],
          -2.0,
          &X_c[0, 0], X_c.shape[1],
          &Y[0, 0], X_c.shape[1], 0.0,
          dist_middle_terms, Y.shape[0])
    
    # Computing argmins here
    for i in range(X_c.shape[0]):
        for j in range(Y.shape[0]):
            _push(heap_red_distances + i * k,
                  heap_indices + i * k,
                  k,
                  # reduced distance: - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
                  dist_middle_terms[i * Y.shape[0] + j] + Y_sq_norms[j],
                  j)
            
                
                
cdef int _pairwise_k_closest(
    floating[:, ::1] X,              # IN
    floating[:, ::1] Y,              # IN
    floating[::1] Y_sq_norms,        # IN
    integral chunk_size,
    integral[:, ::1] Z_idx,          # OUT
) nogil except -1:
    cdef:
        integral n_samples_chunk = chunk_size / X.shape[1]
        
        integral X_n_samples_chunk = min(X.shape[0], n_samples_chunk)
        integral X_n_full_chunks = X.shape[0] // X_n_samples_chunk
        integral X_n_samples_rem = X.shape[0] % X_n_samples_chunk
        
        # Counting remainder chunk in total number of chunks
        integral X_n_chunks = X_n_full_chunks + (X.shape[0] != (X_n_full_chunks * X_n_samples_chunk))
        
        integral n_chunks = X_n_chunks
        integral num_threads = min(n_chunks, N_THREADS)

        integral idx = 0
        integral j = 0
        integral i_chunk = 0
        integral k = Z_idx.shape[1]
        
        integral X_start, X_end
        integral X_chunk_idx 
        floating *dist_middle_terms

    IF DEBUG:
        printf("Using %ld chunk pairs\nUsing %ld threads\n", n_chunks, num_threads)
    
    with nogil, parallel(num_threads=num_threads):
        # Thread local buffers
        
        # Temporary buffer for the -2 * X_c.dot(Y_c.T) term
        dist_middle_terms = <floating*> malloc(Y.shape[0] * X_n_samples_chunk * sizeof(floating))
        heap_red_distances = <floating*> malloc(X_n_samples_chunk * k * sizeof(floating))
        
        for X_chunk_idx in prange(X_n_chunks, schedule='static'):
            # We reset the heap between X chunks
            memset(heap_red_distances, <integral> INF, X_n_samples_chunk * k * sizeof(floating))
            
            X_start = X_chunk_idx * X_n_samples_chunk
            if X_chunk_idx == X_n_chunks - 1 and X_n_samples_rem > 0:
                X_end = X_start + X_n_samples_rem
            else:
                X_end = X_start + X_n_samples_chunk

            _k_closest_chunk(
                X[X_start:X_end, :],
                Y,
                Y_sq_norms,
                dist_middle_terms,
                heap_red_distances,
                &Z_idx[X_start, 0],
                k
            )
            
            # Getting the indices of the k-closest points in
            # the sorted order
            for idx in range(X_end - X_start):
                _simultaneous_sort(
                    heap_red_distances + idx * k,
                    &Z_idx[X_start + idx, 0],
                    k
                )
                           
        # end: for X_chunk_idx
        free(dist_middle_terms)
        free(heap_red_distances)
    
    # end: with nogil, parallel
    IF DEBUG:
        printf("Done cdef _pairwise_k_closest \n\n")
        
# Python-accessible interfaces
def pairwise_argmin(
    floating[:, ::1] X,
    floating[:, ::1] Y,
    integral k,
    integral chunk_size = CHUNK_SIZE,
):
    int_dtype = np.int32 if integral is int else np.int64
    float_dtype = np.float32 if floating is float else np.float64
    cdef:
        integral[:, ::1] Z_idx = np.full((X.shape[0], k), -1, dtype=int_dtype)
        floating[::1] Y_sq_norms = np.einsum('ij,ij->i', Y, Y)
    
    _pairwise_k_closest(X, Y, Y_sq_norms, chunk_size, Z_idx)
    
    return np.asarray(Z_idx)

In [None]:
X_train = np.random.rand(int(20 * 2)).reshape((-1, 2))
X_test = np.copy(X_train[::2, :]) + 0.02

In [None]:
X_train.nbytes / 1e9, X_test.nbytes /1e9

In [None]:
argmins = pairwise_argmin(X_test, X_train, chunk_size = 100, k=1)

In [None]:
argmins

In [None]:
_ = plt.scatter(X_test[:,0], X_test[:,1])
_ = plt.scatter(X_train[:,0], X_train[:,1])
_ = plt.legend(["X_test", "X_train"])

for i, j in enumerate(argmins):
     _ = plt.plot((X_test[i, 0], X_train[j, 0]),
             (X_test[i, 1], X_train[j, 1]),
             'g--', linewidth=0.5)
    
_ = plt.title("Argmin for X_test on X_train")

plt.show()

### Testing for correctness against scikit-learn

In [None]:
@pytest.mark.parametrise("n", [10 ** i for i in [2, 3, 4]])
@pytest.mark.parametrise("d", [2, 5, 10, 100])
def test_correctness(n, d):
    np.random.seed(1)
    Y = np.random.rand(int(n * d)).reshape((-1, d))
    X = np.random.rand(int(n * d // 2)).reshape((-1, d))
        
    neigh = NearestNeighbors(n_neighbors=1, algorithm='brute')
    neigh.fit(Y)
    
    argmins_sk = neigh.kneighbors(X, return_distance=False)
    print("Done with neigh.kneighbors")
    argmins = pairwise_argmin(X, Y, k=1)
    print("Done with pairwise_argmin")

    np.testing.assert_array_equal(np.ndarray.flatten(argmins), np.ndarray.flatten(argmins_sk))


In [None]:
for n, d in itertools.product([10 ** i for i in [2, 3, 4]], [2, 5, 10, 100]):
    print(n, d)
    test_correctness(n, d)

### Comparison against scikit-learn

In [None]:
Y = np.random.rand(int(1e7)).reshape((-1, 500))
X = np.copy(Y[::2, :]) + 0.02
k = 100

In [None]:
X.nbytes / 1e9, Y.nbytes /1e9

In [None]:
X.shape, Y.shape

In [None]:
neigh = NearestNeighbors(n_neighbors=k, algorithm='brute').fit(Y)

In [None]:
%timeit neigh.kneighbors(X, return_distance=False)

In [None]:
%timeit pairwise_argmin(X, Y, k=k, chunk_size=4096 // 2)

In [None]:
%timeit pairwise_argmin(X, Y, k=k, chunk_size=4096)

In [None]:
%timeit pairwise_argmin(X, Y, k=k, chunk_size=4096 * 2)

In [None]:
%timeit pairwise_argmin(X, Y, k=k, chunk_size=4096 * 3)

In [None]:
%timeit pairwise_argmin(X, Y, k=k, chunk_size=4096 * 4)

In [None]:
patch_sklearn()
neigh = NearestNeighbors(n_neighbors=k, algorithm='brute').fit(Y)

In [None]:
%timeit neigh.kneighbors(X, return_distance=False)

In [None]:
unpatch_sklearn()