Skip to content

Commit

Permalink
raise error on nonscalar inputs to STRtree::nearest_all (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
brendan-ward committed Mar 17, 2021
1 parent 086cfc5 commit 3979224
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 6 additions & 2 deletions pygeos/strtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,12 @@ def nearest_all(self, geometry, max_distance=None, return_distance=False):
if geometry.ndim == 0:
geometry = np.expand_dims(geometry, 0)

if max_distance is not None and max_distance <= 0:
raise ValueError("max_distance must be greater than 0")
if max_distance is not None:
if not np.isscalar(max_distance):
raise ValueError("max_distance parameter only accepts scalar values")

if max_distance <= 0:
raise ValueError("max_distance must be greater than 0")

# a distance of 0 means no max_distance is used
max_distance = max_distance or 0
Expand Down
11 changes: 10 additions & 1 deletion pygeos/test/test_strtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,13 +1473,22 @@ def test_nearest_all_max_distance(tree, geometry, max_distance, expected):
@pytest.mark.skipif(pygeos.geos_version < (3, 6, 0), reason="GEOS < 3.6")
@pytest.mark.parametrize(
"geometry,max_distance",
[(pygeos.points(0.5, 0.5), 0), (pygeos.points(0.5, 0.5), -1)],
[
(pygeos.points(0.5, 0.5), 0),
(pygeos.points(0.5, 0.5), -1),
],
)
def test_nearest_all_invalid_max_distance(tree, geometry, max_distance):
with pytest.raises(ValueError, match="max_distance must be greater than 0"):
tree.nearest_all(geometry, max_distance=max_distance)


@pytest.mark.skipif(pygeos.geos_version < (3, 6, 0), reason="GEOS < 3.6")
def test_nearest_all_nonscalar_max_distance(tree):
with pytest.raises(ValueError, match="parameter only accepts scalar values"):
tree.nearest_all(pygeos.points(0.5, 0.5), max_distance=[1])


@pytest.mark.skipif(pygeos.geos_version < (3, 6, 0), reason="GEOS < 3.6")
@pytest.mark.parametrize(
"geometry,expected",
Expand Down

0 comments on commit 3979224

Please sign in to comment.