Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add possibility to fix parameters in multivariate normal fit #18986

Merged
merged 12 commits into from Aug 28, 2023
42 changes: 37 additions & 5 deletions scipy/stats/_multivariate.py
Expand Up @@ -787,7 +787,7 @@ def entropy(self, mean=None, cov=1):
dim, mean, cov_object = self._process_parameters(mean, cov)
return 0.5 * (cov_object.rank * (_LOG_2PI + 1) + cov_object.log_pdet)

def fit(self, x):
def fit(self, x, fix_mean=None, fix_cov=None):
"""Fit a multivariate normal distribution to data.

Parameters
Expand All @@ -797,6 +797,10 @@ def fit(self, x):
The first axis of length `m` represents the number of vectors
the distribution is fitted to. The second axis of length `n`
determines the dimensionality of the fitted distribution.
fix_mean : ndarray(n, )
Fixed mean vector. Must have length `n`.
fix_cov: ndarray (n, n)
Fixed covariance matrix. Must have shape `(n, n)`.

Returns
-------
Expand All @@ -806,7 +810,7 @@ def fit(self, x):
Maximum likelihood estimate of the covariance matrix

"""
# input validation
# input validation for data to be fitted
x = np.asarray(x)
if x.ndim != 2:
raise ValueError("`x` must be two-dimensional.")
Expand All @@ -815,9 +819,37 @@ def fit(self, x):

# parameter estimation
# reference: https://home.ttic.edu/~shubhendu/Slides/Estimation.pdf
mean = x.mean(axis=0)
centered_data = x - mean
cov = centered_data.T @ centered_data / n_vectors
if fix_mean is not None:
# input validation for `fix_mean`
fix_mean = np.atleast_1d(fix_mean)
if fix_mean.shape != (dim, ):
msg = ("`fix_mean` must be a one-dimensional array the same "
"length as the dimensionality of the vectors `x`.")
raise ValueError(msg)
mean = fix_mean
else:
mean = x.mean(axis=0)

if fix_cov is not None:
dschmitz89 marked this conversation as resolved.
Show resolved Hide resolved
# input validation for `fix_cov`
fix_cov = np.atleast_2d(fix_cov)
# validate shape
if fix_cov.shape != (dim, dim):
msg = ("`fix_cov` must be a two-dimensional square array "
"of same side length as the dimensionality of the "
"vectors `x`.")
raise ValueError(msg)
# validate positive semidefiniteness
# a trimmed down copy from _PSD
s, u = scipy.linalg.eigh(fix_cov, lower=True, check_finite=True)
eps = _eigvalsh_to_eps(s)
if np.min(s) < -eps:
msg = "`fix_cov` must be symmetric positive semidefinite."
raise ValueError(msg)
cov = fix_cov
else:
centered_data = x - mean
cov = centered_data.T @ centered_data / n_vectors
return mean, cov


Expand Down
91 changes: 90 additions & 1 deletion scipy/stats/tests/test_multivariate.py
Expand Up @@ -803,7 +803,7 @@ def test_mean_cov(self):
ref = multivariate_normal.pdf(x, [1, 1, 1], cov_object)
assert_equal(multivariate_normal.pdf(x, 1, cov=cov_object), ref)

def test_fit_error(self):
def test_fit_wrong_fit_data_shape(self):
data = [1, 3]
error_msg = "`x` must be two-dimensional."
with pytest.raises(ValueError, match=error_msg):
Expand All @@ -818,6 +818,95 @@ def test_fit_correctness(self, dim):
assert_allclose(mean_est, mean_ref, atol=1e-15)
assert_allclose(cov_est, cov_ref, rtol=1e-15)

def test_fit_both_parameters_fixed(self):
data = np.full((2, 1), 3)
mean_fixed = 1.
cov_fixed = np.atleast_2d(1.)
mean, cov = multivariate_normal.fit(data, fix_mean=mean_fixed,
fix_cov=cov_fixed)
assert_equal(mean, mean_fixed)
assert_equal(cov, cov_fixed)

@pytest.mark.parametrize('fix_mean', [np.zeros((2, 2)),
np.zeros((3, ))])
def test_fit_fix_mean_input_validation(self, fix_mean):
msg = ("`fix_mean` must be a one-dimensional array the same "
"length as the dimensionality of the vectors `x`.")
with pytest.raises(ValueError, match=msg):
multivariate_normal.fit(np.eye(2), fix_mean=fix_mean)

@pytest.mark.parametrize('fix_cov', [np.zeros((2, )),
np.zeros((3, 2)),
np.zeros((4, 4))])
def test_fit_fix_cov_input_validation_dimension(self, fix_cov):
msg = ("`fix_cov` must be a two-dimensional square array "
"of same side length as the dimensionality of the "
"vectors `x`.")
with pytest.raises(ValueError, match=msg):
multivariate_normal.fit(np.eye(3), fix_cov=fix_cov)

def test_fit_fix_cov_not_positive_semidefinite(self):
error_msg = "`fix_cov` must be symmetric positive semidefinite."
with pytest.raises(ValueError, match=error_msg):
fix_cov = np.array([[1., 0.], [0., -1.]])
multivariate_normal.fit(np.eye(2), fix_cov=fix_cov)

def test_fit_fix_mean(self):
rng = np.random.default_rng(4385269356937404)
loc = rng.random(3)
A = rng.random((3, 3))
cov = np.dot(A, A.T)
samples = multivariate_normal.rvs(mean=loc, cov=cov, size=100,
random_state=rng)
mean_free, cov_free = multivariate_normal.fit(samples)
logp_free = multivariate_normal.logpdf(samples, mean=mean_free,
cov=cov_free).sum()
mean_fix, cov_fix = multivariate_normal.fit(samples, fix_mean=loc)
assert_equal(mean_fix, loc)
logp_fix = multivariate_normal.logpdf(samples, mean=mean_fix,
cov=cov_fix).sum()
# test that fixed parameters result in lower likelihood than free
# parameters
assert logp_fix < logp_free
mdhaber marked this conversation as resolved.
Show resolved Hide resolved
# test that a small perturbation of the resulting parameters
# has lower likelihood than the estimated parameters
A = rng.random((3, 3))
m = 1e-8 * np.dot(A, A.T)
cov_perturbed = cov_fix + m
logp_perturbed = (multivariate_normal.logpdf(samples,
mean=mean_fix,
cov=cov_perturbed)
).sum()
assert logp_perturbed < logp_fix


def test_fit_fix_cov(self):
rng = np.random.default_rng(4385269356937404)
loc = rng.random(3)
A = rng.random((3, 3))
cov = np.dot(A, A.T)
samples = multivariate_normal.rvs(mean=loc, cov=cov,
size=100, random_state=rng)
mean_free, cov_free = multivariate_normal.fit(samples)
logp_free = multivariate_normal.logpdf(samples, mean=mean_free,
cov=cov_free).sum()
mean_fix, cov_fix = multivariate_normal.fit(samples, fix_cov=cov)
assert_equal(mean_fix, np.mean(samples, axis=0))
assert_equal(cov_fix, cov)
logp_fix = multivariate_normal.logpdf(samples, mean=mean_fix,
cov=cov_fix).sum()
# test that fixed parameters result in lower likelihood than free
# parameters
assert logp_fix < logp_free
# test that a small perturbation of the resulting parameters
# has lower likelihood than the estimated parameters
mean_perturbed = mean_fix + 1e-8 * rng.random(3)
logp_perturbed = (multivariate_normal.logpdf(samples,
mean=mean_perturbed,
cov=cov_fix)
).sum()
assert logp_perturbed < logp_fix


class TestMatrixNormal:

Expand Down