Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: speed up multivariate_normal.logpdf for nonsingular matrices #9973

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions scipy/stats/_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def pinv(self):
return self._pinv


class _Cho(object):
def __init__(self, M):
self.cho = scipy.linalg.cho_factor(M)
self.log_pdet = 2 * np.log(self.cho[0].diagonal()).sum()
self.rank = M.shape[0]


class multi_rv_generic(object):
"""
Class which encapsulates common functionality between all multivariate
Expand Down Expand Up @@ -443,7 +450,7 @@ def _process_quantiles(self, x, dim):

return x

def _logpdf(self, x, mean, prec_U, log_det_cov, rank):
def _logpdf_arbitrary(self, x, mean, prec_U, log_det_cov, rank):
"""
Parameters
----------
Expand All @@ -470,6 +477,14 @@ def _logpdf(self, x, mean, prec_U, log_det_cov, rank):
maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1)
return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)

def _logpdf_nonsingular(self, x, mean, cho, log_det_cov, rank):
dev = x - mean

prec_times_dev_t = scipy.linalg.cho_solve(cho, dev.T)
maha = (dev * prec_times_dev_t.T).sum(axis=-1)

return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)

def logpdf(self, x, mean=None, cov=1, allow_singular=False):
"""
Log of the multivariate normal probability density function.
Expand All @@ -492,8 +507,13 @@ def logpdf(self, x, mean=None, cov=1, allow_singular=False):
"""
dim, mean, cov = self._process_parameters(None, mean, cov)
x = self._process_quantiles(x, dim)
psd = _PSD(cov, allow_singular=allow_singular)
out = self._logpdf(x, mean, psd.U, psd.log_pdet, psd.rank)
if allow_singular:
psd = _PSD(cov, allow_singular=True)
out = self._logpdf_arbitrary(x, mean, psd.U, psd.log_pdet, psd.rank)
else:
cho = _Cho(cov)
out = self._logpdf_nonsingular(x, mean, cho.cho, cho.log_pdet, cho.rank)

return _squeeze_output(out)

def pdf(self, x, mean=None, cov=1, allow_singular=False):
Expand All @@ -516,11 +536,7 @@ def pdf(self, x, mean=None, cov=1, allow_singular=False):
%(_mvn_doc_callparams_note)s

"""
dim, mean, cov = self._process_parameters(None, mean, cov)
x = self._process_quantiles(x, dim)
psd = _PSD(cov, allow_singular=allow_singular)
out = np.exp(self._logpdf(x, mean, psd.U, psd.log_pdet, psd.rank))
return _squeeze_output(out)
return np.exp(self.logpdf(x, mean, cov, allow_singular))

def _cdf(self, x, mean, cov, maxpts, abseps, releps):
"""
Expand Down Expand Up @@ -730,10 +746,14 @@ def __init__(self, mean=None, cov=1, allow_singular=False, seed=None,
array([[1.]])

"""
self.allow_singular = allow_singular
self._dist = multivariate_normal_gen(seed)
self.dim, self.mean, self.cov = self._dist._process_parameters(
None, mean, cov)
self.cov_info = _PSD(self.cov, allow_singular=allow_singular)
if self.allow_singular:
self.cov_info = _PSD(self.cov, allow_singular=allow_singular)
else:
self.cov_info = _Cho(self.cov)
if not maxpts:
maxpts = 1000000 * self.dim
self.maxpts = maxpts
Expand All @@ -742,8 +762,16 @@ def __init__(self, mean=None, cov=1, allow_singular=False, seed=None,

def logpdf(self, x):
x = self._dist._process_quantiles(x, self.dim)
out = self._dist._logpdf(x, self.mean, self.cov_info.U,
self.cov_info.log_pdet, self.cov_info.rank)
if self.allow_singular:
out = self._dist._logpdf_arbitrary(
x, self.mean, self.cov_info.U,
self.cov_info.log_pdet, self.cov_info.rank
)
else:
out = self._dist._logpdf_nonsingular(
x, self.mean, self.cov_info.cho,
self.cov_info.log_pdet, self.cov_info.rank
)
return _squeeze_output(out)

def pdf(self, x):
Expand Down