Skip to content

Commit

Permalink
Use scipy implementation of trapz if available
Browse files Browse the repository at this point in the history
  • Loading branch information
apdavison committed Dec 8, 2023
1 parent 45dfd60 commit dfdae84
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions quantities/umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,37 +200,44 @@ def trapz(y, x=None, dx=1.0, axis=-1):
else:
ret = _trapz(y.magnitude , x.magnitude, dx.magnitude, axis)
return Quantity ( ret, y.units * x.units)

def _trapz(y, x, dx, axis):
"""ported from numpy 1.26 since it will be deprecated and removed"""
from numpy.core.numeric import asanyarray
from numpy.core.umath import add
y = asanyarray(y)
if x is None:
d = dx
else:
x = asanyarray(x)
if x.ndim == 1:
d = diff(x)
# reshape to correct shape
shape = [1]*y.ndim
shape[axis] = d.shape[0]
d = d.reshape(shape)
else:
d = diff(x, axis=axis)
nd = y.ndim
slice1 = [slice(None)]*nd
slice2 = [slice(None)]*nd
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
try:
ret = (d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0).sum(axis)
except ValueError:
# Operations didn't work, cast to ndarray
d = np.asarray(d)
y = np.asarray(y)
ret = add.reduce(d * (y[tuple(slice1)]+y[tuple(slice2)])/2.0, axis)
return ret
# if scipy is available, we use it
from scipy.integrate import trapezoid # type: ignore
except ImportError:
# otherwise we use the implementation ported from numpy 1.26
from numpy.core.numeric import asanyarray
from numpy.core.umath import add
y = asanyarray(y)
if x is None:
d = dx
else:
x = asanyarray(x)
if x.ndim == 1:
d = diff(x)
# reshape to correct shape
shape = [1]*y.ndim
shape[axis] = d.shape[0]
d = d.reshape(shape)
else:
d = diff(x, axis=axis)
nd = y.ndim
slice1 = [slice(None)]*nd
slice2 = [slice(None)]*nd
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
try:
ret = (d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0).sum(axis)
except ValueError:
# Operations didn't work, cast to ndarray
d = np.asarray(d)
y = np.asarray(y)
ret = add.reduce(d * (y[tuple(slice1)]+y[tuple(slice2)])/2.0, axis)
return ret
else:
return trapezoid(y, x=x, dx=dx, axis=axis)

@with_doc(np.sin)
def sin(x, out=None):
Expand Down

0 comments on commit dfdae84

Please sign in to comment.