Simple Python implementation of all-nearest-neighbours 

In [81]:
from scipy.spatial import Delaunay, cKDTree
import pandas as pd
import numpy as np

In [9]:
# Load two example neurons as point clouds
x = pd.read_csv(
    "/Users/philipps/Downloads/720575940638426064.swc",
    comment="#",
    header=None,
    sep=" ",
)[[2, 3, 4]].values
y = pd.read_csv(
    "/Users/philipps/Downloads/720575940613656978.swc",
    comment="#",
    header=None,
    sep=" ",
)[[2, 3, 4]].values

x.shape, y.shape

((127454, 3), (22992, 3))

In [10]:
%%time

# Calculate the delaunay triangulation
x = Delaunay(x)
y = Delaunay(y)

CPU times: user 6.55 s, sys: 324 ms, total: 6.87 s
Wall time: 6.97 s


In [14]:
# Check if any vertices were dropped
x.coplanar, y.coplanar

(array([], shape=(0, 3), dtype=int32), array([], shape=(0, 3), dtype=int32))

In [15]:
x.simplices

array([[43494, 43858, 38206, 38015],
       [ 6756,  6796,  6654,  6653],
       [ 6756,  6242,  6654,  6653],
       ...,
       [35960, 35800, 35801, 33228],
       [42396, 40406, 39176, 39058],
       [42396, 40406, 39176, 39287]], dtype=int32)

In [16]:
y.simplices

array([[ 5131,  5118, 21107,  6117],
       [ 5100,  5118, 21107,  6117],
       [ 5100,  5131,  5118,  6117],
       ...,
       [ 8284,  8304,  8485,  8432],
       [ 8284,  8273,  8391,  8432],
       [ 8284,  8304,  8391,  8432]], dtype=int32)

In [23]:
y.vertex_neighbor_vertices[0].shape, y.vertex_neighbor_vertices[1].shape

((22993,), (379246,))

In [33]:
# Neighbors of vertex k
k = 5132
ix1 = y.vertex_neighbor_vertices[0][k]
ix2 = y.vertex_neighbor_vertices[0][k + 1]
y.vertex_neighbor_vertices[1][ix1:ix2]

array([5205, 5543, 5560, 5129, 4873, 5071, 5136, 5140, 5135, 5124, 5126,
       5202, 5141, 5122, 5484, 5137, 5128, 5159, 5156, 5158, 5495, 5482,
       5476], dtype=int32)

In [120]:
i = 0


def all_nearest_neighbours(x, y):
    """
    Find all nearest neighbors between two delaunay triangulations.
    """
    # Run a depth first search starting from each vertex in x
    # to find the nearest neighbor in y
    dists = np.zeros(len(x.points))
    ind = np.zeros(len(x.points), dtype=int)
    n_checked = np.zeros(len(x.points))

    # Start with the first vertex in the first simplex
    vertex = x.simplices[0][0]
    d, ix, i = _find_nearest_neighbour(y, x.points[vertex], start=None)
    dists[vertex] = d
    dists[vertex] = ix
    n_checked[vertex] = i

    # Make a stack with vertices to visit and with the closest vertex in y
    stack = [(ix, get_neighbours(x, vertex))]
    while stack:
        ix, neighbors = stack.pop()
        for n in neighbors:
            # Skip if we already found the nearest neighbor for this vertex
            if dists[n] > 0:
                continue
            d, ix, i = _find_nearest_neighbour(y, x.points[n], start=ix)
            dists[n] = d
            ind[n] = ix
            n_checked[n] = i
            stack.append((ix, get_neighbours(x, n)))

    return dists, ind, n_checked


def get_neighbours(x, vertex):
    """Find the neighbours of vertex in delaunay triangulation x."""
    ix1 = x.vertex_neighbor_vertices[0][vertex]
    ix2 = x.vertex_neighbor_vertices[0][vertex + 1]
    return x.vertex_neighbor_vertices[1][ix1:ix2]


def _find_nearest_neighbour(y, p, start=None):
    """Find the approximate nearest neighbor of point p among the points in y."""
    # If no start defined, use the first point in the first simplex
    if start is None:
        vert = y.simplices[0][0]
    else:
        vert = start

    # Track which points we have already visited
    visited = np.zeros(len(y.points), dtype=bool)
    visited[vert] = True

    # Get the distance between p and the starting point
    # d = ((y.points[vert] - p) ** 2).sum()
    # Oddly enough, this is slower than the following:
    yp = y.points[vert]
    d = sum(((yp[0] - p[0]) ** 2, (yp[1] - p[1]) ** 2, (yp[2] - p[2]) ** 2))
    i = 0
    while True:
        # Get the neighbours of the current vertex
        neighbors = get_neighbours(y, vert)

        # Figure out if any of the neighbors is closer to p than the current vertex
        # If so, update the distance and the current vertex
        vert_new = None
        for n in neighbors:
            if not visited[n]:
                # d2 = ((y.points[n] - p) ** 2).sum()
                yp = y.points[n]
                d2 = sum(
                    ((yp[0] - p[0]) ** 2, (yp[1] - p[1]) ** 2, (yp[2] - p[2]) ** 2)
                )
                visited[n] = True
                if d2 < d:
                    d = d2
                    vert_new = n
        # If none of the neighbours is closer, we are done
        if vert_new is None:
            break

        # Otherwise, update the current vertex and mark it as visited
        vert = vert_new
        # visited[vert] = True

        i += 1

    return np.sqrt(d), vert, i

In [121]:
%%time

# Test the A* search for a single point
_find_nearest_neighbour(y, x.points[100000])

CPU times: user 533 µs, sys: 8 µs, total: 541 µs
Wall time: 541 µs


(52422.675732930686, 12455, 5)

In [105]:
# Check against scipy's KDTree
tree = cKDTree(y.points)

In [112]:
%%time

tree.query(x.points[100000])

CPU times: user 570 µs, sys: 142 µs, total: 712 µs
Wall time: 670 µs


(52422.675732930686, 12455)

In [122]:
%%time

# Run full search
dist, ind, i = all_nearest_neighbours(x, y)
dist, ind, i

CPU times: user 17.3 s, sys: 102 ms, total: 17.4 s
Wall time: 17.4 s


(array([105146.42337236,  40647.4728485 ,  41286.3492668 , ...,
         90605.82305969,  90369.69973375,  90338.68600993]),
 array([12687,  6089,  6089, ..., 17557, 17557, 17557]),
 array([0., 0., 1., ..., 1., 0., 0.]))

In [115]:
%%time

# Check against scipy's KDTree
dist2, ind2 = tree.query(x.points)
dist2, ind2

CPU times: user 2.07 s, sys: 25.1 ms, total: 2.1 s
Wall time: 2.12 s


(array([105146.42337236,  40647.4728485 ,  41286.3492668 , ...,
         90605.82305969,  90369.69973375,  90338.68600993]),
 array([12687,  6089,  6089, ..., 17557, 17557, 17557]))

In [117]:
# How often do the results differ?
(ind != ind2).sum()

1

In [79]:
%load_ext line_profiler

In [119]:
%lprun -f _find_nearest_neighbour all_nearest_neighbours(x, y)

Timer unit: 1e-09 s

Total time: 53.5217 s
File: /var/folders/b1/1fbq04gx1vg344_ctkmv52d00000gn/T/ipykernel_34341/3180897509.py
Function: _find_nearest_neighbour at line 45

Line #      Hits         Time  Per Hit   % Time  Line Contents
    45                                           def _find_nearest_neighbour(y, p, start=None):
    46                                               """Find the approximate nearest neighbor of point p among the points in y."""
    47                                               # If no start defined, use the first point in the first simplex
    48    127454   69093000.0    542.1      0.1      if start is None:
    49         1       3000.0   3000.0      0.0          vert = y.simplices[0][0]
    50                                               else:
    51    127453   66722000.0    523.5      0.1          vert = start
    52                                           
    53                                               # Track which points we have alrea