In [1]:
import numpy as np

In [None]:
def k_nearest_neighbors(x: np.dnarray, k: int) -> tuple[np.dnarray, np.dnarray]:
    differences=np.expand_dims(x, 0)- np.expand_dims(x, 1)
    # Or (differences ** 2).sum(axis=2) ** 0.5
    distances = np.linalg.norm(differences, axis=2)
    # (Bonus): Alternative distance method with N**2 memory, slightly lower compute.
    # squares = (x ** 2).sum(axis=1, keepdims=True)  # (N, 1)
    # distances = squares - 2 * np.matmul(x, x.transpose()) + squares.transpose()  # (N, N)
    # (semi-optional) Exclude the point itself. (If not included, the indexing below is [:, 1: k+1]).
    np.fill_diagonal(distances, float("infinity"))
    idx = np.argsort(distances, axis=1)[:, :k]
    # The equivalent of torch.gather. Ok if the candidate doesn't know this method and just use sort,
    # but they should know that is inefficient. (In pytorch the sort method returns both.)
    values = np.take_along_axis(distances, idx, axis=1)
    return values, idx

In [None]:
def k_nearest_neighbors(x: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.cdist(x,x).fill_diagonal_(float("infinity")).topk(3, dim=1, largest=False)

In [None]:
def k_nearest_neighbors(x, k):
    pass


x = np.array([[0.0, 1.0], [0.0, 2.0], [5.0, 3.0], [6.0, 1.0]])


values, idx = k_nearest_neighbors(x, 2)
print(f"Indices:\n{idx}\nDistances:\n{values}")

In [3]:
x = np.array([[0.0, 1.0], [0.0, 2.0], [5.0, 3.0], [6.0, 1.0]])


In [6]:
x

array([[0., 1.],
       [0., 2.],
       [5., 3.],
       [6., 1.]])

In [9]:
x[1:]


array([[0., 2.],
       [5., 3.],
       [6., 1.]])

In [11]:
np.zeros((0,4))

array([], shape=(0, 4), dtype=float64)

In [19]:
np.expand_dims(x, 1).shape


(4, 1, 2)

In [20]:
x

array([[0., 1.],
       [0., 2.],
       [5., 3.],
       [6., 1.]])

In [22]:
d = np.expand_dims(x, 0) - np.expand_dims(x, 1)
d

array([[[ 0.,  0.],
        [ 0.,  1.],
        [ 5.,  2.],
        [ 6.,  0.]],

       [[ 0., -1.],
        [ 0.,  0.],
        [ 5.,  1.],
        [ 6., -1.]],

       [[-5., -2.],
        [-5., -1.],
        [ 0.,  0.],
        [ 1., -2.]],

       [[-6.,  0.],
        [-6.,  1.],
        [-1.,  2.],
        [ 0.,  0.]]])

In [45]:
np.linalg.norm(d, axis=2, ord=-1)

  absx **= ord


array([[0.        , 0.        , 1.42857143, 0.        ],
       [0.        , 0.        , 0.83333333, 0.85714286],
       [1.42857143, 0.83333333, 0.        , 0.66666667],
       [0.        , 0.85714286, 0.66666667, 0.        ]])

In [41]:
np.linalg.norm(d, axis=2, keepdims=True)

array([[[0.        ],
        [1.        ],
        [5.38516481],
        [6.        ]],

       [[1.        ],
        [0.        ],
        [5.09901951],
        [6.08276253]],

       [[5.38516481],
        [5.09901951],
        [0.        ],
        [2.23606798]],

       [[6.        ],
        [6.08276253],
        [2.23606798],
        [0.        ]]])

In [39]:
d[0,1,:]

array([0., 1.])