<a href="https://colab.research.google.com/github/yexf308/AdvancedMachineLearning/blob/main/Fast_KNN_kd_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

reference [here](https://www.cs.cmu.edu/~awm/) and [here](https://courses.cs.washington.edu/courses/cse547). 

In [5]:
from IPython.display import Image


$\def\m#1{\mathbf{#1}}$
$\def\mb#1{\mathbb{#1}}$
$\def\c#1{\mathcal{#1}}$
# The Nearest Neighbor Problem
We have $N$ points, $\{\m{x}^{(1)},\dots, \m{x}^{(N)}\} \subset \mathcal{X}$ (possibly in $D$ dimensions or possibly more abstract). We also have a
distance function $d(\m{x},\m{y})$ on our points. 

**Goal:** Find all pairs of data points $(\m{x}^{(i)}, \m{x}^{(j)})$ that are within distance threshold $d(\m{x}^{(i)},\m{x}^{(j)})\le \epsilon$. Naïvely, we would have to compute pairwise
similarities for every pair. It will take $O(N^2)$. It is too slow!! We want to have $O(N)$.

Or given some new point $\m{y}$ we would like to find either: 1) an exact nearest
neighbor or 2) a point in our $\m{x} \in \mathcal{X}$ that is “close” to $\m{y}$ (Voronoi diagrams). Naively, given a new point finding its nearest neighbor would take us $O(N)$. We want to have $O(\log(N))$



In [6]:
display(Image(url='https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/Knn_voronoi.png?raw=true', width=400))

# KD-tree
k-d tree (short for k-dimensional tree) is a space-partitioning data structure for organizing points in a k-dimensional space. 

**Main idea of construction:**  Recursively	partitions	points	into	axis	
aligned	boxes.

**Main idea of searching:** Enables	more	efficient	pruning	of	
search	space. 

- Examine	nearby	points	first.

- Ignore	any	points	that	are	further	
than	the	nearest	point	found	so	far.


k-d tree work	“well”	in	“low-medium”	dimensions ($k\le 20$).

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/8.png?raw=true" width="500" />

## Algorithm step by step

### Construction

Step 1: Start	with	a	list	of	$d$-dimensional	points.


<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/2.png?raw=true" width="500" />


Step 2: Split	the	points	into	2	groups	by:
 - Choosing	dimension	$d_j$ and	value	$V$
   - $d_j$: As one moves down the tree, one cycles through the axes used to select the splitting planes. (For example, in a 3-dimensional tree, the root would have an x-aligned plane, the root's children would both have y-aligned planes, the root's grandchildren would all have z-aligned planes, the root's great-grandchildren would all have x-aligned planes, the root's great-great-grandchildren would all have y-aligned planes, and so on.) Another way is to choose the widest dimension.
  
  - $V$: Points are inserted by selecting the median of the points being put into the subtree, with respect to their coordinates in the axis being used to create the splitting plane. Another way is to choose the center of the range. 

 - Separating	the	points into $x_{d_j}^{(i)}>V$ and $x_{d_j}^{(i)}\ge V$ 


<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/3.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/4.png?raw=true" width="600" />

Step 3: consider each group separately as subtree and repeat step 2. Stop until $m$ points left or box hits minimum width. 

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/5.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/6.png?raw=true" width="600" />


At the end: create a	binary	tree	structure.  Each leaf	node contains	a	list	of points. At each node,  it has	(tight)	bounds	of	the	points	at	or	below	this	node.


### Many heuristics



In [7]:
display(Image(url='https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/7.png?raw=true', width=700))

### Nearest neighbour search

Step 1: Starting with the root node, the algorithm moves down the tree recursively, in the same way that it would if the search point were being inserted (i.e. it goes left or right depending on whether the point is lesser than or greater than the current node in the split dimension).
<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/9.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/10.png?raw=true" width="600" />



Step 2: Once the algorithm reaches a leaf node, it checks that node point and if the distance is better, that node point is saved as the "current best". Calculate the current distance bound $r$.

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/11.png?raw=true" width="600" />

Step 3: 	Backtrack	and	try	the	other	branch at	each node	visited. Each	time a	new	closest	node is found, update the	distance bound $r$. Using	the	distance	bound	and	bounding	box	of	each	node to prune	parts	of	the	tree	that	could	NOT	include	the	nearest	neighbor. 

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/12.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/13.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/14.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/15.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/16.png?raw=true" width="600" />

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/17.png?raw=true" width="600" />

## Approximate kNN with kd-tree
**Before:**	Prune	when	distance	to	bounding	box	> $r$.

**Now:**	Prune	when	distance	to	bounding	box	> $r/\alpha$, where $\alpha>1$. 

Will	prune	more	than	allowed,	but	can	guarantee	that	if	we	return	a	neighbor	at	
distance $r$,	then	there	is	no	neighbor	closer	than $r/\alpha$. 

Saves	lots	of	search	time	at	little	cost	in	quality	of	nearest	neighbor. 

<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/20.png?raw=true" width="600" />


## Complexity 
### Complexity	for	1	Queries
For nearly balanced binary tree, 
- Construction 
  - size: $2N-1$, $O(N)$.

  - depth: $O(\log(N))$

  - Median	+	send	points	left	right:	$O(N)$ at each tree level.

  - Construction time: $O(N\log(N))$

- 1-NN	query	
  - Traverse	down	tree	to	starting	point:	$O(\log(N))$

  - Maximum	backtrack	and	traverse:	$O(N)$ worse case

  - exponential in $d$

  - Complexity	range:	$O(2^d\log(N))$ to $O(2^dN)$


<img src="https://github.com/yexf308/AdvancedMachineLearning/blob/main/image/kd_tree/18.png?raw=true" width="600" />  

Left: pruned many, order is $\log(N)$
Right: pruned few, order is $N$.

### Complexity	for	$N$	Queries
Ask	for	nearest	neighbor	to	each data

- Brute	force	1-NN: $O(N^2)$

- K-d tree: $O(N\log(N))$ to $O(N^2)$


### Complexity for kNN

Exactly	the	same	algorithm,	but	maintain	distance	as	distance	
to	furthest	of	current	$k$	nearest	neighbors. 

- k-d tree: $O(k\log(N))$


### Issue
High	dimensional	spaces	are	hard!	 Number	of	kd-tree	searches	can	be	exponential	in	dimension	$d$. Rule of thumb, k-d tree is useful when $N\gg 2^d$. k-d tree is useless for very large $d$ since we don't have enough observation. 

Distances	are	sensitive	to irrelevant	features in high dimensions.

- Most	dimensions	are	just	noise and everything iss far away. 

- Need	technique	to	learn	what	features	are	important	for	your	task. 


In [19]:
# https://github.com/Vectorized/Python-KD-Tree

# Makes the KD-Tree for fast lookup
def make_kd_tree(points, dim, i=0):
    if len(points) > 1:
        points.sort(key=lambda x: x[i])
        i = (i + 1) % dim
        half = len(points) >> 1
        return [
            make_kd_tree(points[: half], dim, i),
            make_kd_tree(points[half + 1:], dim, i),
            points[half]
        ]
    elif len(points) == 1:
        return [None, None, points[0]]

# Adds a point to the kd-tree
def add_point(kd_node, point, dim, i=0):
    if kd_node is not None:
        dx = kd_node[2][i] - point[i]
        i = (i + 1) % dim
        for j, c in ((0, dx >= 0), (1, dx < 0)):
            if c and kd_node[j] is None:
                kd_node[j] = [None, None, point]
            elif c:
                add_point(kd_node[j], point, dim, i)

# k nearest neighbors
def get_knn(kd_node, point, k, dim, dist_func, return_distances=True, i=0, heap=None):
    import heapq
    is_root = not heap
    if is_root:
        heap = []
    if kd_node is not None:
        dist = dist_func(point, kd_node[2])
        dx = kd_node[2][i] - point[i]
        if len(heap) < k:
            heapq.heappush(heap, (-dist, kd_node[2]))
        elif dist < -heap[0][0]:
            heapq.heappushpop(heap, (-dist, kd_node[2]))
        i = (i + 1) % dim
        # Goes into the left branch, and then the right branch if needed
        for b in [dx < 0] + [dx >= 0] * (dx * dx < -heap[0][0]):
            get_knn(kd_node[b], point, k, dim, dist_func, return_distances, i, heap)
    if is_root:
        neighbors = sorted((-h[0], h[1]) for h in heap)
        return neighbors if return_distances else [n[1] for n in neighbors]

# For the closest neighbor
def get_nearest(kd_node, point, dim, dist_func, return_distances=True, i=0, best=None):
    if kd_node is not None:
        dist = dist_func(point, kd_node[2])
        dx = kd_node[2][i] - point[i]
        if not best:
            best = [dist, kd_node[2]]
        elif dist < best[0]:
            best[0], best[1] = dist, kd_node[2]
        i = (i + 1) % dim
        # Goes into the left branch, and then the right branch if needed
        for b in [dx < 0] + [dx >= 0] * (dx * dx < best[0]):
            get_nearest(kd_node[b], point, dim, dist_func, return_distances, i, best)
    return best if return_distances else best[1]

def get_knn_naive(points, point, k, dist_func, return_distances=True):
    neighbors = []
    for i, pp in enumerate(points):
        dist = dist_func(point, pp)
        neighbors.append((dist, pp))
    neighbors = sorted(neighbors)[:k]
    return neighbors if return_distances else [n[1] for n in neighbors]

def rand_point(dim):
    return [random.uniform(-1, 1) for d in range(dim)]

def dist_sq(a, b, dim):
    return sum((a[i] - b[i]) ** 2 for i in range(dim))

def dist_sq_dim(a, b):
    return dist_sq(a, b, dim)    

In [25]:

"""
Below is all the testing code
"""

import random, cProfile

dim = 8
points = [rand_point(dim) for x in range(10000)]
additional_points = [rand_point(dim) for x in range(50)]
#points = [rand_point(dim) for x in range(5000)]
test = [rand_point(dim) for x in range(100)]
result1 = []
result2 = []


def bench1():
    kd_tree = make_kd_tree(points, dim)
    for point in additional_points:
        add_point(kd_tree, point, dim)
    result1.append(tuple(get_knn(kd_tree, [0] * dim, 8, dim, dist_sq_dim)))
    for t in test:
        result1.append(tuple(get_knn(kd_tree, t, 8, dim, dist_sq_dim)))


def bench2():
    all_points = points + additional_points
    result2.append(tuple(get_knn_naive(all_points, [0] * dim, 8, dist_sq_dim)))
    for t in test:
        result2.append(tuple(get_knn_naive(all_points, t, 8, dist_sq_dim)))

cProfile.run("bench1()")
cProfile.run("bench2()")


print("Is the result same as naive version?: {}".format(result1 == result2))

print("")
kd_tree = make_kd_tree(points, dim)

print(get_nearest(kd_tree, [0] * dim, dim, dist_sq_dim))



         5077136 function calls (4495456 primitive calls) in 2.853 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   673/50    0.001    0.000    0.001    0.000 <ipython-input-19-c164e2fa9c0b>:18(add_point)
569350/101    0.983    0.000    2.765    0.027 <ipython-input-19-c164e2fa9c0b>:29(get_knn)
  11809/1    0.020    0.000    0.084    0.084 <ipython-input-19-c164e2fa9c0b>:4(make_kd_tree)
      909    0.001    0.000    0.001    0.000 <ipython-input-19-c164e2fa9c0b>:46(<genexpr>)
   119535    0.028    0.000    0.028    0.000 <ipython-input-19-c164e2fa9c0b>:6(<lambda>)
   333993    0.254    0.000    1.627    0.000 <ipython-input-19-c164e2fa9c0b>:75(dist_sq)
  3005937    1.010    0.000    1.010    0.000 <ipython-input-19-c164e2fa9c0b>:76(<genexpr>)
   333993    0.123    0.000    1.750    0.000 <ipython-input-19-c164e2fa9c0b>:78(dist_sq_dim)
        1    0.001    0.001    2.852    2.852 <ipython-input-25-7da0fabfbdcf>:17(bench1