# 基于KD tree的KNN算法

对于kd tree的主要内容在于kd tree的**创建**和**搜索**

创建树节点：  
每个树节点存储一条数据样本，以哪个维度进行划分，左子树，右子树

In [3]:
class Kdnode():
    def __init__(self,content,dim,left,right):
        self.content = content
        self.dim = dim
        self.left = left
        self.right = right

## kd tree的创建：
1. 开始：对数据集以0维进行划分，并将0维上的位于中间的数据点存储为树的根结点  
2. 递归：依此沿数据集的n维进行划分，将n维上位于中间的数据点存储为树的节点，直到无可以划分的数据点

In [155]:
import numpy as np
class Kdtree():
    def __init__(self,data):
        k=len(data[0])
        
        def CreateNode(split,data_set):#在哪个维度划分数据集
            
            if len(data_set)==0:#如果数据集为空，则返回None
                return None
            
            index = np.argsort(data_set[:,split])
            split_pos = (len(data_set)-1)//2#主要是为了和知乎上的一篇文章相对应
            split_next = (split+1)%k
            left_data = data_set[index[:split_pos]]
            right_data = data_set[index[split_pos+1:]]
            return Kdnode(data_set[index[split_pos]],split,CreateNode(split_next,left_data),CreateNode(split_next,right_data))
        
        self.root = CreateNode(0,data)

## 验证创建的kdtree()

In [213]:
data = np.array([
    [6.27,5.5],
    [1.24,-2.86],
    [17.05,-12.79],
    [-6.88,-5.4],
    [-2.96,-0.5],
    [7.75,-22.68],
    [10.8,-5.03],
    [-4.6,-10.55],
    [-4.96,12.61],
    [1.75,12.26],
    [15.31,-13.16],
    [7.83,15.7],
    [14.63,-0.35],
                ])

In [214]:
tree = Kdtree(data)

In [215]:
#输出树，检验下结构有没有问题
def protravel(tree):
    if tree is None:
        return
    print(tree.content,tree.dim)
    protravel(tree.left)
    protravel(tree.right)

In [216]:
protravel(tree.root)

[6.27 5.5 ] 0
[ 1.24 -2.86] 1
[-6.88 -5.4 ] 0
[ -4.6  -10.55] 1
[-2.96 -0.5 ] 0
[-4.96 12.61] 1
[ 1.75 12.26] 1
[ 17.05 -12.79] 1
[  7.75 -22.68] 0
[ 15.31 -13.16] 1
[10.8  -5.03] 0
[ 7.83 15.7 ] 1
[14.63 -0.35] 1


![](kdtree.jpg)

木有问题～

## 在kdtree中搜索，距离target最近的k个样本

In [227]:
from collections import namedtuple
res = namedtuple('res','n max_index max_dist nearest')
k=2
res.n=0
res.max_index=-1
res.max_dist=float("inf")
res.nearest = np.empty((k,len(data[0])))

In [230]:
def search(kd_node,target):
    if kd_node is None:
        return
    s = kd_node.dim
    pivot = kd_node.content
    if target[s] <= pivot[s]:
        nearest_node = kd_node.left
        further_node = kd_node.right
    else:
        nearest_node = kd_node.right
        further_node = kd_node.left
    
    search(nearest_node,target)
    
    if res.n<k:
        res.nearest[res.n] = kd_node.content
        res.n+=1
    elif res.n == k:
        dist = np.linalg.norm(res.nearest - target,axis=1)
        res.max_index = np.argmax(dist)
        res.max_dist = np.max(dist)

        if np.linalg.norm(kd_node.content - target) < res.max_dist:
            res.nearest[res.max_index] = kd_node.content
    
    dist = np.linalg.norm(res.nearest - target,axis=1)
    res.max_index = np.argmax(dist)
    res.max_dist = np.max(dist)
    if abs(kd_node.content[kd_node.dim]-target[kd_node.dim]) < res.max_dist:
        search(further_node,target)

In [231]:
search(tree.root,np.array([-1,-5]))

## 建立基于kd-tree的knn算法

In [49]:
import numpy as np
from collections import namedtuple
from collections import Counter
class Kdnode():
    def __init__(self,label,content,dim,left,right):
        self.content = content
        self.dim = dim
        self.left = left
        self.right = right
        self.label = label #所属的标签
        
        
class Kdtree():
    def __init__(self,data,y_label):
        k=len(data[0])
        
        def CreateNode(split,data_set,y_label):#在哪个维度划分数据集
            
            if len(data_set)==0:#如果数据集为空，则返回None
                return None
            
            index = np.argsort(data_set[:,split])
            split_pos = (len(data_set)-1)//2#主要是为了和知乎上的一篇文章相对应
            split_next = (split+1)%k
            left_data = data_set[index[:split_pos]]
            left_label = y_label[index[:split_pos]]
            
            right_data = data_set[index[split_pos+1:]]
            right_label = y_label[index[split_pos+1:]]
            return Kdnode(y_label[split_pos],data_set[index[split_pos]],split,CreateNode(split_next,left_data,left_label),CreateNode(split_next,right_data,right_label))
        
        self.root = CreateNode(0,data,y_label)
    
    def _search(self,kd_node,target,k):
        if kd_node is None:
            return
        s = kd_node.dim
        pivot = kd_node.content
        if target[s] <= pivot[s]:
            nearest_node = kd_node.left
            further_node = kd_node.right
        else:
            nearest_node = kd_node.right
            further_node = kd_node.left

        self._search(nearest_node,target,k)

        if res.n<k:
            res.nearest[res.n] = kd_node.content
            res.label[res.n] = kd_node.label
            res.n+=1

        elif res.n == k:
            dist = np.linalg.norm(res.nearest - target,axis=1)
            res.max_index = np.argmax(dist)
            res.max_dist = np.max(dist)

            if np.linalg.norm(kd_node.content - target) < res.max_dist:
                res.nearest[res.max_index] = kd_node.content
                res.label[res.max_index] = kd_node.label

        dist = np.linalg.norm(res.nearest - target,axis=1)
        res.max_index = np.argmax(dist)
        res.max_dist = np.max(dist)

        if abs(kd_node.content[kd_node.dim]-target[kd_node.dim]) < res.max_dist:
            self._search(further_node,target,k)
    
    def search(self,kd_node,target,k):
        global res
        res = namedtuple('res','n max_index max_dist nearest label')
        res.n=0
        res.nearest = np.empty((k,len(target)))
        res.label=np.empty(k)
        self._search(kd_node,target,k)
        return res
    
    

class KNN_kdtree:
    
    def __init__(self,k=3,p=2):
        """初始化KNN"""
        self.k = k
        self.p = p
        self.X_train = None
        self.y_train = None
    
    def fit(self,X_train,y_train):
        """根据训练数据 训练KNN分类器"""
        self.X_train = X_train
        self.y_train = y_train
        self.tree = Kdtree(self.X_train,self.y_train)
    
    def predict(self,X_predict):
        """给定待预测的数据集X_predict，返回表示X_predict的结果向量"""
        y_predict = [self._predict(x) for x in X_predict]
        return np.array(y_predict)
    
    def _predict(self,x):
        """给定单个待预测数据，返回X的预测结果值"""
        result = self.tree.search(self.tree.root, x ,self.k)
        votes = Counter(result.label)
        return votes.most_common(1)[0][0]        
    
    def __repr__(self):
        return "KNN_kdtree(k=%d,p=%d)"%(self.k,self.p)

In [50]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
data = load_digits()
X_train,X_test,y_train,y_test = train_test_split(data.data,data.target,random_state=666)

In [51]:
clf_kdtree = KNN_kdtree()
clf_kdtree.fit(X_train,y_train)
predict = clf_kdtree.predict(X_test)

In [45]:
from sklearn.metrics import accuracy_score
accuracy_score(predict,y_test)

0.9577777777777777

In [48]:
from sklearn.neighbors import KNeighborsClassifier
clf_skl = KNeighborsClassifier(n_neighbors=3,algorithm='kd_tree')
clf_skl.fit(X_train,y_train)
clf_skl.score(X_test,y_test)

0.9844444444444445

基本实现了基于KD tree的KNN算法