Skip to content

Commit

Permalink
Merge branch 'betterenvi-speed-up'
Browse files Browse the repository at this point in the history
  • Loading branch information
stefankoegl committed Oct 19, 2017
2 parents 6353437 + 05e1274 commit a149695
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
7 changes: 5 additions & 2 deletions kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def search_knn(self, point, k, dist=None):
The result is an ordered list of (node, distance) tuples.
"""

if k < 1:
raise ValueError("k must be greater than 0.")

if dist is None:
get_dist = lambda n: n.dist(point)
else:
Expand Down Expand Up @@ -439,7 +442,7 @@ def _search_node(self, point, k, results, get_dist, counter):
# so, replace it.
item = (-nodeDist, next(counter), self)
if len(results) >= k:
if -nodeDist > min(results)[0]:
if -nodeDist > results[0][0]:
heapq.heapreplace(results, item)
else:
heapq.heappush(results, item)
Expand All @@ -460,7 +463,7 @@ def _search_node(self, point, k, results, get_dist, counter):

# Search the other side of the splitting plane if it may contain
# points closer than the farthest point in the current results.
if plane_dist2 > min(results)[0] or len(results) < k:
if -plane_dist2 > results[0][0] or len(results) < k:
if point[self.axis] < self.data[self.axis]:
if self.right is not None:
self.right._search_node(point, k, results, get_dist,
Expand Down
1 change: 1 addition & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,5 +366,6 @@ def random_points(dimensions=3, minval=0, maxval=100):
while True:
yield random_point(dimensions, minval, maxval)


if __name__ == "__main__":
unittest.main()

0 comments on commit a149695

Please sign in to comment.