In [32]:
import numpy as np  
  
class Node:  
    def __init__(self, point=None, split=None, left=None, right=None):  
        self.point = point  
        self.split = split  
        self.left = left  
        self.right = right  
  
  
class KDTree:  
    def __init__(self, data):  
        self.data = data  
        self.k = data.shape[1]  
        self.root = self.build_tree(data, depth=0)  
  
    def build_tree(self, data, depth=0):  
        if len(data) == 0:  
            return None  
  
        axis = depth % self.k  
        sorted_data = data[data[:, axis].argsort()]  
        median = len(sorted_data) // 2  
        print(f'axis={axis}, point={sorted_data[median]}')
        print(f'left array = {sorted_data[:median]}  depth={depth + 1}')
        print(f'right array = {sorted_data[median + 1:]}  depth={depth + 1}')
        return Node(point=sorted_data[median],  
                    split=axis,  
                    left=self.build_tree(sorted_data[:median], depth + 1),  
                    right=self.build_tree(sorted_data[median + 1:], depth + 1))  
  
    def query(self, point, tree=None):  
        if tree is None:  
            tree = self.root  
        if tree is None:  
            return None  
  
        if point[tree.split] < tree.point[tree.split]:  
            next_branch = tree.left  
            opposite_branch = tree.right  
        else:  
            next_branch = tree.right  
            opposite_branch = tree.left  
  
        best = min((tree.point, self.query(point, next_branch)), key=lambda x: np.linalg.norm(point - x))  
        if np.linalg.norm(point - best) > abs(point[tree.split] - tree.point[tree.split]):  
            best = min((best, self.query(point, opposite_branch)), key=lambda x: np.linalg.norm(point - x))  
  
        return best  

In [35]:
# 使用范例：  
# data = np.random.randint(1,10, [10,3])  # 创建一个10x3的随机数据集  
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
print(data)
kdtree = KDTree(data)

[[2 3]
 [5 4]
 [9 6]
 [4 7]
 [8 1]
 [7 2]]
axis=0, point=[7 2]
left array = [[2 3]
 [4 7]
 [5 4]]  depth=1
right array = [[8 1]
 [9 6]]  depth=1
axis=1, point=[5 4]
left array = [[2 3]]  depth=2
right array = [[4 7]]  depth=2
axis=0, point=[2 3]
left array = []  depth=3
right array = []  depth=3
axis=0, point=[4 7]
left array = []  depth=3
right array = []  depth=3
axis=1, point=[9 6]
left array = [[8 1]]  depth=2
right array = []  depth=2
axis=0, point=[8 1]
left array = []  depth=3
right array = []  depth=3


In [36]:
sorted_data = data[data[:, 0].argsort()]
print(sorted_data)

[[2 3]
 [4 7]
 [5 4]
 [7 2]
 [8 1]
 [9 6]]


In [21]:
print(kdtree.root.point, kdtree.root.split)
print(kdtree.root.left.point, kdtree.root.left.split)
print(kdtree.root.left.left.point, kdtree.root.left.left.split)
print(kdtree.root.left.right.point, kdtree.root.left.right.split)
print(kdtree.root.left.left.left.point, kdtree.root.left.left.left.split)
# print(kdtree.root.left.right.right.point, kdtree.root.left.right.right.split)
print(kdtree.root.left.right.left.point, kdtree.root.left.right.left.split)

print(kdtree.root.right.point, kdtree.root.right.split)
print(kdtree.root.right.left.point, kdtree.root.right.left.split)


[6 1 1] 0
[5 5 1] 1
[1 3 6] 2
[4 8 9] 2
[3 5 5] 0
[3 9 8] 0
[8 2 5] 1
[9 1 9] 2


In [26]:
query_point = np.random.randint(1,10,3)  # 创建一个随机查询点 
print(query_point)
# print(kdtree.query(query_point))  # 输出离查询点最近的点

[8 9 5]


In [27]:
# point[tree.split] < tree.point[tree.split]:
query_point[kdtree.root.split]

8

In [28]:
kdtree.root.point[kdtree.root.split]

6

In [None]:
# next_branch = tree.right  
# opposite_branch = tree.left  
best = min((tree.point, self.query(point, next_branch)), key=lambda x: np.linalg.norm(point - x))  