### kd tree
#### three steps to find k nearest neighbors:
- recursive to the nearest leaf node
- compare and add to neighbor list
- calc the distance between hypersphere and hyperrectangle to know if need to the subling nodes

### need to do:
- fix low performance error on some complicate datasets

In [1]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

In [2]:
# load dataset
iris = load_iris()
df = pd.DataFrame(iris.data)
df['label'] = iris.target

In [3]:
df.head()

Unnamed: 0,0,1,2,3,label
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0


### kd tree implementation

In [4]:
class KDNode:
    def __init__(self,axis=None,val=None,label=None,left=None,right=None):
        self.axis = axis
        self.val = val
        self.label = label
        self.left = left
        self.right = right
    
class KDTree:
    def __init__(self,X,Y):
        data = np.concatenate((X,np.array(Y).reshape(len(Y),1)),axis=1)
        self.root = self.build_tree(data,depth=0)
    
    def build_tree(self,data,depth):
        if len(data) == 0:return None
        data_num,feature_dim = len(data),len(data[0])-1 #label dim
        axis = depth % feature_dim

        # get median node
        sorted_data = sorted(data,key=lambda x:x[axis])
        median_index = data_num // 2
        median_node = sorted_data[median_index]
        node = KDNode(axis,median_node[:-1],median_node[-1])
        node.left = self.build_tree(sorted_data[:median_index],depth+1)
        node.right = self.build_tree(sorted_data[median_index+1:],depth+1)
        return node
    
    def get_KNN_labels(self,x,count=1):
        # save all nearest nodes: dist,node
        nearest = [[-1,None] for _ in range(count)]
        self.nearest = np.array(nearest)
        self.recursive_find(x,self.root)
        nearest_labels = [node.label for dist,node in self.nearest]
        
        (labels,counts) = np.unique(nearest_labels,return_counts=True)
        ind=np.argmax(counts)
        return labels[ind]

    def recursive_find(self,x,node):
        if node is None:
            return
        # find nearest leaf node
        axis = node.axis
        if x[axis] < node.val[axis]:
            self.recursive_find(x,node.left)
        else:
            self.recursive_find(x,node.right)
        
        # find the leaf node, backtrack and calc distance ,only use l2 distance
        dist = np.sqrt(np.sum((np.array(x)-np.array(node.val))**2))
        
        # if not have enough nearest nodes
        distances = [dist for dist,node in self.nearest]
        if -1 in distances:
            for i,d in enumerate(self.nearest):
                if d[0] == -1:
                    self.nearest[i] = [dist,node]
                    # resort to compare distance with hyperrectangle
                    self.nearest = sorted(self.nearest,key=lambda x:-x[0])
                    break

        # replace the biggest distance in nearest list when it have <count> values
        else:
            self.nearest = sorted(self.nearest,key=lambda x:-x[0])
            if self.nearest[0][0] > dist:
                self.nearest[0] = [dist,node]
        
        #if compare subling nodes by the distance of hyperplane
        
        # have union set between the hypersphere and the hyperrectangle
        if self.nearest[0][0] > abs(x[axis] - node.val[axis]):
            if x[axis] - node.val[axis] < 0:
                self.recursive_find(x,node.left)
            else:
                self.recursive_find(x,node.right)

In [5]:
#prepare data
X,Y = df.values[:,:-1],df.values[:,-1]
train_X,test_X,train_Y,test_Y = train_test_split(X,Y,test_size=0.9,random_state=42)

import time

stime = time.time()
model = KDTree(train_X,train_Y)

preds = []
for instance in test_X:
    label = model.get_KNN_labels(instance,count=3)
    preds.append(label)
    
etime = time.time()
print("ACC score:{},cost time [{}]s".format(accuracy_score(test_Y,preds),etime-stime))

ACC score:0.9629629629629629,cost time [0.09417414665222168]s


In [6]:
test_Y

array([ 1.,  0.,  2.,  1.,  1.,  0.,  1.,  2.,  1.,  1.,  2.,  0.,  0.,
        0.,  0.,  1.,  2.,  1.,  1.,  2.,  0.,  2.,  0.,  2.,  2.,  2.,
        2.,  2.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  2.,  1.,  0.,  0.,
        0.,  2.,  1.,  1.,  0.,  0.,  1.,  2.,  2.,  1.,  2.,  1.,  2.,
        1.,  0.,  2.,  1.,  0.,  0.,  0.,  1.,  2.,  0.,  0.,  0.,  1.,
        0.,  1.,  2.,  0.,  1.,  2.,  0.,  2.,  2.,  1.,  1.,  2.,  1.,
        0.,  1.,  2.,  0.,  0.,  1.,  1.,  0.,  2.,  0.,  0.,  1.,  1.,
        2.,  1.,  2.,  2.,  1.,  0.,  0.,  2.,  2.,  0.,  0.,  0.,  1.,
        2.,  0.,  2.,  2.,  0.,  1.,  1.,  2.,  1.,  2.,  0.,  2.,  1.,
        2.,  1.,  1.,  1.,  0.,  1.,  1.,  0.,  1.,  2.,  2.,  0.,  1.,
        2.,  2.,  0.,  2.,  0.])

In [7]:
np.array(preds)

array([ 1.,  0.,  2.,  1.,  1.,  0.,  1.,  2.,  1.,  1.,  2.,  0.,  0.,
        0.,  0.,  1.,  2.,  1.,  1.,  2.,  0.,  2.,  0.,  2.,  2.,  2.,
        2.,  2.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  2.,  1.,  0.,  0.,
        0.,  2.,  2.,  1.,  0.,  0.,  1.,  2.,  2.,  1.,  2.,  1.,  2.,
        1.,  0.,  2.,  1.,  0.,  0.,  0.,  1.,  2.,  0.,  0.,  0.,  1.,
        0.,  1.,  2.,  0.,  1.,  2.,  0.,  2.,  2.,  1.,  1.,  2.,  1.,
        0.,  1.,  2.,  0.,  0.,  1.,  1.,  0.,  2.,  0.,  0.,  2.,  1.,
        2.,  2.,  2.,  2.,  1.,  0.,  0.,  2.,  2.,  0.,  0.,  0.,  1.,
        2.,  0.,  2.,  2.,  0.,  1.,  1.,  1.,  1.,  2.,  0.,  2.,  1.,
        2.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  2.,  2.,  0.,  1.,
        2.,  2.,  0.,  2.,  0.])