Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,17 @@ def axis_dist(self, point, axis):
return math.pow(self.data[axis] - point[axis], 2)


def dist(self, point):
def dist(self, point, axis=None):
"""
Squared distance between the current Node
and the given point
"""
r = range(self.dimensions)
return sum([self.axis_dist(point, i) for i in r])
if axis is None:
axes = range(self.dimensions)
else:
axes = [axis]

return sum([self.axis_dist(point, i) for i in axes])


def search_knn(self, point, k, dist=None):
Expand All @@ -406,18 +410,21 @@ def search_knn(self, point, k, dist=None):
distances.

dist is a distance function, expecting two points and returning a
distance value. Distance values can be any comparable type.
distance value. dist should expect an optional `axis` parameter. If
given, the distance on the specified axis should be calculated.
Distance values can be any comparable type.

The result is an ordered list of (node, distance) tuples.
"""

if k < 1:
raise ValueError("k must be greater than 0.")

if dist is None:
get_dist = lambda n: n.dist(point)
else:
get_dist = lambda n: dist(n.data, point)
def get_dist(n, axis=None):
if dist is None:
return n.dist(point, axis=None)
else:
return dist(n.data, point, axis=None)

results = []

Expand Down Expand Up @@ -446,12 +453,10 @@ def _search_node(self, point, k, results, get_dist, counter):
heapq.heapreplace(results, item)
else:
heapq.heappush(results, item)
# get the splitting plane

split_plane = self.data[self.axis]
# get the squared distance between the point and the splitting plane
# (squared since all distances are squared).
plane_dist = point[self.axis] - split_plane
plane_dist2 = plane_dist * plane_dist
pt = KDNode(point, dimensions=self.dimensions)
plane_dist = get_dist(pt, axis=self.axis)

# Search the side of the splitting plane that the point is in
if point[self.axis] < split_plane:
Expand All @@ -463,7 +468,7 @@ def _search_node(self, point, k, results, get_dist, counter):

# Search the other side of the splitting plane if it may contain
# points closer than the farthest point in the current results.
if -plane_dist2 > results[0][0] or len(results) < k:
if -plane_dist > results[0][0] or len(results) < k:
if point[self.axis] < self.data[self.axis]:
if self.right is not None:
self.right._search_node(point, k, results, get_dist,
Expand Down