Simple Python implementation of all-nearest-neighbours 

In [1]:
from scipy.spatial import Delaunay, cKDTree
import pandas as pd
import numpy as np

In [2]:
%load_ext Cython

In [3]:
# Load two example neurons as point clouds
x = pd.read_csv(
    "/Users/philipps/Downloads/720575940638426064.swc",
    comment="#",
    header=None,
    sep=" ",
)[[2, 3, 4]].values
y = pd.read_csv(
    "/Users/philipps/Downloads/720575940613656978.swc",
    comment="#",
    header=None,
    sep=" ",
)[[2, 3, 4]].values

x.shape, y.shape

((127454, 3), (22992, 3))

In [4]:
%%time

# Calculate the delaunay triangulation
x = Delaunay(x)
y = Delaunay(y)

CPU times: user 6.24 s, sys: 377 ms, total: 6.61 s
Wall time: 6.69 s


In [5]:
# Check if any vertices were dropped
x.coplanar, y.coplanar

(array([], shape=(0, 3), dtype=int32), array([], shape=(0, 3), dtype=int32))

In [6]:
y.vertex_neighbor_vertices[0].shape, y.vertex_neighbor_vertices[1].shape

((22993,), (379246,))

In [5]:
# Neighbors of vertex k
k = 5132
ix1 = y.vertex_neighbor_vertices[0][k]
ix2 = y.vertex_neighbor_vertices[0][k + 1]
y.vertex_neighbor_vertices[1][ix1:ix2]

array([5205, 5543, 5560, 5129, 4873, 5071, 5136, 5140, 5135, 5124, 5126,
       5202, 5141, 5122, 5484, 5137, 5128, 5159, 5156, 5158, 5495, 5482,
       5476], dtype=int32)

%%cython --force --annotate

In [42]:
!export CFLAGS="-O3"

In [50]:
%%cython --force --annotate

# distutils: language = c++

from libcpp cimport bool as bool_t
from cython.parallel import prange
from libcpp.vector cimport vector
import numpy as np
cimport numpy as cnp
cimport cython
import ctypes


@cython.boundscheck(False)
@cython.wraparound(False)
def all_nearest_neighbours(x, y):
    """Find all nearest neighbors between two delaunay triangulations.

    Parameters
    ----------
    x : scipy.spatial.Delaunay
        Delaunay triangulation of the first point cloud.
    y : scipy.spatial.Delaunay

    Returns
    -------
    dists : np.ndarray
        Array of distances between the nearest neighbors.
    ind : np.ndarray
        Array of indices of the nearest neighbors in y for each vertex in x.
    """
    # Run a depth first search starting from each vertex in x
    # to find the nearest neighbor in y
    dists = np.zeros(len(x.points), dtype='double')
    ind = np.zeros(len(x.points), dtype='long')
    seen = np.zeros(len(x.points), dtype="long")

    cdef double[::1] dists_view = dists
    cdef long[::1] ind_view = ind
    cdef long[::1] seen_view = seen

    # Convert Delaunays to structs so we can pass them to cython functions
    cdef double[:, ::1] xpoints = x.points
    cdef double[:, ::1] ypoints = y.points
    cdef long[::1] xindices = x.vertex_neighbor_vertices[0].astype('long', copy=False)
    cdef long[::1] yindices = y.vertex_neighbor_vertices[0].astype('long', copy=False)
    cdef long[::1] xneighbors = x.vertex_neighbor_vertices[1].astype('long', copy=False)
    cdef long[::1] yneighbors = y.vertex_neighbor_vertices[1].astype('long', copy=False)

    # Start with the first vertex in the first simplex
    cdef long vertex = 0
    cdef double d
    cdef long i, ix, n
    cdef double[::1] p_start = xpoints[vertex]
    cdef long[::1] visited = np.zeros(len(ypoints), dtype="long")
    cdef long nvisited = len(visited)
    d, ix = _find_nearest_neighbor(ypoints, yindices, yneighbors, p_start, 0, visited, nvisited)
    dists_view[vertex] = d
    ind_view[vertex] = ix
    seen_view[vertex] = 1

    # Make a stack with vertices to visit and with the closest vertex in y
    cdef vector[(long, long)] stack
    cdef long start = xindices[vertex]
    cdef long end = xindices[vertex + 1]
    for i in range(start, end):
        n = xneighbors[i]
        stack.push_back((ix, n))

    while stack.size():
        ix, n = stack.back()
        stack.pop_back()
        # Skip if we already found the nearest neighbor for this vertex
        if seen_view[n] > 0:
            continue
        d, ix = _find_nearest_neighbor(ypoints, yindices, yneighbors, xpoints[n], ix, visited, nvisited)
        # print(n, ix, d)
        dists_view[n] = d
        ind_view[n] = ix
        seen_view[n] = 1

        start = xindices[n]
        end = xindices[n + 1]
        for i in range(start, end):
            n = xneighbors[i]
            stack.push_back((ix, n))

    return dists, ind


@cython.boundscheck(False)
@cython.wraparound(False)
cdef long[::1] get_neighbors(long[::1] indices, long[::1] neighbors, long vertex):
    """Find the neighbours of vertex in delaunay triangulation x."""
    return neighbors[indices[vertex]:indices[vertex+1]]


@cython.boundscheck(False)
@cython.wraparound(False)
cdef (double, long) _find_nearest_neighbor(double[:, ::1] points, long[::1] indices, long[::1] neighbors, double[::1] p, long vstart, long[::1] visited, long nvisited):
    """Find the approximate nearest neighbor of point p among the points in y."""
    # If no start defined, use the first point in the first simplex
    cdef long vert = vstart

    # Track which points we have already visited
    cdef long i
    for i in range(nvisited):
        visited[i] = 0
    visited[vert] = 1

    # Get the distance between p and the starting point
    # d = ((y.points[vert] - p) ** 2).sum()
    # Oddly enough, this is slower than the following:
    cdef double[::1] p2 = points[vert]
    cdef double d2
    cdef double d = (p2[0] - p[0]) * (p2[0] - p[0]) + (p2[1] - p[1]) * (p2[1] - p[1]) + (p2[2] - p[2]) * (p2[2] - p[2])
    cdef long vert_new, n
    cdef long start, end
    while True:
        # Figure out if any of the neighbors is closer to p than the current vertex
        # If so, update the distance and the current vertex
        vert_new = -1
        start = indices[vert]
        end = indices[vert + 1]
        for i in range(start, end):
            n = neighbors[i]
            if visited[n] == 0:
                p2 = points[n]
                d2 = (p2[0] - p[0]) * (p2[0] - p[0]) + (p2[1] - p[1]) * (p2[1] - p[1]) + (p2[2] - p[2]) * (p2[2] - p[2])
                visited[n] = 1
                if d2 < d:
                    d = d2
                    vert_new = n
        # If none of the neighbours is closer, we are done
        if vert_new == -1:
            break

        # Otherwise, update the current vertex
        vert = vert_new

    return d**(1/2), vert

Content of stderr:
In file included from /Users/philipps/.cache/ipython/cython/_cython_magic_16f2f3f05801cd62597a6fef7742e2a95b1b9074.cpp:1243:
In file included from /Users/philipps/.pyenv/versions/3.9.9/lib/python3.9/site-packages/numpy/core/include/numpy/arrayobject.h:5:
In file included from /Users/philipps/.pyenv/versions/3.9.9/lib/python3.9/site-packages/numpy/core/include/numpy/ndarrayobject.h:12:
In file included from /Users/philipps/.pyenv/versions/3.9.9/lib/python3.9/site-packages/numpy/core/include/numpy/ndarraytypes.h:1929:
 ^
static __Pyx_memviewslice __pyx_f_54_cython_magic_16f2f3f05801cd62597a6fef7742e2a95b1b9074_get_neighbors(__Pyx_memviewslice __pyx_v_indices, __Pyx_memviewslice __pyx_v_neighbors, long __pyx_v_vertex) {
                          ^

In [7]:
# Check against scipy's KDTree
tree = cKDTree(y.points)

In [51]:
%%timeit

# Run full search
dist, ind = all_nearest_neighbours(x, y)
dist, ind

581 ms ± 7.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [40]:
%%timeit

# Check against scipy's KDTree
dist2, ind2 = tree.query(x.points)
dist2, ind2

1.48 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [41]:
# How often do the results differ?
(ind != ind2).sum()

0