Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache covariance matrix decomposition in frozen multivariate_normal #11772

Closed
Balandat opened this issue Mar 31, 2020 · 5 comments
Closed

Cache covariance matrix decomposition in frozen multivariate_normal #11772

Balandat opened this issue Mar 31, 2020 · 5 comments
Labels
enhancement A new feature or improvement scipy.stats
Milestone

Comments

@Balandat
Copy link
Contributor

Currently, it seems that a frozen multivariate_normal distribution unnecessarily re-computes the root decomposition (and other properties such as evals for the logpdf) of the covariance matrix for each operation. For instance, when sampling scipy just calls the sampling of the underlying distribution with the full covariance matrix:

return self._dist.rvs(self.mean, self.cov, size, random_state)

This is super wasteful, as most of the computation is in fact in computing this decomposition.

What should happen instead is that the frozen object should also have a factor or L attribute s.t. L @ L.T = cov, and then it would compute that only once (upon the first iteratio), and in later steps pass that into the sampling method instead of cov, avoiding a bunch of unnecessary compute.

Torch's mvn does this (though it doesn't allow singular matrices at this point): https://github.com/pytorch/pytorch/blob/master/torch/distributions/multivariate_normal.py#L146-L15

@siddhantwahal
Copy link
Contributor

The rvs method indeed requires recomputing the square root of the covariance.

One approach to fix this could be adding an extra attribute to _PSD, say L as: self.L = np.multiply(u, np.sqrt(s)), where s, u is the eigendecomposition of the covariance matrix.

Then, the square root of the covariance is available in multivariate_normal_frozen as self.cov_info.L, and can be employed in rvs.

I think only rvs needs modifying. As far as I can tell, logpdf uses self.U, which is a square root of the precision matrix.

Happy to work on this if no one else is.

@rlucas7
Copy link
Member

rlucas7 commented May 28, 2020 via email

@mdhaber
Copy link
Contributor

mdhaber commented Oct 18, 2022

I think this has been addressed by the Covariance class. Do you agree @tirthasheshpatel @tupui? Shall we close this one?

@tirthasheshpatel
Copy link
Member

tirthasheshpatel commented Oct 18, 2022

@mdhaber Yup, Covariance should address the issue. Feel free to close this!

@Balandat FYI: Now, we can specify the covariance using the factory methods provided by the stats.Covariance class! Here's an example:

import numpy as np
from scipy import stats

cov = np.eye(3)
mu = np.zeros(3)

# only compute the eigen-decomposition once
w, v = np.linalg.eigh(cov)
cov_object = stats.Covariance.from_eigendecomposition((w, v))

# use the cov_object instead
dist = stats.multivariate_normal(mu, cov_object)
dist.pdf([0, 0, 0])

# also, notice that the covariance object is diagonal
# we can use the special property to make computation
# of pdf, ... more efficient
cov_object = stats.Covariance.from_diagonal(np.diag(cov))
stats.multivariate_normal.pdf([0, 0, 0], mu, cov_object)  # uses the special property of the
                                                          # covariance matrix to compute pdf
                                                          # more efficiently.

Feel free to refer to the devdocs here for more examples.

@Balandat
Copy link
Contributor Author

Thanks! Let me use the opportunity to make a shameless plug for a related (PyTorch) project for structured linear algebra that can also encode various structured covariance matrices: https://github.com/cornellius-gp/linear_operator

@mdhaber mdhaber closed this as completed Oct 18, 2022
@mdhaber mdhaber added this to the 1.10.0 milestone Nov 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement A new feature or improvement scipy.stats
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants