In [1]:
import numpy as np
import numba
import time
import scipy.spatial.distance

In [2]:

def pdist_numpy_naive(As: np.ndarray, Bs: np.ndarray) -> np.ndarray:
    (n, k) = As.shape
    (m, k2) = Bs.shape
    assert k == k2

    deltas = As[:, None, :] - Bs[None, :, :]  # n,m,k
    return np.linalg.norm(deltas, axis=2)


def pdist_numpy_hybrid(As: np.ndarray, Bs: np.ndarray) -> np.ndarray:
    (n, k) = As.shape
    (m, k2) = Bs.shape
    assert k == k2

    out = np.empty((n,m))
    for i, row in enumerate(As):
        out[i,:] = np.linalg.norm(Bs-row[np.newaxis,:], axis=1)
    return out


@numba.njit
def pdist_numba(As: np.ndarray, Bs: np.ndarray) -> np.ndarray:
    (n, k) = As.shape
    (m, k2) = Bs.shape
    assert k == k2

    res = np.empty((n, m), dtype=np.float64)
    for i in range(n):
        for j in range(m):
            res[i, j] = np.linalg.norm(As[i] - Bs[j])
    return res

In [3]:
n, m, k = 2001, 1001, 300
iterations = 10

a = np.random.random((n,k))
b = np.random.random((m,k))

t0 = time.time()
for _ in range(iterations):
    r0 = pdist_numpy_naive(a,b)
t1 = time.time()
print(f'naive numpy {(t1-t0)/iterations:.2f} s')


t0 = time.time()
for _ in range(iterations):
    r1 = pdist_numpy_hybrid(a,b)
t1 = time.time()
print(f'hybrid numpy {(t1-t0)/iterations:.2f} s')


pdist_numba(a,b) # run once to ensure compilation

t0 = time.time()
for _ in range(iterations):
    r2 = pdist_numba(a,b)
t1 = time.time()
print(f'numba {(t1-t0)/iterations:.2f} s')

t0 = time.time()
for _ in range(iterations):
    r4 = scipy.spatial.distance.cdist(a,b, 'euclidean')
t1 = time.time()
print(f'scipy {(t1-t0)/iterations:.2f} s')

naive numpy 4.18 s
hybrid numpy 3.91 s
numba 1.10 s
scipy 0.69 s


# cython

In [4]:

cython_str = '''
# cython: language_level=3
# distutils: language=c
# cython: cpp_locals=True
# cythhon: binding=False
# cython: infer_types=False
# cython: wraparound=False
# cython: boundscheck=False
# cython: cdivision=True
# cython: overflowcheck=False
# cython: nonecheck=False
# cython: initializedcheck=False
# cython: always_allow_keywords=False
# cython: c_api_binop_methods=True
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION


# pdist_cython.pyx

import numpy as np
cimport numpy as npc
from libc.math cimport sqrt
#from cython.parallel import prange

import cython
cimport cython


cdef double[:,:,] pdist_cython(const double[:,:,] As, const double[:,:,] Bs) except *:

    cdef unsigned int n,m,k,k2,i,j,l
    cdef double tmp
    cdef npc.ndarray[double, ndim=2] res
    
    n = As.shape[0]
    k = As.shape[1]
    m = Bs.shape[0]
    #assert k == k2
    
    res = np.empty((n, m))
    
    for i in range(0, n):
        for j in range(0, m):
            #res[i, j] = np.linalg.norm(As[i] - Bs[j])
            tmp = 0
            for l in range(0, k):
                tmp += (As[i,l] - Bs[j,l])**2
            res[i,j] = sqrt(tmp)
    return res
    
cpdef double[:,:,] main(double[:,:,] a, double[:,:,] b) except *:
    return pdist_cython(a, b)
'''
with open('pdist_cython.pyx','w') as f:
    f.write(cython_str)


In [5]:

# setup.py

setup_py_cython_str = '''from distutils.core import setup
from Cython.Build import cythonize
import numpy 
from setuptools import setup, Extension
import platform
include = [numpy.get_include()]

if platform.system() == 'Windows':
    args = ['/O2', '/fp:fast', '/Qfast_transcendentals']
    # above args are for MSVC compiler (windows)
elif platform.system() == 'Linux':
    # below args are for GCC
    args = ['-O3', '-ffast-math']# , '-fast-transcendentals' <- requires intel compilers
else:
    raise OSError
    
setup(
    ext_modules = cythonize([Extension('pdist_cython', sources=['pdist_cython.pyx'], include_dirs=include, extra_compile_args=args)], 
        compiler_directives={'language_level' : "3"},
        ),
    zip_safe=False, 
    )
'''

with open('cython_setup.py','w') as f:
    f.write(setup_py_cython_str)

import subprocess
import sys
subprocess.run([sys.executable, "cython_setup.py", "build_ext", "--inplace"]) 

#import example_cython

#print(example_cython.test(5)    )
    
import pdist_cython
#importlib.reload(pdist_cython)

In [6]:
t0 = time.time()
for _ in range(iterations):
    r3 = pdist_cython.main(a,b)
t1 = time.time()
print(f'cython {(t1-t0)/iterations:.2f} ')
#print(r3-r1)

cython 0.75 


# C

In [7]:
import platform 
import ctypes
import numpy as np
import setuptools._distutils.ccompiler
def compile_and_link_c_string(s, c_file_name = 'test.c'):
    compiler = setuptools._distutils.ccompiler.new_compiler()

    if platform.system() == 'Windows':
        predef = 'extern __declspec(dllexport) '
        s = s.replace('PREDEF', predef)
        with open(c_file_name,'w') as f:
            f.write(s)
        args = ['/O2', '/fp:fast', '/Qfast_transcendentals']
        objects = compiler.compile([c_file_name], extra_postargs = args)
        compiler.link_shared_lib(objects, c_file_name[:-2])
        # windwos needs absolute path
        import os
        abspath = os.path.abspath('pdist.c')
        so_file_name = abspath[:-2]+'.dll'

    elif platform.system() == 'Linux':
        s = s.replace('PREDEF ', '')
        with open(c_file_name,'w') as f:
            f.write(s)
            
        args = ['-O3', '-ffast-math']# , '-fast-transcendentals' <- requires intel compilers
        objects = compiler.compile([c_file_name], extra_postargs = args)
        so_file_name = c_file_name[:-2]+'.so'
        compiler.link(compiler.SHARED_LIBRARY, objects, so_file_name)

    else:
        raise OSError

    c_lib = ctypes.CDLL( so_file_name )
    return c_lib
    

In [8]:
c_code_string = '''
// test.c
#include <stdio.h>
#include <math.h>


void main(){
}


PREDEF void pdist(int n, int m, int k, const double * A, const double * B, double * outdata) {


    int i, j, l, p, q;
    double tmp, tmp2;
    for (i = 0; i < n ; ++i) {
        p = i*k;
        for (j = 0; j < m; ++j){
            q = j*k;
            tmp = 0;
            for (l = 0; l < k; l++){
                 tmp2 = A[p+l]-B[q+l];
                 tmp += tmp2*tmp2;
            }
            outdata[i*m+j] = sqrt(tmp);
        }
    }
}

'''

c_lib = compile_and_link_c_string(c_code_string, c_file_name = 'pdist.c')

pdist_c = c_lib.pdist
pdist_c.restype = None
pdist_c.argtypes = [ctypes.c_int,
                    ctypes.c_int,
                    ctypes.c_int,
                    np.ctypeslib.ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
                    np.ctypeslib.ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
                    np.ctypeslib.ndpointer(ctypes.c_double, flags="C_CONTIGUOUS")]

   

In [9]:

t0 = time.time()
for _ in range(iterations):
    r5 = np.empty((a.shape[0], b.shape[0]))
    pdist_c(a.shape[0], b.shape[0], a.shape[1], a, b, r5)
t1 = time.time()
print(f'c {(t1-t0)/iterations:.2f} s')


c 0.23 s


# C extension

In [10]:
c_extension_code_string = '''
#define PY_SSIZE_T_CLEAN
#include "Python.h"

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "numpy/ndarraytypes.h"
#include "numpy/npy_3kcompat.h"

#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))

#if defined(__OPTIMIZE__)

/* LIKELY() definition */
/* Checks taken from
    https://github.com/python/cpython/blob/main/Objects/obmalloc.c */
#if defined(__GNUC__) && (__GNUC__ > 2)
#  define UNLIKELY(value) __builtin_expect((value), 0)
#else
#  define UNLIKELY(value) (value)
#endif

#else

#define UNLIKELY(value) (value)

#endif /* __OPTIMIZE__ */

PyArray_Descr *DOUBLE_Descr;

Py_LOCAL_INLINE(void)
pdist_cext_impl(const char *A,
                const char *B,
                char *out,
                const size_t n,
                const size_t k,
                const size_t m,
                const size_t sk,
                const size_t sm)
{
    size_t i, j, l;
    double tmp, tmp2;
    const double *Adata, *Bdata;
    double *outdata;

    for (i = 0; i < n; ++i) {
        Adata = (const double *)(A + i * sk);
        outdata = (double *)(out + i * sm);
        for (j = 0; j < m; ++j) {
            Bdata = (double *)(B + j * sk);
            tmp = 0;
            for (l = 0; l < k; ++l) {
                tmp2 = Adata[l] - Bdata[l];
                tmp += tmp2 * tmp2;
            }
            outdata[j] = sqrt(tmp);
        }
    }
}

static PyObject *
pdist_cext_main(PyObject *Py_UNUSED(_),
                PyObject *const *args,
                const Py_ssize_t nargs)
{
    PyArrayObject *As = NULL, *Bs = NULL, *res;
    size_t n, m, k;

    if (UNLIKELY(nargs != 2)) {
        PyErr_Format(PyExc_TypeError,
                     "pdist_cext.main() expected exactly 2 positional "
                     "arguments, got %zd",
                     nargs);
        return NULL;
    }

    Py_INCREF(DOUBLE_Descr);
    As = (PyArrayObject *)PyArray_CheckFromAny(
        args[0], DOUBLE_Descr,
        2, 2,
        NPY_ARRAY_CARRAY_RO | NPY_ARRAY_NOTSWAPPED,
        NULL);
    if (UNLIKELY(As == NULL)) return NULL;
    Py_INCREF(DOUBLE_Descr);
    Bs = (PyArrayObject *)PyArray_CheckFromAny(
        args[1], DOUBLE_Descr,
        2, 2,
        NPY_ARRAY_CARRAY_RO | NPY_ARRAY_NOTSWAPPED,
        NULL);
    if (UNLIKELY(Bs == NULL)) goto error;

    n = PyArray_DIM(As, 0);
    m = PyArray_DIM(Bs, 0);
    const npy_intp dims[2] = {n, m};
    Py_INCREF(DOUBLE_Descr);
    res = (PyArrayObject *)PyArray_NewFromDescr(
        &PyArray_Type, DOUBLE_Descr, 2,
        dims, NULL, NULL, 0, NULL
    );
    if (UNLIKELY(res == NULL)) goto error;

    k = PyArray_DIM(As, 1);

    pdist_cext_impl(PyArray_BYTES(As),
                    PyArray_BYTES(Bs),
                    PyArray_BYTES(res),
                    n, k, m,
                    sizeof(double) * k,
                    sizeof(double) * m);
    Py_DECREF(As);
    Py_DECREF(Bs);

    return (PyObject *)res;
  error:
    Py_XDECREF(As);
    Py_XDECREF(Bs);
    return NULL;
}

static PyMethodDef pdist_cext_methods[] = {
    {"main", (PyCFunction)pdist_cext_main, METH_FASTCALL, ""},
    {NULL}
};

static struct PyModuleDef pdist_cextmodule = {
    PyModuleDef_HEAD_INIT,
    "pdist_cext",
    "",
    0,
    pdist_cext_methods,
    NULL,
    NULL,
    NULL,
    NULL
};

PyMODINIT_FUNC
PyInit_pdist_cext(void)
{
    import_array();

    DOUBLE_Descr = PyArray_DescrFromType(NPY_DOUBLE);

    return PyModule_Create(&pdist_cextmodule);
}
'''

with open('pdist_cext.c','w') as f:
    f.write(c_extension_code_string)

In [11]:

setup_py_cython_str = '''from distutils.core import setup
from Cython.Build import cythonize
import numpy 
from setuptools import setup, Extension
import platform
include = [numpy.get_include()]

if platform.system() == 'Windows':
    args = ['/O2', '/fp:fast', '/Qfast_transcendentals']
    print(args)
    # above args are for MSVC compiler (windows)
elif platform.system() == 'Linux':
    # below args are for GCC
    args = ['-O3', '-ffast-math']# , '-fast-transcendentals' <- requires intel compilers
else:
    raise OSError
    
setup(
    ext_modules = [Extension('pdist_cext', sources=['pdist_cext.c'], include_dirs=include, extra_compile_args=args)],
    zip_safe=False, 
    )
'''
with open('cext_setup.py','w') as f:
    f.write(setup_py_cython_str)

import subprocess
import sys
subprocess.run([sys.executable, "cext_setup.py", "build_ext", "--inplace"]) 
import pdist_cext

In [12]:
t0 = time.time()
for _ in range(iterations):
    r6 =pdist_cext.main(a,b)
t1 = time.time()
print(f'c extension {(t1-t0)/iterations:.2f} s')

c extension 0.26 s
