# KD-TREE: K Dimensional Trees

KD Trees are data structures for storing data by pationing k dimensional space using a tree structure for quicker nearest neighbor serches.


In [None]:
from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [None]:
NUM_POINTS = 20
K = 2

points_A = np.random.rand(K, NUM_POINTS)
plt.scatter(points_A[0,:], points_A[1,:])

In [None]:
class Node():
    """Class for Nodes in the KD Tree"""
    axis: int
    value: float
    left_child: Node | LeafNode
    right_child: Node | LeafNode

class LeafNode():
    """Leaf Node for KD Tree"""
    point: np.ndarray

In [None]:
base_node = Node()
work_list = [(base_node, points_A)]

## Tree Building Procedure

Tree building is an iterative process of:
- Choosing an *axis aligned* threshold
- partioning the points into the two child nodes.

This is continued until all points have been partioned into leaf nodes.

Building a tree with $N$ points has complexity $ \mathcal{O}(N\log{}N)$

In [None]:
def procss_node(parent_node: Node, points: np.ndarray, plot : bool = True):
    # We use the axis with the largest variance as heuristic
    # for choosing which axis to split.
    cov = np.cov(points)
    parent_node.axis = 0
    if cov[1,1] > cov[0,0]:
        parent_node.axis = 1

    # For split value, we choose the mean.
    # Something like median is would be more likely to produce a balanced tree.
    parent_node.value = np.mean(points[parent_node.axis, :])

    # For us, "left" child is less than the split value.
    # Here, we filter out the points for the left child node.
    left_list = points[:,points[parent_node.axis, :] < parent_node.value]
    # Then check if it should be a leaf node.
    if left_list.shape[1] > 1:
        parent_node.left_child = Node()
        work_list.append((parent_node.left_child, left_list))
    else:
        parent_node.left_child = LeafNode()
        parent_node.left_child.point = left_list

    # The same process is repeated on the right.
    right_list = points[:,points[parent_node.axis, :] > parent_node.value]
    if right_list.shape[1] > 1:
        parent_node.right_child = Node()
        work_list.append((parent_node.right_child, right_list))
    else:
        parent_node.right_child = LeafNode()
        parent_node.right_child.point = right_list

    # Plotting logic.
    if plot:
        fig = plt.figure()
        ax = fig.add_axes([0,0,1,1])
        ax.scatter(points_A[0,:], points_A[1, :], c='grey')
        ax.scatter(left_list[0,:], left_list[1,:], c='r')
        ax.scatter(right_list[0,:], right_list[1,:], c='b')
        if parent_node.axis == 0:
            ax.plot([parent_node.value, parent_node.value], [0, 1])
        else:
            ax.plot([0, 1], [parent_node.value, parent_node.value])


**The following cell is meant to be run iteratively**

In [None]:
parent_node, points = work_list.pop(0)
procss_node(parent_node, points, True)
print(len(work_list))

**This cell will finish of the work list without ploting**

In [None]:
while len(work_list) > 0:
    parent_node, points = work_list.pop(0)
    procss_node(parent_node, points, False)

## Tree Query

Searching a tree for query point $p_q$ is a simple procedure of decending the tree by checking if the query point is on the "left" or "right" side of the axis-aligned partion associated with each node. The nearest neighbor is found in the leaf node at the end of this procedure.

The complexity of searching for a nearest neighbor in a KD tree which contains $N$ points is $\mathcal{O}(\log{}N)$

In [None]:
query_point = np.random.rand(K, 1)
current_node = base_node
comparisons = 0

**The following cell is meant to be run iteratively.**

In [None]:
# Plotting setup.
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.scatter(points_A[0,:], points_A[1, :], c='grey')
ax.scatter(query_point[0,:], query_point[1,:], c='m')
if current_node.axis == 0:
    ax.plot([current_node.value, current_node.value], [0, 1])
else:
    ax.plot([0, 1], [current_node.value, current_node.value])

# Actual logic of querying, a simple comparison
comparisons += 1
if query_point[current_node.axis] < current_node.value:
    current_node = current_node.left_child
else:
    current_node = current_node.right_child

# Check if we have reached a leaf node, and therefore, are done.
if type(current_node) is LeafNode:
    found_point = current_node.point
    ax.scatter(found_point[0,:], found_point[1,:], c='g')
    ax.set_title('Done!')

In [None]:
print(comparisons)
print(np.log2(NUM_POINTS))