Skip to content

Commit

Permalink
fix the previous fix that was breaking cumtrapz when y and x where nd…
Browse files Browse the repository at this point in the history
… and add tests for y nd and x 1d.
  • Loading branch information
francisco-dlp committed Nov 17, 2013
1 parent b84900d commit 060eda8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
17 changes: 15 additions & 2 deletions scipy/integrate/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,21 @@ def cumtrapz(y, x=None, dx=1.0, axis=-1, initial=None):
if x is None:
d = dx
else:
d = diff(x)

x = asarray(x)
if x.ndim == 1:
d = diff(x)
# reshape to correct shape
shape = [1]*y.ndim
shape[axis] = -1
d = d.reshape(shape)
elif len(x.shape) != len(y.shape):
raise ValueError("If given, shape of x must be 1-d or the "
"same as y.")
else:
d = diff(x, axis=axis)
if d.shape[axis] != y.shape[axis] - 1:
raise ValueError("If given, length of x along axis must be the "
"same as y.")
nd = len(y.shape)
slice1 = tupleset((slice(None),)*nd, axis, slice(1, None))
slice2 = tupleset((slice(None),)*nd, axis, slice(None, -1))
Expand Down
29 changes: 28 additions & 1 deletion scipy/integrate/tests/test_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_1d(self):
y_int = cumtrapz(y, x, initial=None)
assert_allclose(y_int, y_expected[1:])

def test_nd(self):
def test_y_nd_x_nd(self):
x = np.arange(3 * 2 * 4).reshape(3, 2, 4)
y = x
y_int = cumtrapz(y, x, initial=0)
Expand All @@ -142,6 +142,33 @@ def test_nd(self):
y_int = cumtrapz(y, x, initial=None, axis=axis)
assert_equal(y_int.shape, shape)

def test_y_nd_x_1d(self):
y = np.arange(3 * 2 * 4).reshape(3, 2, 4)
x = np.arange(4) ** 2
# Try with all axes
ys_expected = (
np.array([[[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],

[[ 40., 44., 48., 52.],
[ 56., 60., 64., 68.]]]),
np.array([[[ 2., 3., 4., 5.]],

[[ 10., 11., 12., 13.]],

[[ 18., 19., 20., 21.]]]),
np.array([[[ 0.5, 5. , 17.5],
[ 4.5, 21. , 53.5]],

[[ 8.5, 37. , 89.5],
[ 12.5, 53. , 125.5]],

[[ 16.5, 69. , 161.5],
[ 20.5, 85. , 197.5]]]))
for axis, y_expected in zip([0, 1, 2], ys_expected):
y_int = cumtrapz(y, x=x[:y.shape[axis]], axis=axis, initial=None)
assert_allclose(y_int, y_expected)

def test_x_none(self):
y = np.linspace(-2, 2, num=5)

Expand Down

0 comments on commit 060eda8

Please sign in to comment.