In [None]:
import numpy as np
from matplotlib import pyplot as plt

**Choose K and N**

In [None]:
K = 3
N = 20

In [None]:
np.random.RandomState(seed=42)
points = np.random.random((2, N))
plt.scatter(points[0], points[1], alpha=.5)
for i in range(N):
    plt.text(points[0][i], points[1][i], str(i))
plt.show()

**Find pairwise distances**

In [None]:
# My method
dist_x = np.subtract.outer(points[0], points[0]) ** 2
dist_y = np.subtract.outer(points[1], points[1]) ** 2
dist = np.sqrt(dist_x + dist_y)

# `Data Science Handbook` method
dist = np.sqrt(np.sum((points[:, np.newaxis, :] - points[:, :, np.newaxis]) ** 2, axis=0))

dist.shape

**Find K nearest neighbours**

In [None]:
nearest = np.argpartition(dist, kth=K + 1, axis=0)[:K + 1, :].T
nearest


**Mini benchmark**

In [None]:
def pairwise_dist(points):
    x_dist = np.subtract.outer(points[0], points[0]) ** 2
    y_dist = np.subtract.outer(points[1], points[1]) ** 2
    z_dist = np.subtract.outer(points[2], points[2]) ** 2
    return np.sqrt(x_dist + y_dist + z_dist)


p = np.random.random((3, 10000))

%time dist2 = pairwise_dist(p)
%time dist = np.sqrt(np.sum((p[:, :, np.newaxis] - p[:, np.newaxis, :]) ** 2, axis=0))
