Skip to content

Commit

Permalink
ENH: Consolidate validation of enum values and raise ValueError on in…
Browse files Browse the repository at this point in the history
…valid values (#263)
  • Loading branch information
brendan-ward committed Mar 3, 2021
1 parent 1b89ff7 commit b4fade9
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 39 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Version 0.10 (unreleased)

* STRtree default leaf size is now 10 instead of 5, for somewhat better performance
under normal conditions (#286)
* Removes ``VALID_PREDICATES`` set from ``pygeos.strtree`` package; these can be constructed
in downstream libraries using the ``pygeos.strtree.BinaryPredicate`` enum.

**Added GEOS functions**

Expand Down
28 changes: 14 additions & 14 deletions pygeos/constructive.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import IntEnum
import numpy as np
from . import Geometry # NOQA
from . import lib
from .decorators import requires_geos, multithreading_enabled
from .enum import ParamEnum


__all__ = [
Expand Down Expand Up @@ -30,16 +30,16 @@
]


class BufferCapStyles(IntEnum):
ROUND = 1
FLAT = 2
SQUARE = 3
class BufferCapStyles(ParamEnum):
round = 1
flat = 2
square = 3


class BufferJoinStyles(IntEnum):
ROUND = 1
MITRE = 2
BEVEL = 3
class BufferJoinStyles(ParamEnum):
round = 1
mitre = 2
bevel = 3


@multithreading_enabled
Expand Down Expand Up @@ -105,7 +105,7 @@ def buffer(
circular line endings (see ``quadsegs``). Both 'square' and 'flat'
result in rectangular line endings, only 'flat' will end at the
original vertex, while 'square' involves adding the buffer width.
join_style : {'round', 'bevel', 'sharp'}
join_style : {'round', 'bevel', 'mitre'}
Specifies the shape of buffered line midpoints. 'round' results in
rounded shapes. 'bevel' results in a beveled edge that touches the
original vertex. 'mitre' results in a single vertex that is beveled
Expand Down Expand Up @@ -149,9 +149,9 @@ def buffer(
True
"""
if isinstance(cap_style, str):
cap_style = BufferCapStyles[cap_style.upper()].value
cap_style = BufferCapStyles.get_value(cap_style)
if isinstance(join_style, str):
join_style = BufferJoinStyles[join_style.upper()].value
join_style = BufferJoinStyles.get_value(join_style)
if not np.isscalar(quadsegs):
raise TypeError("quadsegs only accepts scalar values")
if not np.isscalar(cap_style):
Expand Down Expand Up @@ -196,7 +196,7 @@ def offset_curve(
quadsegs : int
Specifies the number of linear segments in a quarter circle in the
approximation of circular arcs.
join_style : {'round', 'bevel', 'sharp'}
join_style : {'round', 'bevel', 'mitre'}
Specifies the shape of outside corners. 'round' results in
rounded shapes. 'bevel' results in a beveled edge that touches the
original vertex. 'mitre' results in a single vertex that is beveled
Expand All @@ -214,7 +214,7 @@ def offset_curve(
<pygeos.Geometry LINESTRING (2 2, 2 0)>
"""
if isinstance(join_style, str):
join_style = BufferJoinStyles[join_style.upper()].value
join_style = BufferJoinStyles.get_value(join_style)
if not np.isscalar(quadsegs):
raise TypeError("quadsegs only accepts scalar values")
if not np.isscalar(join_style):
Expand Down
23 changes: 23 additions & 0 deletions pygeos/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from enum import IntEnum


class ParamEnum(IntEnum):
"""Wraps IntEnum to provide validation of a requested item.
Intended for enums used for function parameters.
Use enum.get_value(item) for this behavior instead of builtin enum[item].
"""

@classmethod
def get_value(cls, item):
"""Validate incoming item and raise a ValueError with valid options if not present."""
try:
return cls[item].value
except KeyError:
valid_options = {e.name for e in cls}
raise ValueError(
"'{}' is not a valid option, must be one of '{}'".format(
item, "', '".join(valid_options)
)
)
26 changes: 4 additions & 22 deletions pygeos/strtree.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from enum import IntEnum
import numpy as np
from . import lib
from .decorators import requires_geos

from .enum import ParamEnum

__all__ = ["STRtree"]


class BinaryPredicate(IntEnum):
class BinaryPredicate(ParamEnum):
"""The enumeration of GEOS binary predicates types"""

intersects = 1
Expand All @@ -21,9 +20,6 @@ class BinaryPredicate(IntEnum):
contains_properly = 9


VALID_PREDICATES = {e.name for e in BinaryPredicate}


class STRtree:
"""A query-only R-tree created using the Sort-Tile-Recursive (STR)
algorithm.
Expand Down Expand Up @@ -105,14 +101,7 @@ def query(self, geometry, predicate=None):
predicate = 0

else:
if not predicate in VALID_PREDICATES:
raise ValueError(
"Predicate {} is not valid; must be one of {}".format(
predicate, ", ".join(VALID_PREDICATES)
)
)

predicate = BinaryPredicate[predicate].value
predicate = BinaryPredicate.get_value(predicate)

return self._tree.query(geometry, predicate)

Expand Down Expand Up @@ -179,14 +168,7 @@ def query_bulk(self, geometry, predicate=None):
predicate = 0

else:
if not predicate in VALID_PREDICATES:
raise ValueError(
"Predicate {} is not valid; must be one of {}".format(
predicate, ", ".join(VALID_PREDICATES)
)
)

predicate = BinaryPredicate[predicate].value
predicate = BinaryPredicate.get_value(predicate)

return self._tree.query_bulk(geometry, predicate)

Expand Down
16 changes: 13 additions & 3 deletions pygeos/test/test_constructive.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ def test_float_arg_nan(geometry, func):
assert actual is None


def test_buffer_cap_style_invalid():
with pytest.raises(ValueError, match="'invalid' is not a valid option"):
pygeos.buffer(point, 1, cap_style="invalid")


def test_buffer_join_style_invalid():
with pytest.raises(ValueError, match="'invalid' is not a valid option"):
pygeos.buffer(point, 1, join_style="invalid")


def test_snap_none():
actual = pygeos.snap(None, point, tolerance=1.0)
assert actual is None
Expand Down Expand Up @@ -233,9 +243,9 @@ def test_offset_curve_non_scalar_kwargs():
pygeos.offset_curve([line_string, line_string], 1, mitre_limit=[5.0, 6.0])


def test_offset_curve_join_style():
with pytest.raises(KeyError):
pygeos.offset_curve(line_string, 1.0, join_style="nonsense")
def test_offset_curve_join_style_invalid():
with pytest.raises(ValueError, match="'invalid' is not a valid option"):
pygeos.offset_curve(line_string, 1.0, join_style="invalid")


@pytest.mark.skipif(pygeos.geos_version < (3, 7, 0), reason="GEOS < 3.7")
Expand Down

0 comments on commit b4fade9

Please sign in to comment.