In [2]:
import numpy as np

class KDNode:
    def __init__(self, point, axis, left=None, right=None):
        """
        point : numpy array of shape (d,)
            The data point stored at this node
        axis : int
            Dimension index used to split at this node
        left, right : KDNode or None
            Subtrees
        """
        self.point = point
        self.axis = axis
        self.left = left
        self.right = right


In [3]:
def build_kdtree(points):
    """
    Build a KD-tree using variance-based axis selection.

    Parameters
    ----------
    points : np.ndarray of shape (n, d)

    Returns
    -------
    KDNode or None
    """

    # Base case: no points → empty subtree
    if len(points) == 0:
        return None

    # -----------------------------
    # 1. Choose split axis
    # -----------------------------
    # Compute variance along each dimension
    variances = np.var(points, axis=0)

    # Pick dimension with maximum variance
    axis = np.argmax(variances)

    # -----------------------------
    # 2. Split by median on that axis
    # -----------------------------
    # Sort points along chosen axis
    points = points[points[:, axis].argsort()]

    # Choose median for balanced tree
    median = len(points) // 2

    # -----------------------------
    # 3. Recursively build subtrees
    # -----------------------------
    return KDNode(
        point=points[median],                     # store median point
        axis=axis,                                # store chosen axis
        left=build_kdtree(points[:median]),       # points left of median
        right=build_kdtree(points[median + 1:])   # points right of median
    )


In [4]:
def euclidean_distance(a, b):
    """
    Compute Euclidean distance between two points.
    """
    return np.linalg.norm(a - b)


In [5]:
def nearest_neighbor(root, query):
    """
    Exact nearest neighbor search in a KD-tree.

    Parameters
    ----------
    root : KDNode
        Root of KD-tree
    query : np.ndarray of shape (d,)

    Returns
    -------
    (best_point, best_distance)
    """

    # Dictionary so inner function can modify it
    best = {
        "point": None,
        "distance": np.inf
    }

    def search(node):
        # Stop if subtree is empty
        if node is None:
            return

        # ----------------------------------
        # 1. Update best using current node
        # ----------------------------------
        dist = euclidean_distance(query, node.point)

        if dist < best["distance"]:
            best["distance"] = dist
            best["point"] = node.point

        # ----------------------------------
        # 2. Decide which subtree to explore
        # ----------------------------------
        axis = node.axis

        # Signed distance to splitting plane
        diff = query[axis] - node.point[axis]

        # Near side = same side as query
        if diff < 0:
            near_branch = node.left
            far_branch = node.right
        else:
            near_branch = node.right
            far_branch = node.left

        # ----------------------------------
        # 3. Explore near subtree first
        # ----------------------------------
        search(near_branch)

        # ----------------------------------
        # 4. Pruning decision
        # ----------------------------------
        # Distance from query to splitting plane
        if abs(diff) < best["distance"]:
            # The hypersphere intersects the plane → must check other side
            search(far_branch)
        # Else: prune safely

    search(root)
    return best["point"], best["distance"]


In [None]:
points = np.array([
    [2, 3],
    [5, 4],
    [9, 6],
    [4, 7],
    [8, 1],
    [7, 2]
])

tree = build_kdtree(points)

query = np.array([6, 3])

nn_point, nn_dist = nearest_neighbor(tree, query)

print("Nearest neighbor:", nn_point)
print("Distance:", nn_dist)


In [8]:
np.random.seed(0)

points_50d = np.random.randn(2000, 50)
tree_50d = build_kdtree(points_50d)

q = np.random.randn(50)
p, d = nearest_neighbor(tree_50d, q)
print(p, d)


[ 0.87367517  0.11120359  0.02508699 -0.37722032 -1.62224404 -0.58414951
 -0.46991077 -1.36653859 -0.20264593  1.06346497 -1.04884323  0.27999557
  0.52283453  0.5676972  -0.71771419  0.83436432 -1.14480875  0.43938418
  0.51037672  1.51041356  0.02951737 -0.82433441 -0.78282304 -0.65967711
  0.615348   -0.27087908 -1.95870751 -0.37965558  0.24155096  0.5919257
 -1.10467752  0.18456624  0.21222466  0.61195845 -0.32595815 -0.3889916
 -0.76392508  0.60780984 -0.06674408 -0.30670771  0.53243852 -0.44536562
 -1.13434114 -0.57432996 -0.21513013 -0.24562675  1.03011492  0.48165894
  0.27012592  0.092487  ] 6.816577142965922


In [9]:
def k_nearest_neighbors(root, query, k):
    """
    Exact k-nearest neighbors search in a KD-tree.

    Parameters
    ----------
    root : KDNode
        Root of the KD-tree
    query : ndarray (d,)
        Query point
    k : int
        Number of neighbors to return

    Returns
    -------
    List of (distance, point), sorted by distance
    """

    # List of current best neighbors:
    # stored as (distance, point)
    best = []

    def search(node):
        if node is None:
            return

        # ----------------------------------
        # 1. Compute distance to current node
        # ----------------------------------
        dist = euclidean_distance(query, node.point)

        # ----------------------------------
        # 2. Insert into best list if relevant
        # ----------------------------------
        if len(best) < k:
            # Still fewer than k neighbors → always insert
            best.append((dist, node.point))
            best.sort(key=lambda x: x[0])  # sort by distance
        elif dist < best[-1][0]:
            # Replace the worst neighbor
            best[-1] = (dist, node.point)
            best.sort(key=lambda x: x[0])

        # ----------------------------------
        # 3. Choose which subtree to explore
        # ----------------------------------
        axis = node.axis
        diff = query[axis] - node.point[axis]

        if diff < 0:
            near_branch = node.left
            far_branch = node.right
        else:
            near_branch = node.right
            far_branch = node.left

        # ----------------------------------
        # 4. Explore near subtree first
        # ----------------------------------
        search(near_branch)

        # ----------------------------------
        # 5. Pruning condition (CRITICAL)
        # ----------------------------------
        # If we have fewer than k points,
        # we MUST explore the other side
        if len(best) < k:
            search(far_branch)
        else:
            # Distance to splitting plane
            plane_dist = abs(diff)

            # Worst (largest) distance among k neighbors
            worst_dist = best[-1][0]

            # If hypersphere intersects splitting plane,
            # the other side may contain closer points
            if plane_dist < worst_dist:
                search(far_branch)
            # Else: prune safely

    search(root)
    return best


In [10]:
points = np.array([
    [2, 3],
    [5, 4],
    [9, 6],
    [4, 7],
    [8, 1],
    [7, 2]
])

tree = build_kdtree(points)

query = np.array([6, 3])

neighbors = k_nearest_neighbors(tree, query, k=3)

for d, p in neighbors:
    print("Point:", p, "Distance:", d)


Point: [7 2] Distance: 1.4142135623730951
Point: [5 4] Distance: 1.4142135623730951
Point: [8 1] Distance: 2.8284271247461903


In [12]:
np.random.seed(0)

points_20d = np.random.randn(3000, 20)
tree_20d = build_kdtree(points_20d)

q = np.random.randn(20)

neighbors = k_nearest_neighbors(tree_20d, q, k=1)
neighbors

[(np.float64(3.2776133983743128),
  array([-0.39693289, -0.78583212, -1.01966376,  0.12692571,  1.06215911,
          1.53068776, -1.51898606,  0.92693751,  1.23353109, -0.4745801 ,
         -0.65154747,  0.79365415, -1.3056499 ,  0.28147124,  1.44100002,
          0.80734963, -1.8721453 ,  0.11845893,  0.89329775, -0.53170016]))]