Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

BUG: signal: lfilter would segfault on object arrays (ticket #1452) #112

Closed
wants to merge 1 commit into from

2 participants

@WarrenWeckesser
Collaborator

See http://projects.scipy.org/scipy/ticket/1452

The fix is to verify that all the objects in the object arrays are numbers. If any are not, a ValueError exception is raised.

@WarrenWeckesser
Collaborator

I just recently discovered matplotlib's is_numlike function, which simply attempts to add 1 to an object, and returns True if that does not raise an exception. Something like that, implemented in the python lfilter function before calling the C code, might be a simpler alternative to the custom check in the C code that I wrote.

@teoliphant
Owner

I don't think the right fix is to add this check. Perhaps there is no other way. I usually prefer to see the error raised rather than pre-check the inputs. I will look at the details again. Likely there is an error condition that is not being correctly handled in the C-code.

@WarrenWeckesser
Collaborator

I usually prefer to see the error raised rather than pre-check the inputs.

If by that you mean the ValueError should occur "naturally" in the code while it is computing the result, then I agree. Validating the input uses CPU time, penalizing the careful user of the routine. My immediate goal was preventing the segfault. If this can be done without an initial validation step for object arrays, I'm all for it.

@teoliphant
Owner

Yes. I found the problem. There were a couple of issues. The main issues was that OBJECT_filt was not handling errors correctly. I will submit a pull request with your tests to replace this one.

@WarrenWeckesser
Collaborator

Replaced by #131

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
This page is out of date. Refresh to see the latest.
View
7 scipy/signal/SConscript
@@ -9,9 +9,10 @@ env = GetNumpyEnvironment(ARGUMENTS)
src = env.FromCTemplate("lfilter.c.src")
src += env.FromCTemplate("correlate_nd.c.src")
env.NumpyPythonExtension('sigtools',
- source = src + ['sigtoolsmodule.c',\
- 'firfilter.c', \
- 'medianfilter.c'])
+ source = src + ['sigtoolsmodule.c',
+ 'firfilter.c',
+ 'medianfilter.c',
+ '_array_util.c'])
env.NumpyPythonExtension('spectral', source='spectral.c')
View
45 scipy/signal/_array_util.c
@@ -0,0 +1,45 @@
+#include <Python.h>
+#define PY_ARRAY_UNIQUE_SYMBOL _scipy_signal_ARRAY_API
+#define NO_IMPORT_ARRAY
+#include <numpy/noprefix.h>
+
+/*
+ * int check_all_numbers(PyObject *x, char *name)
+ *
+ * Determine if all the objects in the object array x are numbers.
+ * x must be a pointer to an ndarray of object dtype.
+ * name must be a null-terminated string that holds the name of the
+ *. variable being checked. It is used in the exception message if any
+ * non-numeric objects are found in x.
+ *
+ * Return values:
+ * 0: all objects pass PyNumber_Check()
+ * 1: at least one object does not satisfy PyNumber_Check()
+ * -1: internal error (PyArray_IterNew(x) failed).
+ *
+ */
+
+int check_all_numbers(PyObject *x, const char *name)
+{
+ PyArrayIterObject *iter;
+ int result;
+
+ iter = (PyArrayIterObject *) PyArray_IterNew(x);
+ if (iter == NULL) {
+ PyErr_SetString(PyExc_RuntimeError, "internal error, possibly out of memory.");
+ return -1;
+ }
+ result = 0;
+ while (iter->index < iter->size) {
+ if (!PyNumber_Check(*((PyObject **) (iter->dataptr)))) {
+ PyErr_Format(PyExc_ValueError,
+ "%s is an object array containing objects that are not numbers.",
+ name);
+ result = 1;
+ break;
+ }
+ PyArray_ITER_NEXT(iter);
+ }
+ Py_DECREF(iter);
+ return result;
+}
View
1  scipy/signal/bento.info
@@ -8,6 +8,7 @@ Library:
sigtoolsmodule.c,
firfilter.c,
medianfilter.c
+ _array_util.c
Extension: spectral
Sources: spectral.c
Extension: spline
View
17 scipy/signal/lfilter.c.src
@@ -129,6 +129,22 @@ scipy_signal_sigtools_linear_filter(PyObject * NPY_UNUSED(dummy), PyObject * arg
typenum);
}
+ if (ara->descr->type_num == NPY_OBJECT) {
+ if (check_all_numbers(ara, "a") != 0) {
+ goto fail;
+ }
+ }
+ if (arb->descr->type_num == NPY_OBJECT) {
+ if (check_all_numbers(arb, "b") != 0) {
+ goto fail;
+ }
+ }
+ if (arX->descr->type_num == NPY_OBJECT) {
+ if (check_all_numbers(arX, "x") != 0) {
+ goto fail;
+ }
+ }
+
if (arX->descr->type_num < 256) {
basic_filter = BasicFilterFunctions[(int) (arX->descr->type_num)];
}
@@ -516,6 +532,7 @@ static void C@NAME@_filt(char *b, char *a, char *x, char *y, char *Z,
}
/**end repeat**/
+
static void OBJECT_filt(char *b, char *a, char *x, char *y, char *Z,
intp len_b, uintp len_x, intp stride_X,
intp stride_Y)
View
2  scipy/signal/setup.py
@@ -11,7 +11,7 @@ def configuration(parent_package='', top_path=None):
config.add_extension('sigtools',
sources=['sigtoolsmodule.c', 'firfilter.c',
'medianfilter.c', 'lfilter.c.src',
- 'correlate_nd.c.src'],
+ 'correlate_nd.c.src', '_array_util.c'],
depends=['sigtools.h'],
include_dirs=['.']
)
View
13 scipy/signal/tests/test_signaltools.py
@@ -379,6 +379,19 @@ class TestLinearFilterComplexxxiExtended28(_TestLinearFilter):
class TestLinearFilterDecimal(_TestLinearFilter):
dt = np.dtype(Decimal)
+class TestLinearFilterObject(_TestLinearFilter):
+ dt = np.object_
+
+
+def test_lfilter_bad_object():
+ """lfilter: object arrays with non-numeric objects raise ValueError.
+
+ Regression test for ticket #1452.
+ """
+ assert_raises(ValueError, lfilter, [1.0], [1.0], [1.0, None, 2.0])
+ assert_raises(ValueError, lfilter, [1.0], [None], [1.0, 2.0, 3.0])
+ assert_raises(ValueError, lfilter, [None], [1.0], [1.0, 2.0, 3.0])
+
class _TestCorrelateReal(TestCase):
Something went wrong with that request. Please try again.