Skip to content

Commit

Permalink
[Done] Reinstate may_segfault for from_geojson (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw committed Nov 5, 2021
1 parent 9dcb996 commit 67fe0ad
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 12 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ Version 0.12 (unreleased)

**Bug fixes**

* ...
* Protect ``pygeos.from_geojson`` against segfaults by running the function in a
subprocess (GEOS 3.10.0 only) (#418).


**Acknowledgments**

Expand Down
16 changes: 14 additions & 2 deletions pygeos/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np

from . import Geometry # noqa
from . import geos_capi_version_string, lib
from . import geos_capi_version_string, geos_version_string, lib
from .decorators import requires_geos
from .enum import ParamEnum
from .may_segfault import may_segfault

__all__ = [
"from_geojson",
Expand Down Expand Up @@ -415,6 +416,11 @@ def from_geojson(geometry, on_invalid="raise", **kwargs):
(with type GEOMETRYCOLLECTION). This may be unpacked using the ``pygeos.get_parts``.
Properties are not read.
.. note::
For GEOS 3.10.0, this function is executed in a subprocess. This is because invalid
GeoJSON input may result in a crash. For GEOS 3.10.1 the issue is expected to be fixed.
The GeoJSON format is defined in `RFC 7946 <https://geojson.org/>`__.
The following are currently unsupported:
Expand Down Expand Up @@ -457,7 +463,13 @@ def from_geojson(geometry, on_invalid="raise", **kwargs):
# of array elements)
geometry = np.asarray(geometry, dtype=object)

return lib.from_geojson(geometry, invalid_handler, **kwargs)
# GEOS 3.10.0 may segfault on invalid GeoJSON input. This bug is currently
# solved in main branch, expected fix in (3, 10, 1)
if geos_version_string == "3.10.0": # so not on dev versions!
_from_geojson = may_segfault(lib.from_geojson)
else:
_from_geojson = lib.from_geojson
return _from_geojson(geometry, invalid_handler, **kwargs)


def from_shapely(geometry, **kwargs):
Expand Down
76 changes: 76 additions & 0 deletions pygeos/may_segfault.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import multiprocessing
import warnings


class ReturningProcess(multiprocessing.Process):
"""A Process with an added Pipe for getting the return_value or exception."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pconn, self._cconn = multiprocessing.Pipe()
self._result = {}

def run(self):
if not self._target:
return
try:
with warnings.catch_warnings(record=True) as w:
return_value = self._target(*self._args, **self._kwargs)
self._cconn.send({"return_value": return_value, "warnings": w})
except Exception as e:
self._cconn.send({"exception": e, "warnings": w})

@property
def result(self):
if not self._result and self._pconn.poll():
self._result = self._pconn.recv()
return self._result

@property
def exception(self):
return self.result.get("exception")

@property
def warnings(self):
return self.result.get("warnings", [])

@property
def return_value(self):
return self.result.get("return_value")


def may_segfault(func):
"""The wrapped function will be called in another process.
If the execution crashes with a segfault or sigabort, an OSError
will be raised.
Note: do not use this to decorate a function at module level, because this
will render the function un-Picklable so that multiprocessing fails on OSX/Windows.
Instead, use it like this:
>>> def some_unstable_func():
... ...
>>> some_func = may_segfault(some_unstable_func)
"""

def wrapper(*args, **kwargs):
process = ReturningProcess(target=func, args=args, kwargs=kwargs)
process.start()
process.join()
for w in process.warnings:
warnings.warn_explicit(
w.message,
w.category,
w.filename,
w.lineno,
)
if process.exception:
raise process.exception
elif process.exitcode != 0:
raise OSError(f"GEOS crashed with exit code {process.exitcode}.")
else:
return process.return_value

return wrapper
12 changes: 3 additions & 9 deletions pygeos/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,12 @@ def test_from_geojson_exceptions():
with pytest.raises(pygeos.GEOSException, match="type must be array, but is null"):
pygeos.from_geojson('{"type": "LineString", "coordinates": null}')

# Note: The two below tests make GEOS 3.10.0 crash if it is compiled in Debug mode
# Note: The two below tests may make GEOS 3.10.0 crash
# https://trac.osgeo.org/geos/ticket/1138
with pytest.raises(pygeos.GEOSException, match="ParseException"):
with pytest.raises((pygeos.GEOSException, OSError)):
pygeos.from_geojson('{"geometry": null, "properties": []}')

with pytest.raises(pygeos.GEOSException, match="ParseException"):
with pytest.raises((pygeos.GEOSException, OSError)):
pygeos.from_geojson('{"no": "geojson"}')


Expand All @@ -789,18 +789,12 @@ def test_from_geojson_warn_on_invalid():
with pytest.warns(Warning, match="Invalid GeoJSON"):
assert pygeos.from_geojson("", on_invalid="warn") is None

with pytest.warns(Warning, match="Invalid GeoJSON"):
assert pygeos.from_geojson('{"no": "geojson"}', on_invalid="warn") is None


@pytest.mark.skipif(pygeos.geos_version < (3, 10, 0), reason="GEOS < 3.10")
def test_from_geojson_ignore_on_invalid():
with pytest.warns(None):
assert pygeos.from_geojson("", on_invalid="ignore") is None

with pytest.warns(None):
assert pygeos.from_geojson('{"no": "geojson"}', on_invalid="ignore") is None


@pytest.mark.skipif(pygeos.geos_version < (3, 10, 0), reason="GEOS < 3.10")
def test_from_geojson_on_invalid_unsupported_option():
Expand Down
44 changes: 44 additions & 0 deletions pygeos/tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ctypes
import os
import sys
import warnings
from itertools import chain
from string import ascii_letters, digits
from unittest import mock
Expand All @@ -9,6 +11,7 @@

import pygeos
from pygeos.decorators import multithreading_enabled, requires_geos
from pygeos.may_segfault import may_segfault


@pytest.fixture
Expand Down Expand Up @@ -184,3 +187,44 @@ def test_multithreading_enabled_preserves_flag():
def test_multithreading_enabled_ok(args, kwargs):
result = set_first_element(42, *args, **kwargs)
assert result[0] == 42


def my_unstable_func(event=None):
if event == "segfault":
ctypes.string_at(0) # segfault
elif event == "exit":
exit(1)
elif event == "raise":
raise ValueError("This is a test")
elif event == "warn":
warnings.warn("This is a test", RuntimeWarning)
elif event == "return":
return "This is a test"


def test_may_segfault():
if os.name == "nt":
match = "access violation"
else:
match = "GEOS crashed"
with pytest.raises(OSError, match=match):
may_segfault(my_unstable_func)("segfault")


def test_may_segfault_exit():
with pytest.raises(OSError, match="GEOS crashed with exit code 1."):
may_segfault(my_unstable_func)("exit")


def test_may_segfault_raises():
with pytest.raises(ValueError, match="This is a test"):
may_segfault(my_unstable_func)("raise")


def test_may_segfault_returns():
assert may_segfault(my_unstable_func)("return") == "This is a test"


def test_may_segfault_warns():
with pytest.warns(RuntimeWarning, match="This is a test"):
may_segfault(my_unstable_func)("warn")

0 comments on commit 67fe0ad

Please sign in to comment.