Skip to content

Commit

Permalink
Merge 8dd4695 into 2c28166
Browse files Browse the repository at this point in the history
  • Loading branch information
grapemix committed Oct 16, 2017
2 parents 2c28166 + 8dd4695 commit d7675aa
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
40 changes: 25 additions & 15 deletions kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,27 @@ def search_nn(self, point, dist=None):
return next(iter(self.search_knn(point, 1, dist)), None)


def _search_nn_dist(self, point, dist, results, get_dist):
if not self:
return

nodeDist = get_dist(self)

if nodeDist < dist:
results.append(self.data)

# get the splitting plane
split_plane = self.data[self.axis]

# Search the side of the splitting plane that the point is in
if point[self.axis] <= split_plane + dist:
if self.left is not None:
self.left._search_nn_dist(point, dist, results, get_dist)
if point[self.axis] >= split_plane - dist:
if self.right is not None:
self.right._search_nn_dist(point, dist, results, get_dist)


@require_axis
def search_nn_dist(self, point, distance, best=None):
"""
Expand All @@ -499,22 +520,11 @@ def search_nn_dist(self, point, distance, best=None):
nodes to the point within the distance will be returned.
"""

if best is None:
best = []

# consider the current node
if self.dist(point) < distance:
best.append(self)

# sort the children, nearer one first (is this really necessairy?)
children = sorted(self.children, key=lambda c_p1: c_p1[0].dist(point))

for child, p in children:
# check if child node needs to be recursed
if self.axis_dist(point, self.axis) < math.pow(distance, 2):
child.search_nn_dist(point, distance, best)
results = []
get_dist = lambda n: n.dist(point)

return best
self._search_nn_dist(point, distance, results, get_dist)
return results


@require_axis
Expand Down
10 changes: 9 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,12 @@ def test_search_nn_dist(self):
tree = kdtree.create(points)
nn = tree.search_nn_dist((5,5), 2.5)

self.assertEqual(len(nn), 4)
self.assertEqual(len(nn), 9)
self.assertTrue( (4,4) in nn)
self.assertTrue( (4,5) in nn)
self.assertTrue( (4,6) in nn)
self.assertTrue( (5,4) in nn)
self.assertTrue( (6,4) in nn)
self.assertTrue( (6,6) in nn)
self.assertTrue( (5,5) in nn)
self.assertTrue( (5,6) in nn)
Expand Down Expand Up @@ -313,3 +318,6 @@ def random_point(dimensions=3, minval=0, maxval=100):
def random_points(dimensions=3, minval=0, maxval=100):
while True:
yield random_point(dimensions, minval, maxval)

if __name__ == "__main__":
unittest.main()

0 comments on commit d7675aa

Please sign in to comment.