Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH:stats:Use explicit formula for gamma.fit('mm') #19932

Merged
merged 3 commits into from Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 31 additions & 5 deletions scipy/stats/_continuous_distns.py
Expand Up @@ -3393,8 +3393,8 @@ def fit(self, data, *args, **kwds):
floc = kwds.get('floc', None)
method = kwds.get('method', 'mle')

if (isinstance(data, CensoredData) or floc is None
or method.lower() == 'mm'):
if (isinstance(data, CensoredData) or
floc is None and method.lower() != 'mm'):
# loc is not fixed or we're not doing standard MLE.
# Use the default fit method.
return super().fit(data, *args, **kwds)
Expand All @@ -3407,9 +3407,7 @@ def fit(self, data, *args, **kwds):

_remove_optimizer_parameters(kwds)

# Special case: loc is fixed.

if f0 is not None and fscale is not None:
if f0 is not None and floc is not None and fscale is not None:
# This check is for consistency with `rv_continuous.fit`.
# Without this check, this function would just return the
# parameters that were given.
Expand All @@ -3422,6 +3420,34 @@ def fit(self, data, *args, **kwds):
if not np.isfinite(data).all():
raise ValueError("The data contains non-finite values.")

# Use explicit formulas for mm (gh-19884)
if method.lower() == 'mm':
m1 = np.mean(data)
m2 = np.var(data)
m3 = np.mean((data - m1) ** 3)
a, loc, scale = f0, floc, fscale
# Three unknowns
if a is None and loc is None and scale is None:
scale = m3 / (2 * m2)
# Two unknowns
if loc is None and scale is None:
scale = np.sqrt(m2 / a)
if a is None and scale is None:
scale = m2 / (m1 - loc)
if a is None and loc is None:
a = m2 / (scale ** 2)
# One unknown
if a is None:
a = (m1 - loc) / scale
if loc is None:
loc = m1 - a * scale
if scale is None:
scale = (m1 - loc) / a
return a, loc, scale

# Special case: loc is fixed.

# NB: data == loc is ok if a >= 1; the below check is more strict.
mdhaber marked this conversation as resolved.
Show resolved Hide resolved
if np.any(data <= floc):
raise FitDataError("gamma", lower=floc, upper=np.inf)

Expand Down
35 changes: 35 additions & 0 deletions scipy/stats/tests/test_distributions.py
Expand Up @@ -4672,6 +4672,41 @@ def test_entropy(self, a, ref, rtol):

assert_allclose(stats.gamma.entropy(a), ref, rtol=rtol)

@pytest.mark.parametrize("a", [1e-2, 1, 1e2])
@pytest.mark.parametrize("loc", [1e-2, 0, 1e2])
@pytest.mark.parametrize('scale', [1e-2, 1, 1e2])
@pytest.mark.parametrize('fix_a', [True, False])
@pytest.mark.parametrize('fix_loc', [True, False])
@pytest.mark.parametrize('fix_scale', [True, False])
def test_fit_mm(self, a, loc, scale, fix_a, fix_loc, fix_scale):
rng = np.random.default_rng(6762668991392531563)
data = stats.gamma.rvs(a, loc=loc, scale=scale, size=100,
random_state=rng)

kwds = {}
if fix_a:
kwds['fa'] = a
if fix_loc:
kwds['floc'] = loc
if fix_scale:
kwds['fscale'] = scale
nfree = 3 - len(kwds)

if nfree == 0:
error_msg = "All parameters fixed. There is nothing to optimize."
with pytest.raises(ValueError, match=error_msg):
stats.halfcauchy.fit(data, method='mm', **kwds)
return
Comment on lines +4697 to +4699
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why halfcauchy here if the PR is about gamma? Copy Paste hickup? @fancidev

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why halfcauchy here if the PR is about gamma? Copy Paste hickup? @fancidev

Oops, yes exactly. Let me make a PR to correct it.


theta = stats.gamma.fit(data, method='mm', **kwds)
dist = stats.gamma(*theta)
if nfree >= 1:
assert_allclose(dist.mean(), np.mean(data))
if nfree >= 2:
assert_allclose(dist.moment(2), np.mean(data**2))
if nfree >= 3:
assert_allclose(dist.moment(3), np.mean(data**3))

def test_pdf_overflow_gh19616():
# Confirm that gh19616 (intermediate over/underflows in PDF) is resolved
# Reference value from R GeneralizedHyperbolic library
Expand Down