In [2]:
import pytest
import numpy as np
import itertools
from matplotlib import pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearnex import patch_sklearn

plt.rcParams['figure.figsize'] = [12, 12]
plt.rcParams['figure.dpi'] = 100

In [3]:
%load_ext cython

In [19]:
%%cython -f --compile-args=-fopenmp --link-args=-fopenmp --annotate
# cython: profile=True, cdivision=True, boundscheck=False, wraparound=False
import numpy as np
cimport numpy as np
from libc.stdlib cimport malloc, calloc, free
from libc.stdio cimport printf
from libc.math cimport sqrt
from cython.parallel cimport prange, parallel
from cython cimport floating, integral

DEF CHUNK_SIZE = 256
DEF N_THREADS = 4
DEF DEBUG = False
DEF INF = 1e19

from sklearn.utils._cython_blas cimport _gemm, BLAS_Order, BLAS_Trans
from sklearn.utils._cython_blas cimport ColMajor, RowMajor, Trans, NoTrans

cdef inline floating r_euclidean_dist(
    floating* x1,
    floating* x2,
    integral size,
) nogil:
    cdef floating d=0
    cdef np.intp_t j
    for j in range(size):
        d += (x1[j] - x2[j]) * (x1[j] - x2[j])
    return d

    
cdef inline integral _argmin_chunk(
    floating[:, ::1] X_c,        # IN
    floating[:, ::1] Y_c,        # IN
    floating[::1] Y_sq_norms,    # IN
    integral[::1] Z_c_idx,       # OUT  
    floating[::1] Z_c_r_dist,    # OUT
    floating *pairwise_distances,# OUT
    integral idx_offset,
) nogil:
    cdef:
        integral i, j
        floating r_dist_j
        
    # Instead of computing the full pairwise squared distances matrix,
    # ||X_c - Y_c||² = ||X_c||² - 2 X_c.Y_c^T + ||Y_c||², we only need
    # to store the - 2 X_c.Y_c^T + ||Y_c||² term since the argmin
    # for a given sample X_c^{i} does not depend on ||X_c^{i}||²
    for i in range(X_c.shape[0]):
        for j in range(Y_c.shape[0]):
            pairwise_distances[i * Y_c.shape[0] + j] = Y_sq_norms[j]

    # pairwise_distances += -2 * X_c.dot(Y_c.T)
    _gemm(RowMajor, NoTrans, Trans, X_c.shape[0], Y_c.shape[0],
          X_c.shape[1], -2.0, &X_c[0, 0], X_c.shape[1],
          &Y_c[0, 0], X_c.shape[1], 1.0, pairwise_distances,
          Y_c.shape[1])
    
    # Computing argmins here
    for i in range(X_c.shape[0]):
        for j in range(Y_c.shape[0]):
            if pairwise_distances[i * Y_c.shape[0] + j] < Z_c_r_dist[i]:
                Z_c_r_dist[i] = pairwise_distances[i * Y_c.shape[0] + j] 
                Z_c_idx[i] = j + idx_offset
                
                
cdef integral _pairwise_argmin(
    floating[:, ::1] X,              # IN
    floating[:, ::1] Y,              # IN
    floating[::1] Y_sq_norms,        # IN
    integral[::1] Z_idx,             # OUT
    floating[::1] Z_r_dist,          # OUT
    integral chunk_size,

) nogil:
    cdef:
        integral n_samples_chunk = chunk_size / X.shape[1]
        
        integral X_n_samples_chunk = min(X.shape[0], n_samples_chunk)
        integral X_n_full_chunks = X.shape[0] // X_n_samples_chunk
        integral X_n_samples_rem = X.shape[0] % X_n_samples_chunk

        integral Y_n_samples_chunk = min(Y.shape[0], n_samples_chunk)
        integral Y_n_full_chunks = Y.shape[0] / Y_n_samples_chunk
        integral Y_n_samples_rem = Y.shape[0] % Y_n_samples_chunk
        
        # Counting remainder chunk in total number of chunks
        integral X_n_chunks = X_n_full_chunks + (X.shape[0] != (X_n_full_chunks * X_n_samples_chunk))
        integral Y_n_chunks = Y_n_full_chunks + (Y.shape[0] != (Y_n_full_chunks * Y_n_samples_chunk))
        
        integral n_chunks = X_n_chunks * Y_n_chunks
        integral num_threads = min(n_chunks, N_THREADS)

        integral idx
        integral i_chunk = 0
        
        integral X_start, X_end, Y_start, Y_end
        integral X_chunk_idx, Y_chunk_idx
                
        floating *argmins_r_dist
    
    with nogil, parallel(num_threads=num_threads):
        argmins_r_dist = <floating*> malloc(Y_n_samples_chunk * X_n_samples_chunk * sizeof(int))
    
        for X_chunk_idx in prange(X_n_chunks, schedule='static'):
            X_start = X_chunk_idx * X_n_samples_chunk
            if X_chunk_idx == X_n_chunks - 1 and X_n_samples_rem > 0:
                X_end = X_start + X_n_samples_rem
            else:
                X_end = X_start + X_n_samples_chunk

            for Y_chunk_idx in range(Y_n_chunks):
                Y_start = Y_chunk_idx * Y_n_samples_chunk
                if Y_chunk_idx == Y_n_chunks - 1 and Y_n_samples_rem > 0:
                    Y_end = Y_start + Y_n_samples_rem
                else:
                    Y_end = Y_start + Y_n_samples_chunk

                _argmin_chunk(X[X_start:X_end, :],
                              Y[Y_start:Y_end, :],
                              Y_sq_norms[Y_start:Y_end],
                              Z_idx[X_start:X_end],
                              Z_r_dist[X_start:X_end],
                              argmins_r_dist,
                              Y_start)
                
        free(argmins_r_dist)
        
# Python-accessible interfaces
        
cpdef np.ndarray[int, ndim=2, mode='c'] pairwise_argmin(
    np.ndarray[floating, ndim=2, mode='c'] X,
    np.ndarray[floating, ndim=2, mode='c'] Y,
    int chunk_size = CHUNK_SIZE,
):
    cdef:
        int[::1] Z_idx = np.zeros((X.shape[0],), dtype=int)
        floating[::1] Z_r_dist = np.full((X.shape[0],), INF, dtype=float)
        floating[::1] Y_sq_norms = np.einsum('ij,ij->i', Y, Y)
    
    _pairwise_argmin(X, Y, Y_sq_norms, Z_idx, Z_r_dist, chunk_size)
    
    return np.asarray(Z_idx)


Error compiling Cython file:
------------------------------------------------------------
...
    cdef:
        int[::1] Z_idx = np.zeros((X.shape[0],), dtype=int)
        floating[::1] Z_r_dist = np.full((X.shape[0],), INF, dtype=float)
        floating[::1] Y_sq_norms = np.einsum('ij,ij->i', Y, Y)
    
    _pairwise_argmin(X, Y, Y_sq_norms, Z_idx, Z_r_dist, chunk_size)
                   ^
------------------------------------------------------------

/home/jsquared/.cache/ipython/cython/_cython_magic_e1fcb0783bb87b85180a79447b5590bb.pyx:139:20: no suitable method found

Error compiling Cython file:
------------------------------------------------------------
...
    cdef:
        int[::1] Z_idx = np.zeros((X.shape[0],), dtype=int)
        floating[::1] Z_r_dist = np.full((X.shape[0],), INF, dtype=float)
        floating[::1] Y_sq_norms = np.einsum('ij,ij->i', Y, Y)
    
    _pairwise_argmin(X, Y, Y_sq_norms, Z_idx, Z_r_dist, chunk_size)
                   ^
-------------------------

### Visual screening for 2D

In [None]:
Y = np.random.rand(int(200 * 2)).reshape((-1, 2))
X = np.copy(Y[::2, :]) + 0.02

In [None]:
X.nbytes / 1e9, Y.nbytes /1e9

In [None]:
X.shape, Y.shape

In [None]:
argmins = pairwise_argmin(X,Y, chunk_size = 10)

In [None]:
argmins

In [None]:
argmins.shape

In [None]:
_ = plt.scatter(X[:,0], X[:,1])
_ = plt.scatter(Y[:,0], Y[:,1])
_ = plt.legend(["X", "Y"])

for i, j in enumerate(argmins):
     _ = plt.plot((X[i, 0], Y[j, 0]),
             (X[i, 1], Y[j, 1]),
             'g--', linewidth=0.5)
    
_ = plt.title("Argmin for X on Y")

plt.show()

### Testing for correctness against scikit-learn

In [None]:
@pytest.mark.parametrise("n", [10 ** i for i in [2, 3, 4]])
@pytest.mark.parametrise("d", [2, 5, 10, 100])
def test_correctness(n, d):
    Y = np.random.rand(int(n * d)).reshape((-1, d))
    X = np.random.rand(int(n * d // 2)).reshape((-1, d))
        
    neigh = NearestNeighbors(n_neighbors=1, algorithm='brute')
    neigh.fit(Y)
    
    argmins_sk = neigh.kneighbors(X, return_distance=False)
    print("Done with neigh.kneighbors")
    argmins = pairwise_argmin(X, Y)
    print("Done with pairwise_argmin")

    np.testing.assert_array_equal(argmins, np.ndarray.flatten(argmins_sk))


In [None]:
for n, d in itertools.product([10 ** i for i in [2, 3, 4]], [2, 5, 10, 100]):
    print(n, d)
    test_correctness(n, d)

### Comparison against scikit-learn

In [None]:
Y = np.random.rand(int(1e5)).reshape((-1, 100))
X = np.copy(Y[::2, :]) + 0.02

In [None]:
X.nbytes / 1e9, Y.nbytes /1e9

In [None]:
X.shape, Y.shape

In [None]:
neigh = NearestNeighbors(n_neighbors=1, algorithm='brute').fit(Y)

In [None]:
%timeit neigh.kneighbors(X, return_distance=False)

In [None]:
%timeit pairwise_argmin(X, Y, chunk_size=1024)

In [None]:
patch_sklearn()

In [None]:
%timeit neigh.kneighbors(X, return_distance=False)

### Experiments with chunks

In [None]:
X_chunk = np.random.rand(50).reshape((-1, 2))
Y_chunk = np.copy(X_chunk[::2]) + 0.02
X_chunk.shape, Y_chunk.shape

In [None]:
argmins

In [None]:
argmins = pairwise_distance_chunk(X_chunk, Y_chunk)

plt.scatter(X_chunk[:,0], X_chunk[:,1])
plt.scatter(Y_chunk[:,0], Y_chunk[:,1])
plt.legend(["X_c","Y_c"])

for i, j in enumerate(argmins):
    plt.plot((X_chunk[i, 0], Y_chunk[j, 0]),
             (X_chunk[i, 1], Y_chunk[j, 1]),
             'g--', linewidth=1);
    
plt.title("Argmin for X_c on Y_c")

plt.show();

In [None]:
def create_dummy(dtype, working_memory, d=2):
    f_size = {
        np.float64: 8,
        np.float32: 4,
    }[dtype]
    
    i_size = 8
        
    n = working_memory // (2 * d * f_size + i_size)
    
    X = np.random.randn(n, d).astype(dtype)
    Y = X + 0.05
    return X, Y

In [None]:
@pytest.mark.parametrise("dtype", [np.float64])
@pytest.mark.parametrise("working_memory", [2 ** i for i in range(5, 12)])
@pytest.mark.parametrise("d", [2, 5, 10, 100])
def test_pairwise_chunk_on_dummy(dtype, working_memory, d):
    
    X_chunk, Y_chunk = create_dummy(dtype, working_memory=working_memory, d=d)
    
    Z_chunk = pairwise_distance_chunk(X_chunk, Y_chunk)
    
    memory = (X_chunk.nbytes + Y_chunk.nbytes + Z_chunk.nbytes)
    assert memory <= working_memory, (memory, working_memory)

In [None]:
for dtype, working_memory, d in itertools.product([np.float64],
                                              [2 ** i for i in range(5, 12)],
                                              [2, 5, 10, 100]):
    test_pairwise_chunk_on_dummy(dtype, working_memory, d)

In [None]:
create_dummy(np.float64, 1024, d=50)[0].shape