diff --git a/CHANGES.txt b/CHANGES.txt index fe01b7f01..8b12c78fd 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -10,6 +10,8 @@ Shapely 1.8 will support only Python versions >= 3.6. New features: +- The STRtree nearest*() methods now take an optional argument that + specifies exclusion of the input geometry from results (#1115). - A GeometryTypeError has been added to shapely.errors and is consistently raised instead of TypeError or ValueError as in version 1.7. For backwards compatibility, the new exception will derive from TypeError and Value error diff --git a/shapely/strtree.py b/shapely/strtree.py index 75b0f2ce7..ac1008fba 100644 --- a/shapely/strtree.py +++ b/shapely/strtree.py @@ -21,6 +21,7 @@ import ctypes import logging from typing import Any, ItemsView, Iterable, Iterator, Sequence, Tuple, Union +import sys from warnings import warn from shapely.errors import ShapelyDeprecationWarning @@ -249,7 +250,9 @@ def query(self, geom: BaseGeometry) -> Sequence[BaseGeometry]: """ return self.query_geoms(geom) - def nearest_item(self, geom: BaseGeometry) -> Union[Any, None]: + def nearest_item( + self, geom: BaseGeometry, exclusive: bool = False + ) -> Union[Any, None]: """Query the tree for the node nearest to geom and get the item stored in the node. @@ -259,6 +262,9 @@ def nearest_item(self, geom: BaseGeometry) -> Union[Any, None]: ---------- geom : geometry object The query geometry. + exclusive : bool, optional + Whether to exclude the item corresponding to the given geom + from results or not. Default: False. Returns ------- @@ -289,10 +295,14 @@ def nearest_item(self, geom: BaseGeometry) -> Union[Any, None]: def callback(item1, item2, distance, userdata): try: + callback_userdata = ctypes.cast(userdata, ctypes.py_object).value idx = ctypes.cast(item1, ctypes.py_object).value geom2 = ctypes.cast(item2, ctypes.py_object).value dist = ctypes.cast(distance, ctypes.POINTER(ctypes.c_double)) - lgeos.GEOSDistance(self._rev[idx]._geom, geom2._geom, dist) + if callback_userdata["exclusive"] and self._rev[idx].equals(geom2): + dist[0] = sys.float_info.max + else: + lgeos.GEOSDistance(self._rev[idx]._geom, geom2._geom, dist) return 1 except Exception: log.exception("Caught exception") @@ -303,12 +313,14 @@ def callback(item1, item2, distance, userdata): ctypes.py_object(geom), envelope._geom, lgeos.GEOSDistanceCallback(callback), - None, + ctypes.py_object({"exclusive": exclusive}), ) result = ctypes.cast(item, ctypes.py_object).value return result - def nearest_geom(self, geom: BaseGeometry) -> Union[BaseGeometry, None]: + def nearest_geom( + self, geom: BaseGeometry, exclusive: bool = False + ) -> Union[BaseGeometry, None]: """Query the tree for the node nearest to geom and get the geometry corresponding to the item stored in the node. @@ -316,6 +328,9 @@ def nearest_geom(self, geom: BaseGeometry) -> Union[BaseGeometry, None]: ---------- geom : geometry object The query geometry. + exclusive : bool, optional + Whether to exclude the given geom from results or not. + Default: False. Returns ------- @@ -325,13 +340,15 @@ def nearest_geom(self, geom: BaseGeometry) -> Union[BaseGeometry, None]: version 2.0. """ - item = self.nearest_item(geom) + item = self.nearest_item(geom, exclusive=exclusive) if item is None: return None else: return self._rev[item] - def nearest(self, geom: BaseGeometry) -> Union[BaseGeometry, None]: + def nearest( + self, geom: BaseGeometry, exclusive: bool = False + ) -> Union[BaseGeometry, None]: """Query the tree for the node nearest to geom and get the geometry corresponding to the item stored in the node. @@ -342,6 +359,9 @@ def nearest(self, geom: BaseGeometry) -> Union[BaseGeometry, None]: ---------- geom : geometry object The query geometry. + exclusive : bool, optional + Whether to exclude the given geom from results or not. + Default: False. Returns ------- @@ -351,4 +371,4 @@ def nearest(self, geom: BaseGeometry) -> Union[BaseGeometry, None]: version 2.0. """ - return self.nearest_geom(geom) + return self.nearest_geom(geom, exclusive=exclusive) diff --git a/tests/test_strtree.py b/tests/test_strtree.py index 04d574223..55b61c43d 100644 --- a/tests/test_strtree.py +++ b/tests/test_strtree.py @@ -177,3 +177,22 @@ def test_nearest_items(geoms, items): with pytest.warns(ShapelyDeprecationWarning): tree = STRtree(geoms, items) assert tree.nearest_item(None) is None + + +@pytest.mark.skipif(geos_version < (3, 6, 0), reason="GEOS 3.6.0 required") +@pytest.mark.parametrize( + "geoms", + [ + [ + Point(0, 0.5), + Polygon([(1, 0), (2, 0), (2, 1), (1, 1)]), + Polygon([(0, 2), (1, 2), (1, 3), (0, 3)]), + ] + ], +) +@pytest.mark.parametrize("items", [list(range(1, 4)), list("abc")]) +@pytest.mark.parametrize("query_geom", [Point(0, 0.5)]) +def test_nearest_item_exclusive(geoms, items, query_geom): + with pytest.warns(ShapelyDeprecationWarning): + tree = STRtree(geoms, items) + assert tree.nearest_item(query_geom, exclusive=True) != items[0]