Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

MAINT: handle empty arrays better in fftconvolve. Also some PEP8 chan…

…ges.
  • Loading branch information...
commit 1353ac89aec85a98160b40fb3c2e143305f1ff50 1 parent 1ba07c1
@rgommers rgommers authored
Showing with 54 additions and 57 deletions.
  1. +8 −6 scipy/signal/signaltools.py
  2. +46 −51 scipy/signal/tests/test_signaltools.py
View
14 scipy/signal/signaltools.py
@@ -149,11 +149,11 @@ def _centered(arr, newsize):
def fftconvolve(in1, in2, mode="full"):
"""Convolve two N-dimensional arrays using FFT.
- Convolve `in1` and `in2` using the fast Fourier transform method, with
+ Convolve `in1` and `in2` using the fast Fourier transform method, with
the output size determined by the `mode` argument.
- This is generally much faster than `convolve` for large arrays (n > ~500),
- but can be slower when only a few output values are needed, and can only
+ This is generally much faster than `convolve` for large arrays (n > ~500),
+ but can be slower when only a few output values are needed, and can only
output float arrays (int or object array inputs will be cast to float).
Parameters
@@ -185,11 +185,13 @@ def fftconvolve(in1, in2, mode="full"):
in1 = asarray(in1)
in2 = asarray(in2)
- if rank(in1) == rank(in2) == 0:
+ if rank(in1) == rank(in2) == 0: # scalar inputs
return in1 * in2
elif not in1.ndim == in2.ndim:
raise ValueError("in1 and in2 should have the same rank")
-
+ elif in1.size == 0 or in2.size == 0: # empty arrays
+ return array([])
+
s1 = array(in1.shape)
s2 = array(in2.shape)
complex_result = (np.issubdtype(in1.dtype, np.complex) or
@@ -203,7 +205,7 @@ def fftconvolve(in1, in2, mode="full"):
fsize = 2 ** np.ceil(np.log2(size)).astype(int)
fslice = tuple([slice(0, int(sz)) for sz in size])
if not complex_result:
- ret = irfftn(rfftn(in1, fsize) *
+ ret = irfftn(rfftn(in1, fsize) *
rfftn(in2, fsize), fsize)[fslice].copy()
ret = ret.real
else:
View
97 scipy/signal/tests/test_signaltools.py
@@ -44,9 +44,9 @@ def test_2d_arrays(self):
a = [[1,2,3],[3,4,5]]
b = [[2,3,4],[4,5,6]]
c = convolve(a,b)
- d = array( [[2 ,7 ,16,17,12],\
- [10,30,62,58,38],\
- [12,31,58,49,30]])
+ d = array([[2 ,7 ,16,17,12],
+ [10,30,62,58,38],
+ [12,31,58,49,30]])
assert_array_equal(c,d)
def test_valid_mode(self):
@@ -77,35 +77,35 @@ class _TestConvolve2d(TestCase):
def test_2d_arrays(self):
a = [[1,2,3],[3,4,5]]
b = [[2,3,4],[4,5,6]]
- d = array( [[2 ,7 ,16,17,12],\
- [10,30,62,58,38],\
- [12,31,58,49,30]])
- e = convolve2d(a,b)
- assert_array_equal(e,d)
+ d = array([[2 ,7 ,16,17,12],
+ [10,30,62,58,38],
+ [12,31,58,49,30]])
+ e = convolve2d(a, b)
+ assert_array_equal(e, d)
def test_valid_mode(self):
- e = [[2,3,4,5,6,7,8],[4,5,6,7,8,9,10]]
- f = [[1,2,3],[3,4,5]]
- g = convolve2d(e,f,'valid')
+ e = [[2,3,4,5,6,7,8], [4,5,6,7,8,9,10]]
+ f = [[1,2,3], [3,4,5]]
+ g = convolve2d(e, f, 'valid')
h = array([[62,80,98,116,134]])
- assert_array_equal(g,h)
+ assert_array_equal(g, h)
def test_fillvalue(self):
a = [[1,2,3],[3,4,5]]
b = [[2,3,4],[4,5,6]]
fillval = 1
c = convolve2d(a,b,'full','fill',fillval)
- d = array([[24,26,31,34,32],\
- [28,40,62,64,52],\
+ d = array([[24,26,31,34,32],
+ [28,40,62,64,52],
[32,46,67,62,48]])
- assert_array_equal(c,d)
+ assert_array_equal(c, d)
def test_wrap_boundary(self):
a = [[1,2,3],[3,4,5]]
b = [[2,3,4],[4,5,6]]
c = convolve2d(a,b,'full','wrap')
- d = array([[80,80,74,80,80],\
- [68,68,62,68,68],\
+ d = array([[80,80,74,80,80],
+ [68,68,62,68,68],
[80,80,74,80,80]])
assert_array_equal(c,d)
@@ -113,8 +113,8 @@ def test_sym_boundary(self):
a = [[1,2,3],[3,4,5]]
b = [[2,3,4],[4,5,6]]
c = convolve2d(a,b,'full','symm')
- d = array([[34,30,44, 62, 66],\
- [52,48,62, 80, 84],\
+ d = array([[34,30,44, 62, 66],
+ [52,48,62, 80, 84],
[82,78,92,110,114]])
assert_array_equal(c,d)
@@ -124,7 +124,7 @@ def test_same_mode(self):
e = [[1,2,3],[3,4,5]]
f = [[2,3,4,5,6,7,8],[4,5,6,7,8,9,10]]
g = convolve2d(e,f,'same')
- h = array([[22,28,34],\
+ h = array([[22,28,34],
[80,98,116]])
assert_array_equal(g,h)
@@ -136,6 +136,7 @@ def _test():
convolve2d(e,f,'valid')
self.assertRaises(ValueError, _test)
+
class TestFFTConvolve(TestCase):
def test_real(self):
x = array([1,2,3])
@@ -148,16 +149,16 @@ def test_complex(self):
def test_2d_real_same(self):
a = array([[1,2,3],[4,5,6]])
- assert_array_almost_equal(signal.fftconvolve(a,a),\
- array([[1,4,10,12,9],\
- [8,26,56,54,36],\
- [16,40,73,60,36]]))
+ assert_array_almost_equal(signal.fftconvolve(a,a),
+ array([[1,4,10,12,9],
+ [8,26,56,54,36],
+ [16,40,73,60,36]]))
def test_2d_complex_same(self):
a = array([[1+2j,3+4j,5+6j],[2+1j,4+3j,6+5j]])
c = fftconvolve(a,a)
- d = array([[-3+4j,-10+20j,-21+56j,-18+76j,-11+60j],\
- [10j,44j,118j,156j,122j],\
+ d = array([[-3+4j,-10+20j,-21+56j,-18+76j,-11+60j],
+ [10j,44j,118j,156j,122j],
[3+4j,10+20j,21+56j,18+76j,11+60j]])
assert_array_almost_equal(c,d)
@@ -190,12 +191,10 @@ def test_real_valid_mode2(self):
assert_array_almost_equal(c,d)
def test_empty(self):
- """Regression test for #1745: crashes with 0-length input"""
- a = array([5])
- b = array([])
- def _test():
- fftconvolve(a,b)
- self.assertRaises(ValueError, _test)
+ # Regression test for #1745: crashes with 0-length input.
+ assert_(fftconvolve([], []).size == 0)
+ assert_(fftconvolve([5, 6], []).size == 0)
+ assert_(fftconvolve([], [7]).size == 0)
def test_zero_rank(self):
a = array(4967)
@@ -262,12 +261,12 @@ def test_none(self):
class TestWiener(TestCase):
def test_basic(self):
g = array([[5,6,4,3],[3,5,6,2],[2,3,5,6],[1,6,9,7]],'d')
- correct = array([[2.16374269,3.2222222222, 2.8888888889, 1.6666666667],
- [2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],
- [2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],
- [1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
- h = signal.wiener(g)
- assert_array_almost_equal(h,correct,decimal=6)
+ h = array([[2.16374269,3.2222222222, 2.8888888889, 1.6666666667],
+ [2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],
+ [2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],
+ [1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
+ assert_array_almost_equal(signal.wiener(g), h , decimal=6)
+
class TestCSpline1DEval(TestCase):
def test_basic(self):
@@ -282,6 +281,7 @@ def test_basic(self):
# make sure interpolated values are on knot points
assert_array_almost_equal(y2[::10], y, decimal=5)
+
class TestOrderFilt(TestCase):
def test_basic(self):
assert_array_equal(signal.order_filter([1,2,3],[1,0,1],1),
@@ -390,7 +390,7 @@ def test_rank3(self):
assert_array_almost_equal(y[i, j], lfilter(b, a, x[i, j]))
def test_empty_zi(self):
- """Regression test for #880: empty array for zi crashes."""
+ # Regression test for #880: empty array for zi crashes.
a = np.ones(1).astype(self.dt)
b = np.ones(1).astype(self.dt)
x = np.arange(5).astype(self.dt)
@@ -426,17 +426,14 @@ class TestLinearFilterObject(_TestLinearFilter):
def test_lfilter_bad_object():
- """lfilter: object arrays with non-numeric objects raise TypeError.
-
- Regression test for ticket #1452.
- """
+ # lfilter: object arrays with non-numeric objects raise TypeError.
+ # Regression test for ticket #1452.
assert_raises(TypeError, lfilter, [1.0], [1.0], [1.0, None, 2.0])
assert_raises(TypeError, lfilter, [1.0], [None], [1.0, 2.0, 3.0])
assert_raises(TypeError, lfilter, [None], [1.0], [1.0, 2.0, 3.0])
class _TestCorrelateReal(TestCase):
-
dt = None
def _setup_rank1(self):
@@ -469,7 +466,7 @@ def _setup_rank3(self):
a = np.linspace(0, 39, 40).reshape((2, 4, 5), order='F').astype(self.dt)
b = np.linspace(0, 23, 24).reshape((2, 3, 4), order='F').astype(self.dt)
- y_r = array([[[ 0., 184., 504., 912., 1360., 888., 472., 160.,],
+ y_r = array([[[0., 184., 504., 912., 1360., 888., 472., 160.,],
[ 46., 432., 1062., 1840., 2672., 1698., 864., 266.,],
[ 134., 736., 1662., 2768., 3920., 2418., 1168., 314.,],
[ 260., 952., 1932., 3056., 4208., 2580., 1240., 332.,] ,
@@ -518,6 +515,7 @@ class TestCorrelateX(base):
TestCorrelateX.__name__ = "TestCorrelate%s" % datatype.__name__.title()
return TestCorrelateX
+
for datatype in [np.ubyte, np.byte, np.ushort, np.short, np.uint, np.int,
np.ulonglong, np.ulonglong, np.float32, np.float64, np.longdouble,
Decimal]:
@@ -526,7 +524,6 @@ class TestCorrelateX(base):
class _TestCorrelateComplex(TestCase):
-
# The numpy data type to use.
dt = None
@@ -641,7 +638,7 @@ def test_sine(self):
assert_equal(y2d, y2dt.T)
def test_axis(self):
- """Test the 'axis' keyword on a 3D array."""
+ # Test the 'axis' keyword on a 3D array.
x = np.arange(10.0 * 11.0 * 12.0).reshape(10, 11, 12)
b, a = butter(3, 0.125)
y0 = filtfilt(b, a, x, padlen=0, axis=0)
@@ -658,7 +655,7 @@ def test_basic(self):
assert_array_equal(signal.decimate(x, 2, n=1).round(), x[::2])
def test_shape(self):
- """Regression test for ticket #1480."""
+ # Regression test for ticket #1480.
z = np.zeros((10, 10))
d0 = signal.decimate(z, 2, axis=0)
assert_equal(d0.shape, (5, 10))
@@ -726,8 +723,7 @@ def test_hilbert_axisN(self):
yield assert_equal, hilbert(a.T, N=20, axis=0).shape, [20,3]
#the next test is just a regression test,
#no idea whether numbers make sense
- a0hilb = np.array(
- [ 0.000000000000000e+00-1.72015830311905j ,
+ a0hilb = np.array([ 0.000000000000000e+00-1.72015830311905j ,
1.000000000000000e+00-2.047794505137069j,
1.999999999999999e+00-2.244055555687583j,
3.000000000000000e+00-1.262750302935009j,
@@ -753,7 +749,6 @@ def test_hilbert_axisN(self):
class TestHilbert2(object):
def test_bad_args(self):
-
# x must be real.
x = np.array([[1.0 + 0.0j]])
assert_raises(ValueError, hilbert2, x)
Please sign in to comment.
Something went wrong with that request. Please try again.