Skip to content

Commit

Permalink
BUG: fix moments method to support arrays and list (#12197)
Browse files Browse the repository at this point in the history
* BUG: fix moments method to support arrays and list
* MAINT: stats: refactor moment function for array input
* TST, CI: stats: add hypothesis to CI and add tests
* Revert "TST, CI: stats: add hypothesis to CI and add tests"

Co-authored-by: Matt Haberland <mhaberla@calpoly.edu>
Co-authored-by: Pamphile ROY <roy.pamphile@gmail.com>
  • Loading branch information
3 people committed Aug 27, 2021
1 parent e742ae1 commit 9e08b05
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 14 deletions.
61 changes: 47 additions & 14 deletions scipy/stats/_distn_infrastructure.py
Expand Up @@ -1270,9 +1270,17 @@ def moment(self, n, *args, **kwds):
scale parameter (default=1)
"""
args, loc, scale = self._parse_args(*args, **kwds)
if not (self._argcheck(*args) and (scale > 0)):
return nan
shapes, loc, scale = self._parse_args(*args, **kwds)
args = np.broadcast_arrays(*(*shapes, loc, scale))
*shapes, loc, scale = args

i0 = np.logical_and(self._argcheck(*shapes), scale > 0)
i1 = np.logical_and(i0, loc == 0)
i2 = np.logical_and(i0, loc != 0)

args = argsreduce(i0, *shapes, loc, scale)
*shapes, loc, scale = args

if (floor(n) != n):
raise ValueError("Moment must be an integer.")
if (n < 0):
Expand All @@ -1283,21 +1291,46 @@ def moment(self, n, *args, **kwds):
mdict = {'moments': {1: 'm', 2: 'v', 3: 'vs', 4: 'vk'}[n]}
else:
mdict = {}
mu, mu2, g1, g2 = self._stats(*args, **mdict)
val = _moment_from_stats(n, mu, mu2, g1, g2, self._munp, args)
mu, mu2, g1, g2 = self._stats(*shapes, **mdict)
val = np.empty(loc.shape) # val needs to be indexed by loc
val[...] = _moment_from_stats(n, mu, mu2, g1, g2, self._munp, shapes)

# Convert to transformed X = L + S*Y
# E[X^n] = E[(L+S*Y)^n] = L^n sum(comb(n, k)*(S/L)^k E[Y^k], k=0...n)
if loc == 0:
return scale**n * val
else:
result = 0
fac = float(scale) / float(loc)
result = zeros(i0.shape)
place(result, ~i0, self.badvalue)

if i1.any():
res1 = scale[loc == 0]**n * val[loc == 0]
place(result, i1, res1)

if i2.any():
mom = [mu, mu2, g1, g2]
arrs = [i for i in mom if i is not None]
idx = [i for i in range(4) if mom[i] is not None]
if any(idx):
arrs = argsreduce(loc != 0, *arrs)
j = 0
for i in idx:
mom[i] = arrs[j]
j += 1
mu, mu2, g1, g2 = mom
args = argsreduce(loc != 0, *shapes, loc, scale, val)
*shapes, loc, scale, val = args

res2 = zeros(loc.shape, dtype='d')
fac = scale / loc
for k in range(n):
valk = _moment_from_stats(k, mu, mu2, g1, g2, self._munp, args)
result += comb(n, k, exact=True)*(fac**k) * valk
result += fac**n * val
return result * loc**n
valk = _moment_from_stats(k, mu, mu2, g1, g2, self._munp,
shapes)
res2 += comb(n, k, exact=True)*fac**k * valk
res2 += fac**n * val
res2 *= loc**n
place(result, i2, res2)

if result.ndim == 0:
return result.item()
return result

def median(self, *args, **kwds):
"""Median of the distribution.
Expand Down
109 changes: 109 additions & 0 deletions scipy/stats/tests/test_continuous_basic.py
Expand Up @@ -719,3 +719,112 @@ def test_burr_fisk_moment_gh13234_regression():

vals1 = stats.fisk.moment(1, 8)
assert isinstance(vals1, float)


def test_moments_with_array_gh12192_regression():
# array loc and scalar scale
vals0 = stats.norm.moment(n=1, loc=np.array([1, 2, 3]), scale=1)
expected0 = np.array([1., 2., 3.])
npt.assert_equal(vals0, expected0)

# array loc and invalid scalar scale
vals1 = stats.norm.moment(n=1, loc=np.array([1, 2, 3]), scale=-1)
expected1 = np.array([np.nan, np.nan, np.nan])
npt.assert_equal(vals1, expected1)

# array loc and array scale with invalid entries
vals2 = stats.norm.moment(n=1, loc=np.array([1, 2, 3]), scale=[-3, 1, 0])
expected2 = np.array([np.nan, 2., np.nan])
npt.assert_equal(vals2, expected2)

# (loc == 0) & (scale < 0)
vals3 = stats.norm.moment(n=2, loc=0, scale=-4)
expected3 = np.nan
npt.assert_equal(vals3, expected3)
assert isinstance(vals3, expected3.__class__)

# array loc with 0 entries and scale with invalid entries
vals4 = stats.norm.moment(n=2, loc=[1, 0, 2], scale=[3, -4, -5])
expected4 = np.array([10., np.nan, np.nan])
npt.assert_equal(vals4, expected4)

# all(loc == 0) & (array scale with invalid entries)
vals5 = stats.norm.moment(n=2, loc=[0, 0, 0], scale=[5., -2, 100.])
expected5 = np.array([25., np.nan, 10000.])
npt.assert_equal(vals5, expected5)

# all( (loc == 0) & (scale < 0) )
vals6 = stats.norm.moment(n=2, loc=[0, 0, 0], scale=[-5., -2, -100.])
expected6 = np.array([np.nan, np.nan, np.nan])
npt.assert_equal(vals6, expected6)

# scalar args, loc, and scale
vals7 = stats.chi.moment(n=2, df=1, loc=0, scale=0)
expected7 = np.nan
npt.assert_equal(vals7, expected7)
assert isinstance(vals7, expected7.__class__)

# array args, scalar loc, and scalar scale
vals8 = stats.chi.moment(n=2, df=[1, 2, 3], loc=0, scale=0)
expected8 = np.array([np.nan, np.nan, np.nan])
npt.assert_equal(vals8, expected8)

# array args, array loc, and array scale
vals9 = stats.chi.moment(n=2, df=[1, 2, 3], loc=[1., 0., 2.],
scale=[1., -3., 0.])
expected9 = np.array([3.59576912, np.nan, np.nan])
npt.assert_allclose(vals9, expected9, rtol=1e-8)

# (n > 4), all(loc != 0), and all(scale != 0)
vals10 = stats.norm.moment(5, [1., 2.], [1., 2.])
expected10 = np.array([26., 832.])
npt.assert_allclose(vals10, expected10, rtol=1e-13)

# test broadcasting and more
a = [-1.1, 0, 1, 2.2, np.pi]
b = [-1.1, 0, 1, 2.2, np.pi]
loc = [-1.1, 0, np.sqrt(2)]
scale = [-2.1, 0, 1, 2.2, np.pi]

a = np.array(a).reshape((-1, 1, 1, 1))
b = np.array(b).reshape((-1, 1, 1))
loc = np.array(loc).reshape((-1, 1))
scale = np.array(scale)

vals11 = stats.beta.moment(n=2, a=a, b=b, loc=loc, scale=scale)

a, b, loc, scale = np.broadcast_arrays(a, b, loc, scale)

for i in np.ndenumerate(a):
with np.errstate(invalid='ignore', divide='ignore'):
i = i[0] # just get the index
# check against same function with scalar input
expected = stats.beta.moment(n=2, a=a[i], b=b[i],
loc=loc[i], scale=scale[i])
np.testing.assert_equal(vals11[i], expected)


def test_broadcasting_in_moments_gh12192_regression():
vals0 = stats.norm.moment(n=1, loc=np.array([1, 2, 3]), scale=[[1]])
expected0 = np.array([[1., 2., 3.]])
npt.assert_equal(vals0, expected0)
assert vals0.shape == expected0.shape

vals1 = stats.norm.moment(n=1, loc=np.array([[1], [2], [3]]),
scale=[1, 2, 3])
expected1 = np.array([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]])
npt.assert_equal(vals1, expected1)
assert vals1.shape == expected1.shape

vals2 = stats.chi.moment(n=1, df=[1., 2., 3.], loc=0., scale=1.)
expected2 = np.array([0.79788456, 1.25331414, 1.59576912])
npt.assert_allclose(vals2, expected2, rtol=1e-8)
assert vals2.shape == expected2.shape

vals3 = stats.chi.moment(n=1, df=[[1.], [2.], [3.]], loc=[0., 1., 2.],
scale=[-1., 0., 3.])
expected3 = np.array([[np.nan, np.nan, 4.39365368],
[np.nan, np.nan, 5.75994241],
[np.nan, np.nan, 6.78730736]])
npt.assert_allclose(vals3, expected3, rtol=1e-8)
assert vals3.shape == expected3.shape

0 comments on commit 9e08b05

Please sign in to comment.