In [45]:
from typing import List
from collections import namedtuple
import time


class Point(namedtuple("Point", 'x y')):
    def __repr__(self) -> str:
        return f'Point{tuple(self)!r}'
        


class Rectangle(namedtuple("Rectangle", "lower upper")):
    def __repr__(self) -> str:
        return f'Rectangle{tuple(self)!r}'

    def is_contains(self, p: Point) -> bool:
        return self.lower.x <= p.x <= self.upper.x and self.lower.y <= p.y <= self.upper.y


class Node:
    def __init__(self, location: Point, left=None, right=None):
        self.location = location
        self.left = left
        self.right = right

    def __repr__(self):
        return f'{tuple(self)!r}'

class KDTree:
    """k-d tree"""

    def __init__(self):
        self._root = None
        self._n = 0
        self.left = None
        self.right = None

    def insert_alone(self, p: Point):
        """insert a point"""
        if self._root is None:
            self._root = Node(location=p, left=None, right=None)
            return
        curr = self._root
        dim = 0
        while curr:
            axis = dim % 2
            if getattr(p, 'x' if axis else 'y') < getattr(curr.location, 'x' if axis else 'y'):
                if curr.left is None:
                    curr.left = Node(location=p, left=None, right=None)
                    break
                curr = curr.left
            else:
                if curr.right is None:
                    curr.right = Node(location=p, left=None, right=None)
                    break
                curr = curr.right
            dim += 1

    def insert(self, p: List[Point]):
        for point in p:
            self.insert_alone(point)

    def range(self, rectangle: Rectangle) -> List[Point]:
        """range query"""
        pass


def range_test():
    points = [Point(7, 2), Point(5, 4), Point(9, 6), Point(4, 7), Point(8, 1), Point(2, 3)]
    kd = KDTree()
    kd.insert(points)
    result = kd.range(Rectangle(Point(0, 0), Point(6, 6)))
    assert sorted(result) == sorted([Point(2, 3), Point(5, 4)])


def performance_test():
    points = [Point(x, y) for x in range(1000) for y in range(1000)]

    lower = Point(500, 500)
    upper = Point(504, 504)
    rectangle = Rectangle(lower, upper)
    #  naive method
    start = int(round(time.time() * 1000))
    result1 = [p for p in points if rectangle.is_contains(p)]
    end = int(round(time.time() * 1000))
    print(f'Naive method: {end - start}ms')

    kd = KDTree()
    kd.insert(points)
    # k-d tree
    start = int(round(time.time() * 1000))
    result2 = kd.range(rectangle)
    end = int(round(time.time() * 1000))
    print(f'K-D tree: {end - start}ms')

    assert sorted(result1) == sorted(result2)


# if __name__ == '__main__':
#     range_test()
#     performance_test()

In [47]:
kd = KDTree()
points = [Point(7, 2), Point(5, 4), Point(9, 6), Point(4, 7), Point(8, 1), Point(2, 3)]
performance_test()

Naive method: 508ms


In [15]:
Point(1,2).y

2