# Matrix Normal distribution (`matrix_normal`) — Gaussian random matrices with separable covariance

The **matrix normal distribution** is the matrix-valued analogue of the multivariate normal.
It models a random matrix \(X \in \mathbb{R}^{n\times p}\) whose dependence structure **separates** into:

- a **row covariance** matrix \(U \in \mathbb{R}^{n\times n}\), and
- a **column covariance** matrix \(V \in \mathbb{R}^{p\times p}\).

Equivalently, if you stack the columns of \(X\) into a vector (`vec`), then
\(\mathrm{vec}(X)\) is multivariate normal with covariance \(V \otimes U\) (a Kronecker product).

**Goals**
- Understand what the matrix normal models and why the Kronecker structure is useful.
- Work with the PDF, moments, entropy, and key identities.
- Interpret parameters \(M, U, V\) and see how they shape samples.
- Sample from \(\mathrm{MN}(M, U, V)\) with a **NumPy-only** algorithm.
- Use `scipy.stats.matrix_normal` for evaluation and simulation, and fit \(U,V\) via an alternating MLE.

**Prerequisites**
- Multivariate normal distribution (quadratic forms, covariance).
- Basic linear algebra (trace, determinant, Cholesky).
- NumPy + plotting.


## Notebook roadmap

1. Title & Classification
2. Intuition & Motivation
3. Formal Definition
4. Moments & Properties
5. Parameter Interpretation
6. Derivations
7. Sampling & Simulation
8. Visualization
9. SciPy Integration
10. Statistical Use Cases
11. Pitfalls
12. Summary


In [None]:
import platform

import numpy as np
import scipy
from scipy import stats
from scipy.stats import chi2, multivariate_normal, norm

import plotly
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

# Plotly notebook defaults (mirrors patterns used elsewhere in this repo)
pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

np.set_printoptions(precision=4, suppress=True)

# Reproducibility
rng = np.random.default_rng(0)

# Quick/slow toggle
FAST_RUN = True
N_SAMPLES = 15_000 if FAST_RUN else 120_000
N_EXPERIMENTS = 250 if FAST_RUN else 2_000

print("Python", platform.python_version())
print("NumPy", np.__version__)
print("SciPy", scipy.__version__)
print("Plotly", plotly.__version__)


In [None]:
def vec_f(A: np.ndarray) -> np.ndarray:
    """Column-stacking vectorization (Fortran order).

    With this convention the Kronecker covariance is `np.kron(V, U)`.
    """

    A = np.asarray(A)
    return A.reshape(-1, order="F")


def unvec_f(v: np.ndarray, shape: tuple[int, int]) -> np.ndarray:
    v = np.asarray(v)
    return v.reshape(shape, order="F")


def ar1_cov(n: int, rho: float, sigma2: float = 1.0) -> np.ndarray:
    """AR(1)-style covariance: sigma2 * rho^{|i-j|}. SPD for |rho| < 1."""

    if n <= 0:
        raise ValueError("n must be positive")
    if not (-1.0 < rho < 1.0):
        raise ValueError("rho must be in (-1, 1)")
    idx = np.arange(n)
    return sigma2 * (rho ** np.abs(idx[:, None] - idx[None, :]))


def chol_spd(A: np.ndarray, name: str = "A", jitter: float = 1e-12, max_tries: int = 6) -> np.ndarray:
    """Cholesky factor of an SPD matrix with optional diagonal jitter.

    Returns lower-triangular L such that A ≈ L L^T.
    """

    A = np.asarray(A, dtype=float)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError(f"{name} must be a square matrix")

    A = 0.5 * (A + A.T)  # symmetrize to reduce numerical asymmetry
    eye = np.eye(A.shape[0])

    for k in range(max_tries):
        try:
            return np.linalg.cholesky(A + (jitter * (10**k)) * eye)
        except np.linalg.LinAlgError:
            pass

    raise ValueError(f"{name} must be symmetric positive definite (Cholesky failed)")


def validate_matrix_normal_params(mean: np.ndarray, rowcov: np.ndarray, colcov: np.ndarray) -> tuple[int, int]:
    mean = np.asarray(mean, dtype=float)
    if mean.ndim != 2:
        raise ValueError("mean must be a 2D matrix")

    n, p = mean.shape

    rowcov = np.asarray(rowcov, dtype=float)
    colcov = np.asarray(colcov, dtype=float)

    if rowcov.shape != (n, n):
        raise ValueError(f"rowcov must have shape ({n}, {n})")
    if colcov.shape != (p, p):
        raise ValueError(f"colcov must have shape ({p}, {p})")

    _ = chol_spd(rowcov, name="rowcov")
    _ = chol_spd(colcov, name="colcov")

    return n, p


def matrix_normal_rvs_numpy(
    mean: np.ndarray,
    rowcov: np.ndarray,
    colcov: np.ndarray,
    size: int = 1,
    rng: np.random.Generator | None = None,
) -> np.ndarray:
    """Sample from MN(mean, rowcov, colcov) using only NumPy.

    Algorithm:
    1) Factorize rowcov = L_u L_u^T and colcov = L_v L_v^T.
    2) Draw Z with i.i.d. N(0,1) entries.
    3) Return mean + L_u Z L_v^T.

    Returns:
    - if size == 1: array (n, p)
    - else: array (size, n, p)
    """

    if rng is None:
        rng = np.random.default_rng()

    mean = np.asarray(mean, dtype=float)
    rowcov = np.asarray(rowcov, dtype=float)
    colcov = np.asarray(colcov, dtype=float)
    n, p = validate_matrix_normal_params(mean, rowcov, colcov)

    L_u = chol_spd(rowcov, name="rowcov")
    L_v = chol_spd(colcov, name="colcov")

    Z = rng.standard_normal(size=(size, n, p))
    out = np.empty_like(Z)

    for i in range(size):
        out[i] = mean + L_u @ Z[i] @ L_v.T

    return out[0] if size == 1 else out


def matrix_normal_logpdf_numpy(
    x: np.ndarray,
    mean: np.ndarray,
    rowcov: np.ndarray,
    colcov: np.ndarray,
) -> np.ndarray:
    """Log-PDF of X ~ MN(mean, rowcov, colcov) using NumPy linear algebra.

    Supports:
    - x shape (n, p) -> returns scalar
    - x shape (m, n, p) -> returns length-m array
    """

    mean = np.asarray(mean, dtype=float)
    rowcov = np.asarray(rowcov, dtype=float)
    colcov = np.asarray(colcov, dtype=float)

    n, p = validate_matrix_normal_params(mean, rowcov, colcov)

    x = np.asarray(x, dtype=float)
    if x.shape[-2:] != (n, p):
        raise ValueError(f"x must end with shape ({n}, {p})")

    L_u = chol_spd(rowcov, name="rowcov")
    L_v = chol_spd(colcov, name="colcov")

    logdet_u = 2.0 * np.sum(np.log(np.diag(L_u)))
    logdet_v = 2.0 * np.sum(np.log(np.diag(L_v)))

    const = (
        -0.5 * n * p * np.log(2.0 * np.pi)
        -0.5 * p * logdet_u
        -0.5 * n * logdet_v
    )

    def quad_form(E: np.ndarray) -> float:
        # || L_u^{-1} E L_v^{-T} ||_F^2 computed via two triangular solves
        A = np.linalg.solve(L_u, E)  # (n, p)
        B = np.linalg.solve(L_v, A.T)  # (p, n)
        return float(np.sum(B**2))

    if x.ndim == 2:
        E = x - mean
        return const - 0.5 * quad_form(E)

    if x.ndim == 3:
        out = np.empty(x.shape[0], dtype=float)
        for i in range(x.shape[0]):
            E = x[i] - mean
            out[i] = const - 0.5 * quad_form(E)
        return out

    raise ValueError("x must be 2D or 3D")


## 1) Title & Classification

| Item | Value |
|---|---|
| Name | Matrix Normal (`matrix_normal`) |
| Type | **Continuous** |
| Support | \(X \in \mathbb{R}^{n\times p}\) (all real matrices of a fixed shape) |
| Parameters | mean matrix \(M \in \mathbb{R}^{n\times p}\), row covariance \(U \in \mathbb{S}_{++}^n\), column covariance \(V \in \mathbb{S}_{++}^p\) |
| Parameter space | \(\mathbb{S}_{++}^k\) denotes the set of \(k\times k\) symmetric positive definite matrices |

We write

\[
X \sim \mathrm{MN}_{n\times p}(M,\,U,\,V).
\]


## 2) Intuition & Motivation

### What this distribution models
Matrix normal is a distribution over **random matrices** that behaves like a Gaussian in every direction.
The key modeling assumption is **separable covariance**:

- **Rows** are correlated according to \(U\).
- **Columns** are correlated according to \(V\).
- The covariance between entries factorizes:

  \[
  \mathrm{Cov}(X_{ij}, X_{k\ell}) = U_{ik}\,V_{j\ell}.
  \]

This Kronecker structure is useful when a full \((np)\times(np)\) covariance is too expensive to store or estimate.

### Typical real-world use cases
- **Multivariate linear regression / MANOVA**: residuals in \(Y\in\mathbb{R}^{n\times p}\) often have correlations across samples (rows) and across responses (columns).
- **Spatiotemporal grids**: rows = time, columns = space (or vice versa) with separable dependence.
- **Images / patches**: covariance that factors into horizontal/vertical components.
- **Gaussian-process models with Kronecker structure**: matrix normal appears when discretizing separable kernels.

### Relations to other distributions
- If \(X \sim \mathrm{MN}(M,U,V)\), then

  \[
  \mathrm{vec}(X) \sim \mathcal{N}(\mathrm{vec}(M),\, V \otimes U).
  \]

- If \(p=1\) (a single column), matrix normal reduces to an \(n\)-dimensional multivariate normal.
- If \(U = I_n\) and \(V = I_p\), entries of \(X\) are i.i.d. \(\mathcal{N}(M_{ij}, 1)\).
- In matrix-variate statistics, sums of quadratic forms in matrix-normal samples lead to **Wishart** distributions.


## 3) Formal Definition

Let \(X \in \mathbb{R}^{n\times p}\). We write

\[
X \sim \mathrm{MN}_{n\times p}(M,\,U,\,V)
\]

with mean \(M\in\mathbb{R}^{n\times p}\), row covariance \(U\in\mathbb{S}_{++}^n\), and column covariance \(V\in\mathbb{S}_{++}^p\).

### PDF
The density with respect to Lebesgue measure on \(\mathbb{R}^{np}\) is

\[
 f(X\mid M,U,V)
= \frac{\exp\left(-\tfrac12\,\mathrm{tr}\left(U^{-1}(X-M)V^{-1}(X-M)^\top\right)\right)}
       {(2\pi)^{\tfrac{np}{2}}\,|U|^{\tfrac{p}{2}}\,|V|^{\tfrac{n}{2}}}.
\]

A numerically stable way to view the quadratic term is

\[
\mathrm{tr}\left(U^{-1}(X-M)V^{-1}(X-M)^\top\right)
= \lVert L_U^{-1}(X-M)L_V^{-\top} \rVert_F^2,
\]

where \(U=L_U L_U^\top\) and \(V=L_V L_V^\top\) are Cholesky factorizations.

### CDF
A full multivariate CDF is defined componentwise:

\[
F(X) = \mathbb{P}(X_{11}\le x_{11},\,\dots,\,X_{np}\le x_{np}).
\]

In terms of vectorization,

\[
F(X) = \Phi_{np}\big(\mathrm{vec}(X);\,\mathrm{vec}(M),\,V\otimes U\big),
\]

where \(\Phi_{np}\) is the \(np\)-dimensional multivariate normal CDF. In general there is **no closed-form expression**;
numerical evaluation is feasible only for small \(np\).

### LaTeX notation
Common shorthands:

\[
X \sim \mathcal{MN}(M,U,V)\quad\text{or}\quad X \sim \mathrm{MN}_{n\times p}(M,U,V).
\]


In [None]:
# Quick numerical check of logpdf against SciPy
n, p = 3, 2
M = rng.normal(size=(n, p))
U = ar1_cov(n, rho=0.4, sigma2=1.5)
V = ar1_cov(p, rho=-0.2, sigma2=0.7)

X = matrix_normal_rvs_numpy(M, U, V, size=1, rng=rng)

rv = stats.matrix_normal(mean=M, rowcov=U, colcov=V)
logpdf_scipy = rv.logpdf(X)
logpdf_numpy = matrix_normal_logpdf_numpy(X, M, U, V)

float(logpdf_scipy), float(logpdf_numpy), float(logpdf_numpy - logpdf_scipy)


## 4) Moments & Properties

A convenient starting point is the vectorized representation:

\[
\mathrm{vec}(X) \sim \mathcal{N}(\mathrm{vec}(M),\, \Sigma),\quad \Sigma = V\otimes U.
\]

### Mean and (co)variance
- **Mean**: \(\mathbb{E}[X]=M\).
- **Covariance of vectorization**: \(\mathrm{Cov}(\mathrm{vec}(X)) = V\otimes U\).
- **Entrywise covariance**:

  \[
  \mathrm{Cov}(X_{ij}, X_{k\ell}) = U_{ik}\,V_{j\ell}.
  \]

  In particular, \(\mathrm{Var}(X_{ij}) = U_{ii}V_{jj}\).

### Skewness and kurtosis
Because \(\mathrm{vec}(X)\) is multivariate normal:
- every centered **third moment** is 0 (zero skewness for any scalar linear functional),
- **fourth moments** follow Isserlis' (Wick's) theorem.

For any scalar projection \(S = \langle T, X\rangle = \mathrm{tr}(T^\top X)\), we have
\(S\sim\mathcal{N}(\mu_S, \sigma_S^2)\), so

- skewness \(=0\)
- kurtosis \(=3\) (excess kurtosis \(=0\)).

### MGF / characteristic function
For \(S = \mathrm{tr}(T^\top X)\):

\[
\mathbb{E}[e^{tS}] = \exp\left(t\,\mathrm{tr}(T^\top M) + \tfrac12 t^2\,\mathrm{tr}(U T V T^\top)\right).
\]

Equivalently,

\[
M_X(T) = \mathbb{E}[\exp(\mathrm{tr}(T^\top X))]
       = \exp\left(\mathrm{tr}(T^\top M) + \tfrac12\,\mathrm{tr}(U T V T^\top)\right).
\]

The characteristic function replaces \(T\) with \(iT\).

### Entropy
The differential entropy is that of an \(np\)-dimensional normal with covariance \(V\otimes U\):

\[
H(X) = \tfrac12\log\big((2\pi e)^{np}\,|V\otimes U|\big)
     = \tfrac{np}{2}(1+\log 2\pi) + \tfrac{p}{2}\log|U| + \tfrac{n}{2}\log|V|,
\]

using \(|V\otimes U| = |V|^n |U|^p\).


In [None]:
# Empirical check: mean and covariance structure
n, p = 4, 3
M = np.arange(n * p, dtype=float).reshape(n, p) / 10.0
U = ar1_cov(n, rho=0.6, sigma2=2.0)
V = ar1_cov(p, rho=-0.3, sigma2=1.2)

X_samps = matrix_normal_rvs_numpy(M, U, V, size=N_SAMPLES, rng=rng)  # (N, n, p)

# Mean
M_hat = X_samps.mean(axis=0)

# Covariance of vec (column-stacking)
X_vec = np.stack([vec_f(X_samps[i]) for i in range(X_samps.shape[0])], axis=0)
Sigma_hat = np.cov(X_vec, rowvar=False)
Sigma_theory = np.kron(V, U)  # consistent with vec_f

mean_err = np.linalg.norm(M_hat - M) / np.linalg.norm(M)
rel_cov_err = np.linalg.norm(Sigma_hat - Sigma_theory) / np.linalg.norm(Sigma_theory)

mean_err, rel_cov_err


## 5) Parameter Interpretation

### Mean matrix \(M\)
- \(M_{ij}\) is the average value of entry \((i,j)\).
- Changing \(M\) shifts the distribution without affecting dependence.

### Row covariance \(U\)
- Controls dependence **between rows** (across all columns).
- For a fixed column \(j\), the column vector \(X_{:,j}\) satisfies
  \(X_{:,j} \sim \mathcal{N}(M_{:,j},\, V_{jj} U)\).

### Column covariance \(V\)
- Controls dependence **between columns** (across all rows).
- For a fixed row \(i\), the row vector \(X_{i,:}\) satisfies
  \(X_{i,:} \sim \mathcal{N}(M_{i,:},\, U_{ii} V)\).

### Shape changes
- Increasing variances in \(U\) (diagonal entries) increases variability across rows.
- Increasing variances in \(V\) increases variability across columns.
- Increasing off-diagonal correlations in \(U\) makes rows “move together”.
- Increasing off-diagonal correlations in \(V\) makes columns “move together”.

**Identifiability note**: \(U\) and \(V\) are not separately identifiable from \(V\otimes U\). For any \(c>0\),

\[
\mathrm{MN}(M,\,U,\,V) \equiv \mathrm{MN}(M,\,cU,\,V/c).
\]

Many fitting procedures impose a constraint such as \(\mathrm{tr}(V)=p\) (or \(|V|=1\)) to fix this scale.


In [None]:
# Visual intuition: one draw under different row/column correlations
n, p = 8, 8
M = np.zeros((n, p))

scenarios = [
    {"rho_u": 0.0, "rho_v": 0.0, "title": "independent"},
    {"rho_u": 0.6, "rho_v": 0.0, "title": "row-correlated"},
    {"rho_u": 0.6, "rho_v": 0.7, "title": "row+col correlated"},
]

fig = make_subplots(rows=1, cols=3, subplot_titles=[s["title"] for s in scenarios])

for j, s in enumerate(scenarios, start=1):
    U = ar1_cov(n, rho=s["rho_u"], sigma2=1.0)
    V = ar1_cov(p, rho=s["rho_v"], sigma2=1.0)
    X = matrix_normal_rvs_numpy(M, U, V, size=1, rng=rng)
    fig.add_trace(go.Heatmap(z=X, coloraxis="coloraxis"), row=1, col=j)

fig.update_layout(
    title="One draw from MN(0, U, V) under different correlations",
    coloraxis={"colorscale": "RdBu", "cmin": -3, "cmax": 3},
    height=350,
)
fig.show()


## 6) Derivations

We sketch core derivations using vectorization identities.

### Expectation
Using \(\mathrm{vec}(X) \sim \mathcal{N}(\mathrm{vec}(M), V\otimes U)\), the mean of a multivariate normal is its location:

\[
\mathbb{E}[\mathrm{vec}(X)] = \mathrm{vec}(M)\quad\Rightarrow\quad \mathbb{E}[X]=M.
\]

### Variance / covariance
From \(\Sigma = V\otimes U\) and the mapping between \(\Sigma\) entries and matrix indices, we obtain

\[
\mathrm{Cov}(X_{ij}, X_{k\ell}) = U_{ik} V_{j\ell}.
\]

A useful special case is the scalar projection \(S=\mathrm{tr}(T^\top X)=\mathrm{vec}(T)^\top\mathrm{vec}(X)\):

\[
\mathrm{Var}(S)
= \mathrm{vec}(T)^\top (V\otimes U)\,\mathrm{vec}(T)
= \mathrm{tr}(U T V T^\top).
\]

### Likelihood
For one observation \(X\), the log-likelihood is

\[
\ell(M,U,V\mid X)
= -\tfrac{np}{2}\log(2\pi) - \tfrac{p}{2}\log|U| - \tfrac{n}{2}\log|V|
  -\tfrac12\,\mathrm{tr}\left(U^{-1}(X-M)V^{-1}(X-M)^\top\right).
\]

**MLE notes**
- For fixed \(U,V\), the MLE of \(M\) is the sample mean.
- Jointly maximizing over \((U,V)\) has no closed form. A common approach is the **flip-flop** (alternating) algorithm.
- Because of the scale non-identifiability \((cU, V/c)\), one typically normalizes after each iteration.


In [None]:
def flip_flop_mle(
    X: np.ndarray,
    max_iter: int = 50,
    tol: float = 1e-6,
    normalize: str | None = "trace_v",
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
    """Estimate (M, U, V) for matrix-normal samples using the flip-flop algorithm."""

    X = np.asarray(X, dtype=float)
    if X.ndim != 3:
        raise ValueError("X must have shape (m, n, p)")

    m, n, p = X.shape

    M_hat = X.mean(axis=0)
    E = X - M_hat

    U = np.eye(n)
    V = np.eye(p)

    history = {"loglik": []}
    prev_ll = None

    for _ in range(max_iter):
        # Update U given V
        L_v = chol_spd(V, name="V")
        U_new = np.zeros((n, n), dtype=float)
        for r in range(m):
            B = np.linalg.solve(L_v, E[r].T)  # (p, n)
            U_new += B.T @ B
        U_new /= (m * p)
        U_new = 0.5 * (U_new + U_new.T)

        # Update V given U
        L_u = chol_spd(U_new, name="U")
        V_new = np.zeros((p, p), dtype=float)
        for r in range(m):
            A = np.linalg.solve(L_u, E[r])
            V_new += A.T @ A
        V_new /= (m * n)
        V_new = 0.5 * (V_new + V_new.T)

        if normalize == "trace_v":
            c = float(np.trace(V_new) / p)
            if c <= 0:
                raise ValueError("normalization failed (non-positive trace)")
            V_new = V_new / c
            U_new = U_new * c

        ll = float(matrix_normal_logpdf_numpy(X, M_hat, U_new, V_new).sum())
        history["loglik"].append(ll)

        if prev_ll is not None:
            rel = abs(ll - prev_ll) / (abs(prev_ll) + 1.0)
            if rel < tol:
                U, V = U_new, V_new
                break

        prev_ll = ll
        U, V = U_new, V_new

    return M_hat, U, V, history


# Fit demo on synthetic data
n, p = 4, 3
M_true = rng.normal(size=(n, p))
U_true = ar1_cov(n, rho=0.5, sigma2=1.8)
V_true = ar1_cov(p, rho=-0.4, sigma2=0.9)

X = matrix_normal_rvs_numpy(M_true, U_true, V_true, size=800 if FAST_RUN else 4_000, rng=rng)
M_hat, U_hat, V_hat, info = flip_flop_mle(X, max_iter=60, tol=1e-7)

Sigma_true = np.kron(V_true, U_true)
Sigma_hat = np.kron(V_hat, U_hat)

mean_err = np.linalg.norm(M_hat - M_true) / np.linalg.norm(M_true)
rel_sigma_err = np.linalg.norm(Sigma_hat - Sigma_true) / np.linalg.norm(Sigma_true)

mean_err, rel_sigma_err, len(info["loglik"])


## 7) Sampling & Simulation

A clean NumPy-only sampling recipe follows directly from the Kronecker structure.

If
\(U = L_U L_U^\top\) and \(V = L_V L_V^\top\) are Cholesky factorizations and
\(Z\in\mathbb{R}^{n\times p}\) has i.i.d. \(\mathcal{N}(0,1)\) entries, then

\[
X = M + L_U\,Z\,L_V^\top
\]

satisfies
\(X\sim\mathrm{MN}(M,U,V)\).

**Why it works (sketch)**
Using the identity \(\mathrm{vec}(AZB^\top) = (B\otimes A)\mathrm{vec}(Z)\), we get

\[
\mathrm{vec}(X-M) = (L_V\otimes L_U)\,\mathrm{vec}(Z),
\]

so the covariance is
\((L_VL_V^\top)\otimes(L_UL_U^\top) = V\otimes U\).


In [None]:
# Basic sampling sanity check
n, p = 5, 4
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.3, sigma2=1.0)
V = ar1_cov(p, rho=0.7, sigma2=0.5)

X_samps = matrix_normal_rvs_numpy(M, U, V, size=3, rng=rng)
X_samps.shape, X_samps[0].mean(), X_samps[0].std()


## 8) Visualization

Matrix-normal densities live on \(\mathbb{R}^{np}\), so direct “PDF over \(X\)” plots are rarely helpful.
Instead, we visualize:

- **Univariate marginals** (e.g. a single entry \(X_{ij}\))
- **Scalar projections** \(S=\mathrm{tr}(T^\top X)\)
- **Bivariate marginals** of a few entries to show correlation

These are all normal distributions because the matrix normal is Gaussian.


In [None]:
# A) Univariate marginal PDF + CDF for one entry
n, p = 4, 4
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.65, sigma2=1.2)
V = ar1_cov(p, rho=-0.3, sigma2=0.8)

i, j = 1, 2
mu_ij = M[i, j]
var_ij = U[i, i] * V[j, j]

xs = np.linspace(mu_ij - 4 * np.sqrt(var_ij), mu_ij + 4 * np.sqrt(var_ij), 600)

pdf = norm(loc=mu_ij, scale=np.sqrt(var_ij)).pdf(xs)
cdf = norm(loc=mu_ij, scale=np.sqrt(var_ij)).cdf(xs)

fig = make_subplots(rows=1, cols=2, subplot_titles=["Marginal PDF", "Marginal CDF"])
fig.add_trace(go.Scatter(x=xs, y=pdf, name="theory pdf"), row=1, col=1)
fig.add_trace(go.Scatter(x=xs, y=cdf, name="theory cdf"), row=1, col=2)

fig.update_xaxes(title_text="x", row=1, col=1)
fig.update_yaxes(title_text="pdf", row=1, col=1)
fig.update_xaxes(title_text="x", row=1, col=2)
fig.update_yaxes(title_text="cdf", row=1, col=2)

fig.update_layout(title=f"Entry X[{i},{j}] ~ N({mu_ij:.2f}, {var_ij:.2f})", height=320)
fig.show()


In [None]:
# B) Monte Carlo vs theory for a scalar projection S = tr(T^T X)

n, p = 6, 5
M = rng.normal(size=(n, p)) * 0.2
U = ar1_cov(n, rho=0.5, sigma2=1.0)
V = ar1_cov(p, rho=0.3, sigma2=0.8)

T = rng.normal(size=(n, p))

mu_S = float(np.sum(T * M))  # tr(T^T M)
var_S = float(np.trace(U @ T @ V @ T.T))

X_samps = matrix_normal_rvs_numpy(M, U, V, size=N_SAMPLES, rng=rng)
S_samps = np.einsum("ij,nij->n", T, X_samps)

xs = np.linspace(mu_S - 4 * np.sqrt(var_S), mu_S + 4 * np.sqrt(var_S), 600)

fig = make_subplots(rows=1, cols=2, subplot_titles=["PDF (hist + theory)", "CDF (empirical + theory)"])

fig.add_trace(
    go.Histogram(x=S_samps, nbinsx=70, histnorm="probability density", name="MC"),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(x=xs, y=norm(mu_S, np.sqrt(var_S)).pdf(xs), name="theory"),
    row=1,
    col=1,
)

s_sorted = np.sort(S_samps)
emp_cdf = np.linspace(1.0 / len(s_sorted), 1.0, len(s_sorted))
fig.add_trace(go.Scatter(x=s_sorted, y=emp_cdf, mode="lines", name="empirical"), row=1, col=2)
fig.add_trace(go.Scatter(x=xs, y=norm(mu_S, np.sqrt(var_S)).cdf(xs), name="theory"), row=1, col=2)

fig.update_xaxes(title_text="s", row=1, col=1)
fig.update_yaxes(title_text="density", row=1, col=1)
fig.update_xaxes(title_text="s", row=1, col=2)
fig.update_yaxes(title_text="cdf", row=1, col=2)

fig.update_layout(
    title=f"S=⟨T,X⟩ ~ N({mu_S:.3f}, {var_S:.3f}) (verified by Monte Carlo)",
    height=350,
)
fig.show()


In [None]:
# C) Bivariate marginal of two entries to show correlation

n, p = 4, 4
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.8, sigma2=1.0)
V = ar1_cov(p, rho=0.0, sigma2=1.0)

(i1, j1), (i2, j2) = (0, 1), (2, 1)

mu = np.array([M[i1, j1], M[i2, j2]])
Sigma2 = np.array(
    [
        [U[i1, i1] * V[j1, j1], U[i1, i2] * V[j1, j2]],
        [U[i2, i1] * V[j2, j1], U[i2, i2] * V[j2, j2]],
    ]
)

X_samps = matrix_normal_rvs_numpy(M, U, V, size=N_SAMPLES, rng=rng)
Y = np.stack([X_samps[:, i1, j1], X_samps[:, i2, j2]], axis=1)

xg = np.linspace(Y[:, 0].min(), Y[:, 0].max(), 120)
yg = np.linspace(Y[:, 1].min(), Y[:, 1].max(), 120)
XG, YG = np.meshgrid(xg, yg)
pts = np.stack([XG.ravel(), YG.ravel()], axis=1)
Z = multivariate_normal(mean=mu, cov=Sigma2).pdf(pts).reshape(XG.shape)

fig = go.Figure()
fig.add_trace(
    go.Histogram2dContour(
        x=Y[:, 0],
        y=Y[:, 1],
        nbinsx=40,
        nbinsy=40,
        contours=dict(coloring="fill", showlines=False),
        colorscale="Blues",
        name="MC density",
    )
)
fig.add_trace(
    go.Contour(
        x=xg,
        y=yg,
        z=Z,
        line=dict(color="black"),
        contours=dict(showlabels=False),
        showscale=False,
        name="theory pdf",
    )
)

fig.update_layout(
    title=f"Bivariate marginal: (X[{i1},{j1}], X[{i2},{j2}])",
    xaxis_title=f"X[{i1},{j1}]",
    yaxis_title=f"X[{i2},{j2}]",
    height=380,
)
fig.show()


## 9) SciPy Integration (`scipy.stats.matrix_normal`)

SciPy provides a matrix normal distribution:

```python
rv = scipy.stats.matrix_normal(mean=M, rowcov=U, colcov=V)
```

Available methods (SciPy 1.15):
- `pdf`, `logpdf`
- `rvs`
- `mean`

**Notably missing**
- `cdf` (full multivariate CDF is expensive)
- `fit` (Kronecker-structured MLE needs an iterative procedure)

Workarounds:
- For CDF-like quantities, use **univariate/bivariate marginals**, or for very small \(np\) use `scipy.stats.multivariate_normal.cdf` on `vec(X)`.
- For fitting, use a flip-flop / alternating MLE (implemented above).


In [None]:
# SciPy usage: pdf / logpdf / rvs
n, p = 3, 2
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.4, sigma2=1.0)
V = ar1_cov(p, rho=-0.2, sigma2=0.6)

rv = stats.matrix_normal(mean=M, rowcov=U, colcov=V)
X = rv.rvs(size=5, random_state=rng)

print("rvs shape:", X.shape)
print("mean shape:", rv.mean.shape)
print("logpdf[0] scipy:", rv.logpdf(X[0]))
print("logpdf[0] numpy:", matrix_normal_logpdf_numpy(X[0], M, U, V))

# CDF workaround for small np: vectorize and use multivariate_normal.cdf
M_small = np.array([[0.0, 0.0]])  # 1x2
U_small = np.array([[1.0]])
V_small = np.array([[1.0, 0.5], [0.5, 1.0]])

x_small = np.array([[0.2, -0.1]])
Sigma_small = np.kron(V_small, U_small)

cdf_small = multivariate_normal(mean=vec_f(M_small), cov=Sigma_small).cdf(vec_f(x_small))
print("joint CDF for 1x2 example:", float(cdf_small))


## 10) Statistical Use Cases

### Hypothesis testing
With known \(U,V\), testing a mean matrix is a multivariate-normal problem.
Under \(H_0: M=M_0\), for one observation

\[
Q = \mathrm{tr}\left(U^{-1}(X-M_0)V^{-1}(X-M_0)^\top\right)
\]

is chi-square: \(Q \sim \chi^2_{np}\). For \(m\) i.i.d. observations,
\(\sum_{r=1}^m Q_r \sim \chi^2_{mnp}\).

### Bayesian modeling
Matrix normal is a **conjugate** building block in multivariate regression.
A classic model:

\[
Y \mid B,\Sigma \sim \mathrm{MN}(X B,\, I_n,\, \Sigma)
\]

with a matrix-normal prior on coefficients \(B\mid\Sigma\), often paired with an inverse-Wishart prior on \(\Sigma\).

### Generative modeling
The matrix normal provides a lightweight way to generate correlated random matrices
(e.g. noise images or fields) while keeping computations cheap via the Kronecker structure.


In [None]:
# Example: chi-square test for H0: M = 0 with known (U, V)

n, p = 3, 4
m = 10
M0 = np.zeros((n, p))
U = ar1_cov(n, rho=0.4, sigma2=1.0)
V = ar1_cov(p, rho=0.2, sigma2=0.7)

L_u = chol_spd(U, name="U")
L_v = chol_spd(V, name="V")

alpha = 0.05
crit = chi2.ppf(1 - alpha, df=m * n * p)

rejections = 0
for _ in range(N_EXPERIMENTS):
    X = matrix_normal_rvs_numpy(M0, U, V, size=m, rng=rng)

    Q = 0.0
    for r in range(m):
        A = np.linalg.solve(L_u, X[r] - M0)
        B = np.linalg.solve(L_v, A.T)
        Q += float(np.sum(B**2))

    if Q > crit:
        rejections += 1

rejections / N_EXPERIMENTS


## 11) Pitfalls

- **Invalid parameters**: \(U\) and \(V\) must be symmetric positive definite.
- **Scale non-identifiability**: \((U,V)\) and \((cU, V/c)\) describe the same covariance \(V\otimes U\).
  Impose a constraint (e.g. \(\mathrm{tr}(V)=p\)) when fitting.
- **Numerical stability**:
  - avoid explicit inverses; use Cholesky solves;
  - work with `logpdf` rather than `pdf` when \(np\) is moderate/large.
- **Kronecker blow-up**: forming \(V\otimes U\) explicitly costs \(O(n^2 p^2)\) memory.
  Prefer trace / Frobenius formulas and two-sided solves.
- **Vectorization convention**: \(V\otimes U\) corresponds to *column-stacking* `vec`.
  Mixing row-major flattening with the Kronecker formula is a common source of bugs.


## 12) Summary

- Matrix normal is a **continuous distribution over real matrices** with separable row/column covariance.
- \(X\sim\mathrm{MN}(M,U,V)\) iff \(\mathrm{vec}(X)\sim\mathcal{N}(\mathrm{vec}(M), V\otimes U)\).
- Mean is \(M\); entrywise covariance factorizes: \(\mathrm{Cov}(X_{ij},X_{k\ell})=U_{ik}V_{j\ell}\).
- Sampling is efficient via \(X=M+L_U Z L_V^\top\) with i.i.d. standard-normal \(Z\).
- SciPy supports `pdf/logpdf/rvs/mean`; full `cdf` and `fit` are not provided, but CDFs for small problems and iterative fitting are straightforward.
