#### 第三章 k近邻法

In [72]:
import numpy as np
from __future__ import annotations 

class KDTree:
    """
    KDTree节点的构造函数
    :param right: 右子树（当前维度值 > value 的样本）
    :param left: 左子树（当前维度值 ≤ value 的样本）
    :param point: 该节点存储的样本点（k维数组）
    :param axis: 该节点用于分割的维度（比如0=第一维，1=第二维）
    :param value: 该节点在axis维度上的分割阈值（通常是该维度的中位数）
    """
    def __init__(self, right: KDTree | None,
                       left: KDTree | None, 
                       point: np.ndarray, 
                       axis: int | None, 
                       value: float | None):
        self.right = right
        self.left = left
        self.point = point
        self.axis = axis
        self.value = value


def build_kd_tree(arr: np.ndarray, axis: int) -> KDTree:
    """
    构建KD树
    :param arr: 输入样本数组，形状为(n_samples, n_features)
    :param axis: 当前分割维度
    :return: 构建好的KDTree根节点
    """
    n, k = arr.shape
    axis = axis % k
    if n == 1:
        return KDTree(None, None, arr[0], None, None)
    sorted_indices_asc = np.argsort(arr[:, axis])
    arr_sorted_asc = arr[sorted_indices_asc]

    median_idx = len(arr_sorted_asc) // 2
    median_point = arr_sorted_asc[median_idx]
    median_value = median_point[axis]

    left_subtree = build_kd_tree(arr_sorted_asc[:median_idx], axis + 1) if median_idx>0 else None
    right_subtree = build_kd_tree(arr_sorted_asc[median_idx+1:], axis + 1) if median_idx+1<n else None
    return KDTree(right_subtree, left_subtree, median_point, axis, median_value)

def knn(x: np.ndarray, kd_tree: KDTree, k: int) -> np.ndarray:
    """
    在KD树中查找x的k近邻
    
    :param x: 待查询样本点，形状为(n_features,)
    :param kd_tree: KDTree根节点
    :param k: 近邻数量
    :return: k近邻的标签列表
    """
    # 找到叶节点
    node = kd_tree
    path = []
    while node.axis is not None:
        path.append(node)
        if x[node.axis] <= node.value:
            node = node.left
        else:
            node = node.right
    path.append(node)

    # 维护一个大小为k的近邻列表
    neighbors = []
    
    def add_neighbor(point):
        dist = np.linalg.norm(x - point.point)
        if len(neighbors) <= k:
            neighbors.append((dist, point))
            neighbors.sort(key=lambda tup: tup[0])
        else:
            if dist < neighbors[-1][0]:
                neighbors[-1] = (dist, point)
                neighbors.sort(key=lambda tup: tup[0])

    add_neighbor(node)
    # 回溯路径，检查其他分支
    for node in reversed(path[:-1]):
        axis = node.axis
        if axis is None:
            continue
        dist_to_plane = abs(x[axis] - node.value)
        if len(neighbors) < k or dist_to_plane < neighbors[-1][0]:
            add_neighbor(node)
            # 可能需要检查另一侧子树
            if x[axis] <= node.value and node.right is not None:
                # 检查右子树
                subtree = node.right
            elif x[axis] > node.value and node.left is not None:
                # 检查左子树
                subtree = node.left
            else:
                continue
            
            # 在子树中查找近邻
            stack = [subtree]
            while stack:
                current_node = stack.pop()
                if current_node.axis is None:
                    add_neighbor(current_node)
                    continue
                if x[current_node.axis] <= current_node.value:
                    stack.append(current_node.right) if current_node.right is not None else None
                    stack.append(current_node.left) if current_node.left is not None else None
                else:
                    stack.append(current_node.left) if current_node.left is not None else None
                    stack.append(current_node.right) if current_node.right is not None else None
        return np.array([tup[1] for tup in neighbors])
X = np.array([[2,3],[5,4],[9,6], [4, 7], [8, 1], [7, 2]])
kd_tree = build_kd_tree(X, axis=0)
k_neighbor=knn(np.array([3,4.5]), kd_tree, k=3)
print([p.point for p in k_neighbor])

[array([2, 3]), array([5, 4]), array([4, 7])]
