In [1]:
from random import randint, seed
from lolviz import *

In [2]:
class Point(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self):
        return f'Point(x={self.x}, y={self.y})'

In [3]:
class Rectangle(object):
    def __init__(self, x, y, w, h):
        '''
        x, y: center of the rectangle
        w: width of the rectangle
        h: height of the rectangle
        '''
        self.x = x
        self.y = y
        self.w = w
        self.h = h

    def __repr__(self):
        return f'Rectangle(x={self.x}, y={self.y}, w={self.w}, h={self.h}'

    def contains(self, point):
        '''
        Return true if the point is contained within this rectangle
        '''
        return (point.x >= self.x - self.w and
                point.x < self.x + self.w and
                point.y >= self.y - self.h and
                point.y < self.y + self.h)

    def intersects(self, area):
        '''
        Return true if the area intersects this rectangle
        '''
        return not(area.x - area.w > self.x + self.w or
                   area.x + area.w < self.x - self.w or
                   area.y - area.h > self.y + self.h or
                   area.y + area.h < self.y - self.h)

In [4]:
class Quad(object):
    def __init__(self, boundary, capacity):
        '''
        capacity: total points a quadtree is allowed to contain
        boundary: rectangle boundary of the quadtree
        '''
        self.boundary = boundary
        self.capacity = capacity
        # points this quadtree contains
        self.points = []
        self.ne = None
        self.nw = None
        self.se = None
        self.sw = None
        # is the quadtree divided into children quadrants?
        self.divided = False

    def insert(self, point):
        if not self.boundary.contains(point):
            return False

        if (len(self.points) < self.capacity):
            self.points.append(point)
            return True
        else:
            if not self.divided:
                self.subdivide()
                self.divided = True

        if self.ne.insert(point):
            return True
        elif self.nw.insert(point):
            return True
        elif self.se.insert(point):
            return True
        elif self.sw.insert(point):
            return True

    def subdivide(self):
        x = self.boundary.x
        y = self.boundary.y
        w = self.boundary.w
        h = self.boundary.h

        ne = Rectangle(x+w/2, y-h/2, w/2, h/2)
        nw = Rectangle(x-w/2, y-h/2, w/2, h/2)
        se = Rectangle(x+w/2, y+h/2, w/2, h/2)
        sw = Rectangle(x-w/2, y+h/2, w/2, h/2)

        self.ne = Quad(ne, self.capacity)
        self.nw = Quad(nw, self.capacity)
        self.se = Quad(se, self.capacity)
        self.sw = Quad(sw, self.capacity)

    def query(self, area):
        return self._query(area, [])

    def _query(self, area, found):
        if not self.boundary.intersects(area):
            return found
        else:
            for p in self.points:
                if area.contains(p):
                    found.append(p)

        if self.divided:
            self.ne._query(area, found)
            self.nw._query(area, found)
            self.se._query(area, found)
            self.sw._query(area, found)

        return found

    def __repr__(self):
        return f'Quad(capacity={self.capacity}, boundary={self.boundary}, points={self.points}, se={self.se}, sw={self.sw}, ne={self.ne}, nw={self.nw}, divided={self.divided})'


In [5]:
# quad = Quad(Rectangle(0, 0, 100, 100), 4)
# num_points = 25
# points = list(set([Point(randint(1, 100), randint(1, 100)) for _ in range(num_points)]))
# for p in points:
#     quad.insert(p)
# g = treeviz(quad)
# g.view()

In [6]:
import unittest
seed(0)

class TestQuad(unittest.TestCase):
    def setUp(self):
        quad = Quad(Rectangle(0, 0, 100, 100), 4)
        num_points = 25
        points = list(set([Point(randint(1, 100), randint(1, 100)) for _ in range(num_points)]))
        for p in points:
            quad.insert(p)
        self.quad = quad
        self.points = points
        
    def test_known(self):
        known = Point(9, 9)
        self.points.append(known)
        self.quad.insert(known)
        area = Rectangle(0, 0, 10, 10)
        found = self.quad.query(area)
        self.assertEqual(len(found), 1)
        self.assertEqual(known, found[0])

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

.
----------------------------------------------------------------------
Ran 1 test in 0.001s

OK
