Skip to content

Commit

Permalink
DOC: Improved documentation of matrix_normal.
Browse files Browse the repository at this point in the history
  • Loading branch information
drpeteb committed Nov 22, 2015
1 parent 61d5282 commit 1d7dd4a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
26 changes: 21 additions & 5 deletions scipy/stats/_multivariate.py
Expand Up @@ -669,7 +669,7 @@ class matrix_normal_gen(multi_rv_generic):
A matrix normal random variable.
The `mean` keyword specifies the mean. The `rowcov` keyword specifies the
among-row covariance matrix. The 'colcov' keyword specified the
among-row covariance matrix. The 'colcov' keyword specifies the
among-column covariance matrix.
Methods
Expand Down Expand Up @@ -700,10 +700,10 @@ class matrix_normal_gen(multi_rv_generic):
-----
%(_matnorm_doc_callparams_note)s
The covariance matrices `rowcov` and `colcov` must be (symmetric) positive
definite. If the samples in `X` are :math:`m \times n`, then `rowcov`
must be :math:`m \times m` and `colcov` must be :math:`n \times n`.
`mean` must be the same shape as `X`.
The covariance matrices specified by `rowcov` and `colcov` must be
(symmetric) positive definite. If the samples in `X` are
:math:`m \times n`, then `rowcov` must be :math:`m \times m` and
`colcov` must be :math:`n \times n`. `mean` must be the same shape as `X`.
The probability density function for `matrix_normal` is
Expand All @@ -720,6 +720,12 @@ class matrix_normal_gen(multi_rv_generic):
distribution is not currently supported. Covariance matrices must be
full rank.
The `matrix_normal` distribution is closely related to the
`multivariate_normal` distribution. Specifically, :math:`\mathrm{Vec}(X)`
(the vectorisation of :math:`X`) has a multivariate normal distribution
with mean :math:`\mathrm{Vec}(M)` and covariance :math:`V \otimes U`
(where :math:`\otimes` is the Kronecker product).
.. versionadded:: 0.17.0
Examples
Expand All @@ -744,6 +750,14 @@ class matrix_normal_gen(multi_rv_generic):
[ 4.1, 5.1]])
>>> matrix_normal.pdf(X, mean=M, rowcov=U, colcov=V)
0.023410202050005054
>>> # Equivalent multivariate normal
>>> from scipy.stats import multivariate_normal
>>> vectorised_X = X.T.flatten()
>>> equiv_mean = M.T.flatten()
>>> equiv_cov = np.kron(V,U)
>>> multivariate_normal.pdf(vectorised_X, mean=equiv_mean, cov=equiv_cov)
0.023410202050005054
"""

def __init__(self, seed=None):
Expand Down Expand Up @@ -966,6 +980,7 @@ def rvs(self, mean=None, rowcov=1, colcov=1, size=1, random_state=None):

matrix_normal = matrix_normal_gen()


class matrix_normal_frozen(multi_rv_frozen):
def __init__(self, mean=None, rowcov=1, colcov=1, seed=None):
"""
Expand Down Expand Up @@ -1013,6 +1028,7 @@ def rvs(self, size=1, random_state=None):
return self._dist.rvs(self.mean, self.rowcov, self.colcov, size,
random_state)


# Set frozen generator docstrings from corresponding docstrings in
# matrix_normal_gen and fill in default strings in class docstrings
for name in ['logpdf', 'pdf', 'rvs']:
Expand Down
21 changes: 0 additions & 21 deletions scipy/stats/tests/test_multivariate.py
Expand Up @@ -475,27 +475,6 @@ def test_matches_multivariate(self):
assert_allclose(pdf1, pdf2, rtol=1E-10)
assert_allclose(logpdf1, logpdf2, rtol=1E-10)

def test_equivalent_multivariate(self):
# Check that the equivalent multivariate distribution matches
for i in range(1,5):
for j in range(1,5):
M = 0.3 * np.ones((i,j))
U = 0.5 * np.identity(i) + 0.5 * np.ones((i,i))
V = 0.7 * np.identity(j) + 0.3 * np.ones((j,j))

frozen = matrix_normal(mean=M, rowcov=U, colcov=V)
X = frozen.rvs(random_state=1234)
pdf1 = frozen.pdf(X)
logpdf1 = frozen.logpdf(X)

mvn = frozen.equivalent_multivariate_normal()
vecX = X.T.flatten()
pdf2 = mvn.pdf(vecX)
logpdf2 = mvn.logpdf(vecX)

assert_allclose(pdf1, pdf2, rtol=1E-10)
assert_allclose(logpdf1, logpdf2, rtol=1E-10)

def test_array_input(self):
# Check array of inputs has the same output as the separate entries.
num_rows = 4
Expand Down

0 comments on commit 1d7dd4a

Please sign in to comment.