In [4]:
from typing import List
from collections import namedtuple
import time


class Point(namedtuple("Point", "x y")):
    def __repr__(self) -> str:
        return f'Point{tuple(self)!r}'


class Rectangle(namedtuple("Rectangle", "lower upper")):
    def __repr__(self) -> str:
        return f'Rectangle{tuple(self)!r}'

    def is_contains(self, p: Point) -> bool:
        return self.lower.x <= p.x <= self.upper.x and self.lower.y <= p.y <= self.upper.y


class Node(namedtuple("Node", "location left right")):
    """
    location: Point
    left: Node
    right: Node
    """

    def __repr__(self):
        return f'{tuple(self)!r}'


class KDTree:
    """k-d tree"""

    def __init__(self):
        self._root = None
        self._n = 0

    def insert(self, p: List[Point],d=0):
        """insert a list of points"""

        if not p : 
            return
        if d == 0:      
            p = sorted(p,key=lambda x:x.x)       
        elif d == 1:
            p = sorted(p,key=lambda x:x.y)  
        median=len(p)//2     
        median_value=p[median]
        if median_value==p[-1]:  
            median=len(p)-1
        else:
            for i in range(median,len(p)):   
                if p[i]>median_value:
                    median=i-1
                    break
        if self._n == 0:
            self._n = 1   
            self._root=Node(p[median],self.insert(p[:median],1-d),self.insert(p[median+1:],1-d))  
        else:
            n=Node(p[median],self.insert(p[:median],1-d),self.insert(p[median+1:],1-d))
            return n    

    def range(self, rectangle: Rectangle) -> List[Point]:
        """range query"""
        l=[]
        n=self._root
        if n==None:
            return l
        n_lis=[n]
        coordinate_lis=[0] 
        for ind,i in enumerate(n_lis):
            if i==None:
                continue
            coordinate=coordinate_lis[ind]
            if coordinate==0:
                if rectangle.lower.x<=i.location.x<=rectangle.upper.x:
                    n_lis.append(i.right)
                    n_lis.append(i.left)
                    coordinate_lis+=[1-coordinate]*2
                    if rectangle.lower.y<=i.location.y<=rectangle.upper.y:
                        l.append(i.location)
                elif i.location.x<rectangle.lower.x:
                    n_lis.append(i.right)
                    coordinate_lis.append(1-coordinate)
                else :
                    n_lis.append(i.left)
                    coordinate_lis.append(1 - coordinate)
            else:
                if rectangle.lower.y<=i.location.y<=rectangle.upper.y:
                    n_lis.append(i.right)
                    n_lis.append(i.left)
                    coordinate_lis+=[1-coordinate]*2
                    if rectangle.lower.x<=i.location.x<=rectangle.upper.x:
                        l.append(i.location)
                elif i.location.y<rectangle.lower.y:
                    n_lis.append(i.right)
                    coordinate_lis.append(1-coordinate)
                else :
                    n_lis.append(i.left)
                    coordinate_lis.append(1 - coordinate)
        return l


def range_test():
    points = [Point(7, 2), Point(5, 4), Point(9, 6), Point(4, 7), Point(8, 1), Point(2, 3)]
    kd = KDTree()
    kd.insert(points)
    result = kd.range(Rectangle(Point(0, 0), Point(6, 6)))
    assert sorted(result) == sorted([Point(2, 3), Point(5, 4)])


def performance_test():
    points = [Point(x, y) for x in range(1000) for y in range(1000)]

    lower = Point(500, 500)
    upper = Point(504, 504)
    rectangle = Rectangle(lower, upper)
    #  naive method
    start = int(round(time.time() * 1000))
    result1 = [p for p in points if rectangle.is_contains(p)]
    end = int(round(time.time() * 1000))
    print(f'Naive method: {end - start}ms')

    kd = KDTree()
    kd.insert(points)
    # k-d tree
    start = int(round(time.time() * 1000))
    result2 = kd.range(rectangle)
    end = int(round(time.time() * 1000))
    print(f'K-D tree: {end - start}ms')
    assert sorted(result1) == sorted(result2)


if __name__ == '__main__':
    range_test()
    performance_test()

Naive method: 163ms
K-D tree: 0ms


In [None]:
#用kdtree方法先从根节点搜索到叶节点，然后从叶节点回溯到根节点以完成搜索。因为kdtree方法将点以不同的维度进行切分，所以在搜索时不用访问每一个节点，可以以分割点的数据为参考进行筛选，从可以减少计算量，可以达到快速检索。而普通方法需要访问每一个点，计算量更大，用时更久。