## ADVANCED ALGORITHM PROJECT
### KNN BY KD-TREE COMPARING BRUTE FORCE
### NAME : S M SUTHARSAN RAJ
### SRN    : PES1UG20CS362

In [29]:
# Compute the squared Euclidean distance between X and Y.
def SED(X, Y):
    return sum((i-j)**2 for i, j in zip(X, Y))

SED( (3, 4), (4, 9) )

26

In [30]:
# Use a brute force algorithm to solve the"Nearest Neighbor Problem"
def nearest_neighbor_bf(*, query_points, reference_points):
    return {
        query_p: min(
            reference_points,
            key=lambda X: SED(X, query_p),
        )
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
    (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)

{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}

In [31]:
import collections
import operator

BT = collections.namedtuple("BT", ["value", "left", "right"])


def kdtree(points):

    k = len(points[0])
    
    def build(*, points, depth):
        """Build a k-d tree from a set of points at a given
        depth.
        """
        if len(points) == 0:
            return None
        
        points.sort(key=operator.itemgetter(depth % k))
        middle = len(points) // 2
        
        return BT(
            value = points[middle],
            left = build(
                points=points[:middle],
                depth=depth+1,
            ),
            right = build(
                points=points[middle+1:],
                depth=depth+1,
            ),
        )
    
    return build(points=list(points), depth=0)

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
kdtree(reference_points)

BT(value=(3, 5), left=BT(value=(3, 2), left=BT(value=(1, 2), left=None, right=None), right=None), right=BT(value=(4, 1), left=None, right=None))

In [32]:
NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])


def find_nearest_neighbor(*, tree, point):
    """Find the nearest neighbor in a k-d tree for a given
    point.
    """
    k = len(point)
    
    best = None
    # Basic search in KD tree
    def search(*, tree, depth):
        
        nonlocal best
        
        if tree is None:
            return
        
        distance = SED(tree.value, point)
        if best is None or distance < best.distance:
            best = NNRecord(point=tree.value, distance=distance)
        
        axis = depth % k
        diff = point[axis] - tree.value[axis]
        if diff <= 0:
            close, away = tree.left, tree.right
        else:
            close, away = tree.right, tree.left
        
        search(tree=close, depth=depth+1)
        if diff**2 < best.distance:
            search(tree=away, depth=depth+1)
    
    search(tree=tree, depth=0)
    return best.point

#reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
reference_points = [ (5, 4), (2, 6), (13, 3), (8, 7), (3, 1), (10, 2) ]
tree = kdtree(reference_points)
find_nearest_neighbor(tree=tree, point=(9, 4))

(10, 2)

In [33]:
def nearest_neighbor_kdtree(*, query_points, reference_points):
    tree = kdtree(reference_points)
    return {
        query_p: find_nearest_neighbor(tree=tree, point=query_p)
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [(3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)]

nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)

{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}

In [34]:
nn_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)
nn_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)
nn_kdtree == nn_bf

True

In [35]:
import random

random_point = lambda: (random.random(), random.random())
reference_points = [ random_point() for _ in range(3000) ]
query_points = [ random_point() for _ in range(3000) ]

solution_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points
)
solution_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points
)

solution_bf == solution_kdtree

True

In [36]:
import cProfile

reference_points = [ random_point() for _ in range(4000) ]
query_points = [ random_point() for _ in range(4000) ]

cProfile.run("""nearest_neighbor_bf(reference_points=reference_points, query_points=query_points,)""")

         96004005 function calls in 46.228 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 16000000   10.282    0.000   37.186    0.000 <ipython-input-29-b858ed1712b4>:2(SED)
 48000000   16.196    0.000   16.196    0.000 <ipython-input-29-b858ed1712b4>:3(<genexpr>)
        1    0.000    0.000   46.228   46.228 <ipython-input-30-4ebba6a0f744>:2(nearest_neighbor_bf)
        1    0.013    0.013   46.228   46.228 <ipython-input-30-4ebba6a0f744>:3(<dictcomp>)
 16000000    5.002    0.000   42.189    0.000 <ipython-input-30-4ebba6a0f744>:6(<lambda>)
        1    0.000    0.000   46.228   46.228 <string>:2(<module>)
        1    0.000    0.000   46.228   46.228 {built-in method builtins.exec}
     4000    4.026    0.001   46.214    0.012 {built-in method builtins.min}
 16000000   10.708    0.000   26.904    0.000 {built-in method builtins.sum}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' ob

In [37]:
cProfile.run("""nearest_neighbor_kdtree(reference_points=reference_points, query_points=query_points,)""")

         519530 function calls (425748 primitive calls) in 0.406 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    71823    0.055    0.000    0.190    0.000 <ipython-input-29-b858ed1712b4>:2(SED)
   215469    0.085    0.000    0.085    0.000 <ipython-input-29-b858ed1712b4>:3(<genexpr>)
   8001/1    0.019    0.000    0.034    0.034 <ipython-input-31-8809abb0a4d2>:11(build)
        1    0.000    0.000    0.034    0.034 <ipython-input-31-8809abb0a4d2>:7(kdtree)
89782/4000    0.166    0.000    0.374    0.000 <ipython-input-32-e3b03aa1b367>:11(search)
     4000    0.004    0.000    0.379    0.000 <ipython-input-32-e3b03aa1b367>:4(find_nearest_neighbor)
        1    0.000    0.000    0.417    0.417 <ipython-input-33-ba431d79b981>:1(nearest_neighbor_kdtree)
        1    0.004    0.004    0.383    0.383 <ipython-input-33-ba431d79b981>:3(<dictcomp>)
     4000    0.002    0.000    0.003    0.000 <string>:1(__new__)
        1    0.0