# Ch3 kd-Tree KNN

- Author: Aiden Yansen Han
- Date: July 14, 2019

**Important Notes of kd-tree KNN:**
- kd-tree KNN only needs O(log n) computational time.
- kd Tree is binary.
- The choice of k is determined by cross validation.

**The code for kd-tree KNN:**
- The code can be divided into two parts. 
- The first part is to construct a kd tree. 
- The second part is to implement KNN algorithm under the kd-tree data structure.

In [1]:
### kd-tree construction
import numpy as np

class kdTreeNode():
    def __init__(self, data, label=0, cut_axis=0, flag=0):
        self.data = data
        self.label = label
        self.left = None
        self.right = None
        self.cut_axis = cut_axis # used for kd-tree KNN and means which axis of data are used to do sorting.
        self.flag = 0 # recorded whether this point has been arrived or not.


class kdTree():
    '''
    This class constructs a kd tree for both unlabled data and labeled data.
    
    Parameter:
    -- samples: an array or a list(can be high-dimension)
    -- lables: an array or a list or None
    '''    
    def __init__(self, samples, labels=None):
        self.root = None
        self.samples = samples
        self.labels = labels
        
    def _construct_tree(self, samples, labels=None, index=0):
        samples_len = len(samples)
        samples_dim = np.array(samples).ndim
        labels = np.squeeze(labels) if labels is not None else None
        samples = samples if type(samples) == list else samples.tolist()
        if labels is not None:
            labels = labels if type(labels) == list else labels.tolist()
        if samples_len == 1:
            return kdTreeNode(np.squeeze(samples).tolist(), labels, index)
        elif samples_len == 0:
            return None

        samples.sort(key = lambda x: x[index])
        median_point = samples_len//2
        root_data = samples[median_point]
        root_label = labels[median_point] if labels is not None else None
        root = kdTreeNode(root_data, root_label, index)

        idx = (index + 1) % samples_dim
        if labels is None:
            root.left = self._construct_tree(samples[:median_point], None, idx)
            root.right = self._construct_tree(samples[median_point+1:], None, idx)
        else:
            root.left = self._construct_tree(samples[:median_point], labels[:median_point], idx)
            root.right = self._construct_tree(samples[median_point+1:], labels[median_point+1:], idx)

        return np.squeeze(root).tolist()
    
    def construct_kdTree(self):
        self.root = self._construct_tree(self.samples, self.labels)
        return self.root

In [2]:
class KNN():
    '''
    This KNN class would contain two methods, simple KNN and kd-tree KNN.
    
    Parameters:
    -- X_training: an array or list; the feature vectors of training data.
    -- Y_training: an array or list; the labels of training data.
    -- X_test: an array or list; the feature vectors of test data.
    -- p: the order of L^p distance. When p=2, the distance is Euclid distance.
    -- k: the count-in number of nearest neighbors.
    '''
    def __init__(self, X_training, Y_training, X_test, p=2, k=3):
        self.X_training = X_training if type(X_training) == list else X_training.tolist()
        self.Y_training = Y_training if type(Y_training) == list else Y_training.tolist()
        self.X_test = X_test if type(X_test) == list else X_test.tolist()
        self.p = p
        self.k = k
    
    def _distance(self, sample, p):
        if p == 1:
            s = 0
            for i in list(sample):
                s += abs(i)
        elif p == "inf":
            s = max(list(sample))
        else:
            s = 0
            for i in list(sample):
                s += i**p
            s = s**(1/p)
        return s
    
    def _get_mode(self, list_data):
        output = []
        for row in list_data:
            counts = np.bincount(row)
            output.append(np.argmax(counts))
        return output
    
    def simple_KNN(self):
        test_data_len = len(self.X_test)
        Y_nearest_neighbor = []
        Y_nearest_neighbor_idx = []
        Y_nearest_neighbor_dist = []
        
        for i in range(len(X_training)):
            for j in range(test_data_len):
                try:
                    Ynn_j_len = len(Y_nearest_neighbor_dist[j])
                    new_distance = self._distance(np.array(X_training[i])-np.array(X_test[j]), self.p)
                    
                    if Ynn_j_len >= self.k:                        
                        Ynn_j = np.array(Y_nearest_neighbor_dist[j])
                        Ynn_j_max = Ynn_j.max()
                        if new_distance < Ynn_j_max:
                            Ynn_j_argmax = Ynn_j.argmax()
                            Y_nearest_neighbor[j][Ynn_j_argmax] = i
                            Y_nearest_neighbor_dist[j][Ynn_j_argmax] = new_distance
                            Y_nearest_neighbor_idx[j][Ynn_j_argmax] =self.Y_training[i]
                    else:
                        Y_nearest_neighbor[j].append(i)
                        Y_nearest_neighbor_dist[j].append(new_distance)
                        Y_nearest_neighbor_idx[j].append(self.Y_training[i])
                        
                except IndexError:
                    Y_nearest_neighbor.append([i])
                    Y_nearest_neighbor_idx.append([self.Y_training[i]])
                    Y_nearest_neighbor_dist.append([self._distance(np.array(X_training[i])-np.array(X_test[j]), self.p)])
        
        print("The nearest neighbors for each test data are:\n")      
        for i in Y_nearest_neighbor:
            for j in i:
                print(X_training[j], '\n')
            print('------')
            
        print("The labels of nearest neighbors are:", Y_nearest_neighbor_idx)
        return self._get_mode(Y_nearest_neighbor_idx)
    
    ########################################################
    ### The following functions are used for kd-tree KNN ###
    ########################################################
    
    def go_to_leaf(self, root, point):
        '''
        Function for Step 1 of kd-tree KNN
        '''
        traversal_record = [root]
        cut_axis = root.cut_axis
        if root.data[cut_axis] > point[cut_axis] and root.left is not None:
            traversal_record.extend(self.go_to_leaf(root.left, point))
        elif root.data[cut_axis] <= point[cut_axis] and root.right is not None:
            traversal_record.extend(self.go_to_leaf(root.right, point))
        return traversal_record
    
    def _add_item(self, leaf, point, nearest_Neighbors):
        '''
        Function for Step 2 of kd-tree KNN
        '''
        d = np.shape(nearest_Neighbors)[1] if np.array(nearest_Neighbors).ndim > 1 else 0
        new_distance = self._distance(np.array(point)-np.array(leaf.data), self.p)
        
        if d == 0:
            nearest_Neighbors = [[new_distance], [leaf]]
        elif d == self.k:
            if np.max(nearest_Neighbors[0]) > new_distance:
                idx = np.argmax(nearest_Neighbors[0])
                nearest_Neighbors[0][idx] = new_distance
                nearest_Neighbors[1][idx] = leaf
        else:
            nearest_Neighbors[0].append(new_distance)
            nearest_Neighbors[1].append(leaf)
        return nearest_Neighbors
    
    def _step3_2_step3(self, leaf, root, point, record, nearest_Neighbors):
        '''
        Function for Step 3 of kd-tree KNN
        '''
        if leaf is root:
            return nearest_Neighbors
        else:
            leaf = record.pop(-1)
            if leaf.flag == 1:
                return self._step3_2_step3(leaf, root, point, record, nearest_Neighbors)
            else:
                leaf.flag = 1
                ### step 3.1 ###
                nearest_Neighbors = self._add_item(leaf, point, nearest_Neighbors)
                ### step 3.2 ###
                cut_axis = leaf.cut_axis
                distance_from_cutline = abs(leaf.data[cut_axis] - point[cut_axis])
                
                if distance_from_cutline > np.max(nearest_Neighbors[0]) or leaf.right is None or leaf.left is None:
                    
                    return self._step3_2_step3(leaf, root, point, record, nearest_Neighbors)
                    
                else:
                    if leaf.left.flag == 1:
                        nearest_Neighbors = self.kdTree_KNN_4_single_point(leaf.right, point, nearest_Neighbors)
                    else:
                        nearest_Neighbors = self.kdTree_KNN_4_single_point(leaf.left, point, nearest_Neighbors)
                    return self._step3_2_step3(leaf, root, point, record, nearest_Neighbors)
    
    def kdTree_KNN_4_single_point(self, root, point, nearest_Neighbors=[]):
        
        ### Step 1 ###
        record = self.go_to_leaf(root, point)
        
        ### Step 2 ###
        leaf = record.pop(-1)
        leaf.flag = 1
        nearest_Neighbors = self._add_item(leaf, point, nearest_Neighbors)
            
        ### Step 3 ###
        nearest_Neighbors = self._step3_2_step3(leaf, root, point, record, nearest_Neighbors)
        
        return nearest_Neighbors
    
    def kdTree_KNN(self, root_copy, X_test, nearest_Neighbors=[]):
        neighbor_labels = []
        import copy
        print("The nearest neighbors for each sample are:\n")
        for sample in X_test:
            root = copy.deepcopy(root_copy)
            nearest_Neighbors_4_single_sample = self.kdTree_KNN_4_single_point(root, sample)
            neighbor_labels_4_single_sample = []
            for node in nearest_Neighbors_4_single_sample[1]:
                neighbor_labels_4_single_sample.append(node.label)
                print(node.data, '\n')
            print("------")
            neighbor_labels.append(neighbor_labels_4_single_sample)
            
        print("The labels of nearest neighbors are:", neighbor_labels)
        return self._get_mode(neighbor_labels)

### Experiments
**1. Synthetic Data For Examples**

In [3]:
# Create some data on a plane. For those y>x, we label it as 1; otherwise, we label it as 0.
np.random.seed(1010)
X_training = np.random.rand(10, 2).tolist()
Y_training = [(1 if i[0] < i[1] else 0) for i in np.random.randn(10, 2).tolist()]
X_test = [[0.7,0.9], [0.3, 0.1], [0.8, 0.5]]

In [4]:
X_training

[[0.3942564861172032, 0.1755924719507771],
 [0.07270586036662141, 0.19188086780745994],
 [0.399804308410452, 0.4181233316734022],
 [0.7625821030216905, 0.5214099038217975],
 [0.41088321525140814, 0.5374442681526889],
 [0.2705623137454708, 0.4333266215698163],
 [0.8272662882194327, 0.27185321763297077],
 [0.7273878130513135, 0.10024027644635503],
 [0.4977860316670397, 0.5803158538106875],
 [0.5766745754374724, 0.20703908355921719]]

In [5]:
Y_training

[0, 0, 0, 1, 1, 1, 1, 0, 0, 1]

**2. Simple KNN**

In [6]:
classifier = KNN(X_training, Y_training, X_test)
classifier.simple_KNN()

The nearest neighbors for each test data are:

[0.41088321525140814, 0.5374442681526889] 

[0.7625821030216905, 0.5214099038217975] 

[0.4977860316670397, 0.5803158538106875] 

------
[0.3942564861172032, 0.1755924719507771] 

[0.07270586036662141, 0.19188086780745994] 

[0.5766745754374724, 0.20703908355921719] 

------
[0.4977860316670397, 0.5803158538106875] 

[0.7625821030216905, 0.5214099038217975] 

[0.8272662882194327, 0.27185321763297077] 

------
The labels of nearest neighbors are: [[1, 1, 0], [0, 0, 1], [0, 1, 1]]


[1, 0, 1]

**3. kd-tree KNN**

In [7]:
### Construct kd-tree
root = kdTree(X_training, Y_training).construct_kdTree()
cls = KNN(X_training, Y_training, X_test)
cls.kdTree_KNN(root, X_test)

The nearest neighbors for each sample are:

[0.7625821030216905, 0.5214099038217975] 

[0.41088321525140814, 0.5374442681526889] 

[0.4977860316670397, 0.5803158538106875] 

------
[0.07270586036662141, 0.19188086780745994] 

[0.3942564861172032, 0.1755924719507771] 

[0.5766745754374724, 0.20703908355921719] 

------
[0.7625821030216905, 0.5214099038217975] 

[0.8272662882194327, 0.27185321763297077] 

[0.4977860316670397, 0.5803158538106875] 

------
The labels of nearest neighbors are: [[1, 1, 1], [0, 0, 1], [1, 0, 1]]


[1, 0, 1]

#### Reference:
- https://zhuanlan.zhihu.com/p/23966698