# k-Nearest Neighbors

Let's quickly see how we might use argsort function along multiple axes to find the nearest neighbors of each point in a set. We'll start by creating a random set of 10 points on a two-dimensional plane. Using the standard convention, we'll arrange these in a 10×2 array:

In [None]:
import numpy as np

In [None]:
rand = np.random.RandomState(666)
X = rand.rand(10, 2)

To get an idea of how these points look, let's quickly scatter plot them:

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() # Plot styling
plt.scatter(X[:, 0], X[:, 1], s=100);

Now we'll compute the distance between each pair of points. Recall that the squared-distance between two points is the sum of the squared differences in each dimension; using the efficient broadcasting and aggregation  routines provided by NumPy we can compute the matrix of square distances in a single line of code:

In [None]:
dist_sq = np.sum((X[:, np.newaxis, :] - X[np.newaxis, :, :]) ** 2, axis=-1)

This operation has a lot packed into it, and it might be a bit confusing if you're unfamiliar with NumPy's broadcasting rules. When you come across code like this, it can be useful to break it down into its component steps:

In [None]:
# for each pair of points, compute differences in their coordinates
differences = X[:, np.newaxis, :] - X[np.newaxis, :, :]
differences.shape

In [None]:
# square the coordinate differences
sq_differences = differences ** 2
sq_differences.shape

In [None]:
# sum the coordinate differences to get the squared distance
dist_sq = sq_differences.sum(-1)
dist_sq.shape

Just to double-check what we are doing, we should see that the diagonal of this matrix (i.e., the set of distances between each point and itself) is all zero:

In [None]:
dist_sq.diagonal()

It checks out! With the pairwise square-distances converted, we can now use np.argsort to sort along each row. The leftmost columns will then give the indices of the nearest neighbors:

In [None]:
nearest = np.argsort(dist_sq, axis=1)
print(nearest)

In [None]:
K = 2
nearest_partition = np.argpartition(dist_sq, K + 1, axis=1)

In [None]:

plt.scatter(X[:, 0], X[:, 1], s=100)

# draw lines from each point to its two nearest neighbors
K = 2

for i in range(X.shape[0]):
    for j in nearest_partition[i, :K+1]:
        # plot a line from X[i] to X[j]
        # use some zip magic to make it happen:
        plt.plot(*zip(X[j], X[i]), color='black')

Each point in the plot has lines drawn to its two nearest neighbors. At first glance, it might seem strange that some of the points have more than two lines coming out of them: this is due to the fact that if point A is one of the two nearest neighbors of point B, this does not necessarily imply that point B is one of the two nearest neighbors of point A.

Although the broadcasting and row-wise sorting of this approach might seem less straightforward than writing a loop, it turns out to be a very efficient way of operating on this data in Python. You might be tempted to do the same type of operation by manually looping through the data and sorting each set of neighbors individually, but this would almost certainly lead to a slower algorithm than the vectorized version we used. The beauty of this approach is that it's written in a way that's agnostic to the size of the input data: we could just as easily compute the neighbors among 100 or 1,000,000 points in any number of dimensions, and the code would look the same.

Finally, I'll note that when doing very large nearest neighbor searches, there are tree-based and/or approximate algorithms that can scale as O/[NlogN/] or better rather than the O/[N2/] of the brute-force algorithm. One example of this is the KD-Tree, implemented in Scikit-learn.

In [None]:
import sklearn.neighbors

In [None]:
tree = sklearn.neighbors.KDTree(X, leaf_size=2) 
dist, ind = tree.query([X[0]], k=3)

print(ind)  # indices of 3 closest neighbors
print(dist)  # distances to 3 closest neighbors

In [None]:
plt.scatter(X[:, 0], X[:, 1], s=100)

for i in range(X.shape[0]):
    dist,ind = tree.query([X[i]], k=3)
    plt.plot(*zip(X[ind[0][1]], X[i]), color='black')
    plt.plot(*zip(X[ind[0][2]], X[i]), color='black')