# kd-tree implementation for KNN

Reference: [github](https://github.com/Vectorized/Python-KD-Tree)

Step 1: Import packages.

In [1]:
import random
import heapq
import time
import numpy as np
from operator import itemgetter

Step 2: Prepare data.

In [2]:
random.seed(1)
num_data = 5000
num_dims = 3
x_train1 = np.random.random((num_data // 2, num_dims)) - 1
x_train2 = np.random.random((num_data // 2, num_dims))
x_train = np.vstack([x_train1, x_train2])
print(x_train.shape)
y_train = np.array([0] * (num_data // 2) + [1] * (num_data // 2))
print(y_train.shape)

x_test = list()
for i in np.linspace(-1, 1, 500):
    x_test.append([i] * num_dims)

(5000, 3)
(5000,)


Step 3: Build model.

In [3]:
class KNN(object):
    def __init__(self, x_train, y_train):
        '''
            x_train: (num_data, num_dims)
            y_train: (num_data, )
            self.xy: (num_data, num_dims + 1)
        '''
        self.xy = np.zeros((x_train.shape[0], x_train.shape[1] + 1), dtype=float)
        self.xy[:, :-1] = x_train
        self.xy[:, -1] = y_train
        self.num_dims = x_train.shape[1]
        self.kdtree = self.create_kdtree(self.xy)
        self.num_classes = len(set(y_train))
    
    def euclidean_distance(self, point1, point2):
        return np.sqrt(np.sum((point1 - point2) ** 2))
    
    def get_neighbors_naive(self, x_test, k):
        res = list()
        for feature in x_test:
            curr_res = [(self.euclidean_distance(xy[:-1], feature), list(xy[:-1])) for xy in self.xy]
            res.append(curr_res)
        neighbors = [sorted(curr_res)[:k] for curr_res in res]
        return neighbors
    
    def create_kdtree(self, points, i=0):
        if len(points) == 1:
            return (None, None, points[0])
        elif len(points) == 0:
            return None
        points = sorted(points, key=lambda x: x[i])
        middle = len(points) // 2
        i = (i + 1) % self.num_dims
        return (self.create_kdtree(points[:middle], i), self.create_kdtree(points[middle + 1:], i), points[middle])
    
    def get_neighbors_kdtree_single(self, node, point, k, i=0, heap=None):
        '''
            point: (num_dims)
            curr_xy: (num_dims + 1)
        '''
        if node == None:
            return None
        left_node, right_node, curr_xy = node
        dist = self.euclidean_distance(point, curr_xy[:-1])
        dx = point[i] - curr_xy[i]
        
        if not heap:
            heap = list()
        # add curr_point to k list
        if len(heap) < k:
            heapq.heappush(heap, (-dist, curr_xy))
        elif dist < -heap[0][0]:
            heapq.heappushpop(heap, (-dist, curr_xy))
        
        i = (i + 1) % self.num_dims
        
        # updown
        if dx < 0:
            self.get_neighbors_kdtree_single(left_node, point, k, i, heap)
        else:
            self.get_neighbors_kdtree_single(right_node, point, k, i, heap)
        
        # downup
        if abs(dx) < -heap[0][0]: # if another area is covered, the other child should be visited
            if dx < 0:
                self.get_neighbors_kdtree_single(right_node, point, k, i, heap)
            else:
                self.get_neighbors_kdtree_single(left_node, point, k, i, heap)
        
        return [(-neg_dist, list(xy)) for (neg_dist, xy) in sorted(heap, key=lambda x: -1 * x[0])]
    
    def get_neighbors_kdtree(self, x_test, k):
        res = list()
        for feature in x_test:
            res.append(self.get_neighbors_kdtree_single(self.kdtree, feature, k))
        return res
    
    def get_neighbors(self, x_test, k, mode='kdtree'):
        if mode == 'naive':
            return self.get_neighbors_naive(x_test, k)
        elif mode == 'kdtree':
            return self.get_neighbors_kdtree(x_test, k)
        return None
    
    def predict(self, x_test, k, mode='kdtree'):
        neighbors = self.get_neighbors(x_test, k, mode)
        y_predicted = list()
        for index in range(len(x_test)):
            y_count = dict()
            for neighbor in neighbors[index]:
                dist, xy = neighbor
                y = xy[-1]
                y_count[y] = y_count.get(y, 0) + 1
            y_pred = sorted(y_count.items(), key=itemgetter(1), reverse=True)[0][0]
            y_predicted.append(y_pred)
            
        return y_predicted

Step 4: Get neighbors with naive knn and kdtree knn. / Predict labels.

*Test Neighbors*

Define a useful function that only prints the distances of neighbors

In [4]:
def print_dist(neighbors):
    for index, neighbor in enumerate(neighbors):
        print(index, [dist for dist, _ in neighbor])

**1. naive knn**

In [5]:
knn = KNN(x_train, y_train)
t1 = time.time()
neighbors1 = knn.get_neighbors(x_test, 8, mode='naive')
t2 = time.time()
print('time cost:{}s'.format(t2 - t1))
print_dist(neighbors1[::100])

time cost:29.449085474014282s
0 [0.11272374340351567, 0.11613088616183086, 0.12172469698304142, 0.12298341412042026, 0.12443267501483761, 0.13194205980290044, 0.13368433438571858, 0.13993687302237728]
1 [0.047189008786685666, 0.060621485562171604, 0.061875761497820814, 0.06962195367914559, 0.07124976372463021, 0.07299351106055775, 0.076524203308105, 0.07826550819137927]
2 [0.03354048066016671, 0.04300832943077879, 0.05081941607751753, 0.05539906937870222, 0.06896936853703502, 0.0776054868189429, 0.08296892212811463, 0.0855137765483282]
3 [0.033719603230873835, 0.039659908019904884, 0.06439548288966943, 0.06738278839047686, 0.07007277803854058, 0.0810428906688013, 0.08853842102400361, 0.08858609663843374]
4 [0.03806098707251986, 0.04280085784909719, 0.055622069202833885, 0.058207424736827044, 0.06146509585831059, 0.06727589336726607, 0.07137543491534047, 0.07163512837589164]


**2. kdtree knn**

In [6]:
knn = KNN(x_train, y_train)
t1 = time.time()
neighbors2 = knn.get_neighbors(x_test, 8, mode='kdtree')
t2 = time.time()
print('time cost:{}s'.format(t2 - t1))
print_dist(neighbors2[::100])

time cost:0.9458584785461426s
0 [0.11272374340351567, 0.11613088616183086, 0.12172469698304142, 0.12298341412042026, 0.12443267501483761, 0.13194205980290044, 0.13368433438571858, 0.13993687302237728]
1 [0.047189008786685666, 0.060621485562171604, 0.061875761497820814, 0.06962195367914559, 0.07124976372463021, 0.07299351106055775, 0.076524203308105, 0.07826550819137927]
2 [0.03354048066016671, 0.04300832943077879, 0.05081941607751753, 0.05539906937870222, 0.06896936853703502, 0.0776054868189429, 0.08296892212811463, 0.0855137765483282]
3 [0.033719603230873835, 0.039659908019904884, 0.06439548288966943, 0.06738278839047686, 0.07007277803854058, 0.0810428906688013, 0.08853842102400361, 0.08858609663843374]
4 [0.03806098707251986, 0.04280085784909719, 0.055622069202833885, 0.058207424736827044, 0.06146509585831059, 0.06727589336726607, 0.07137543491534047, 0.07163512837589164]


Compare two results.

In [7]:
# compare their distances
cmp = []
have_error = False
for i in range(len(x_test)):
    curr_cmp = [n1[0] == n2[0] for n1, n2 in zip(neighbors1[i], neighbors2[i])]
    for j in curr_cmp:
        if not j: have_error = True
    cmp.append(curr_cmp)
    print(cmp[i])
if not have_error:
    print()
    print('Congratulations! Two results are the same!')

[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True]
[True, True, True, T

*Test label prediction*

In [8]:
knn = KNN(x_train, y_train)
y_predicted = knn.predict(x_test, 8, 'kdtree')
print(y_predicted)

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,