In [1]:
import platform

import numpy as np
import scipy
from scipy import stats
from scipy.stats import chi2, multivariate_normal, norm

import plotly
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

# Plotly notebook defaults (mirrors patterns used elsewhere in this repo)
pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

np.set_printoptions(precision=4, suppress=True)

# Reproducibility
rng = np.random.default_rng(0)

# Quick/slow toggle
FAST_RUN = True
N_SAMPLES = 15_000 if FAST_RUN else 120_000
N_EXPERIMENTS = 250 if FAST_RUN else 2_000

print("Python", platform.python_version())
print("NumPy", np.__version__)
print("SciPy", scipy.__version__)
print("Plotly", plotly.__version__)


Python 3.12.9
NumPy 1.26.2
SciPy 1.15.0
Plotly 6.5.2


In [2]:
def vec_f(A: np.ndarray) -> np.ndarray:
    """Column-stacking vectorization (Fortran order).

    With this convention the Kronecker covariance is `np.kron(V, U)`.
    """

    A = np.asarray(A)
    return A.reshape(-1, order="F")


def unvec_f(v: np.ndarray, shape: tuple[int, int]) -> np.ndarray:
    v = np.asarray(v)
    return v.reshape(shape, order="F")


def ar1_cov(n: int, rho: float, sigma2: float = 1.0) -> np.ndarray:
    """AR(1)-style covariance: sigma2 * rho^{|i-j|}. SPD for |rho| < 1."""

    if n <= 0:
        raise ValueError("n must be positive")
    if not (-1.0 < rho < 1.0):
        raise ValueError("rho must be in (-1, 1)")
    idx = np.arange(n)
    return sigma2 * (rho ** np.abs(idx[:, None] - idx[None, :]))


def chol_spd(A: np.ndarray, name: str = "A", jitter: float = 1e-12, max_tries: int = 6) -> np.ndarray:
    """Cholesky factor of an SPD matrix with optional diagonal jitter.

    Returns lower-triangular L such that A ≈ L L^T.
    """

    A = np.asarray(A, dtype=float)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError(f"{name} must be a square matrix")

    A = 0.5 * (A + A.T)  # symmetrize to reduce numerical asymmetry
    eye = np.eye(A.shape[0])

    for k in range(max_tries):
        try:
            return np.linalg.cholesky(A + (jitter * (10**k)) * eye)
        except np.linalg.LinAlgError:
            pass

    raise ValueError(f"{name} must be symmetric positive definite (Cholesky failed)")


def validate_matrix_normal_params(mean: np.ndarray, rowcov: np.ndarray, colcov: np.ndarray) -> tuple[int, int]:
    mean = np.asarray(mean, dtype=float)
    if mean.ndim != 2:
        raise ValueError("mean must be a 2D matrix")

    n, p = mean.shape

    rowcov = np.asarray(rowcov, dtype=float)
    colcov = np.asarray(colcov, dtype=float)

    if rowcov.shape != (n, n):
        raise ValueError(f"rowcov must have shape ({n}, {n})")
    if colcov.shape != (p, p):
        raise ValueError(f"colcov must have shape ({p}, {p})")

    _ = chol_spd(rowcov, name="rowcov")
    _ = chol_spd(colcov, name="colcov")

    return n, p


def matrix_normal_rvs_numpy(
    mean: np.ndarray,
    rowcov: np.ndarray,
    colcov: np.ndarray,
    size: int = 1,
    rng: np.random.Generator | None = None,
) -> np.ndarray:
    """Sample from MN(mean, rowcov, colcov) using only NumPy.

    Algorithm:
    1) Factorize rowcov = L_u L_u^T and colcov = L_v L_v^T.
    2) Draw Z with i.i.d. N(0,1) entries.
    3) Return mean + L_u Z L_v^T.

    Returns:
    - if size == 1: array (n, p)
    - else: array (size, n, p)
    """

    if rng is None:
        rng = np.random.default_rng()

    mean = np.asarray(mean, dtype=float)
    rowcov = np.asarray(rowcov, dtype=float)
    colcov = np.asarray(colcov, dtype=float)
    n, p = validate_matrix_normal_params(mean, rowcov, colcov)

    L_u = chol_spd(rowcov, name="rowcov")
    L_v = chol_spd(colcov, name="colcov")

    Z = rng.standard_normal(size=(size, n, p))
    out = np.empty_like(Z)

    for i in range(size):
        out[i] = mean + L_u @ Z[i] @ L_v.T

    return out[0] if size == 1 else out


def matrix_normal_logpdf_numpy(
    x: np.ndarray,
    mean: np.ndarray,
    rowcov: np.ndarray,
    colcov: np.ndarray,
) -> np.ndarray:
    """Log-PDF of X ~ MN(mean, rowcov, colcov) using NumPy linear algebra.

    Supports:
    - x shape (n, p) -> returns scalar
    - x shape (m, n, p) -> returns length-m array
    """

    mean = np.asarray(mean, dtype=float)
    rowcov = np.asarray(rowcov, dtype=float)
    colcov = np.asarray(colcov, dtype=float)

    n, p = validate_matrix_normal_params(mean, rowcov, colcov)

    x = np.asarray(x, dtype=float)
    if x.shape[-2:] != (n, p):
        raise ValueError(f"x must end with shape ({n}, {p})")

    L_u = chol_spd(rowcov, name="rowcov")
    L_v = chol_spd(colcov, name="colcov")

    logdet_u = 2.0 * np.sum(np.log(np.diag(L_u)))
    logdet_v = 2.0 * np.sum(np.log(np.diag(L_v)))

    const = (
        -0.5 * n * p * np.log(2.0 * np.pi)
        -0.5 * p * logdet_u
        -0.5 * n * logdet_v
    )

    def quad_form(E: np.ndarray) -> float:
        # || L_u^{-1} E L_v^{-T} ||_F^2 computed via two triangular solves
        A = np.linalg.solve(L_u, E)  # (n, p)
        B = np.linalg.solve(L_v, A.T)  # (p, n)
        return float(np.sum(B**2))

    if x.ndim == 2:
        E = x - mean
        return const - 0.5 * quad_form(E)

    if x.ndim == 3:
        out = np.empty(x.shape[0], dtype=float)
        for i in range(x.shape[0]):
            E = x[i] - mean
            out[i] = const - 0.5 * quad_form(E)
        return out

    raise ValueError("x must be 2D or 3D")


In [3]:
# Quick numerical check of logpdf against SciPy
n, p = 3, 2
M = rng.normal(size=(n, p))
U = ar1_cov(n, rho=0.4, sigma2=1.5)
V = ar1_cov(p, rho=-0.2, sigma2=0.7)

X = matrix_normal_rvs_numpy(M, U, V, size=1, rng=rng)

rv = stats.matrix_normal(mean=M, rowcov=U, colcov=V)
logpdf_scipy = rv.logpdf(X)
logpdf_numpy = matrix_normal_logpdf_numpy(X, M, U, V)

float(logpdf_scipy), float(logpdf_numpy), float(logpdf_numpy - logpdf_scipy)


(-7.792107996744354, -7.792107996744427, -7.283063041541027e-14)

In [4]:
# Empirical check: mean and covariance structure
n, p = 4, 3
M = np.arange(n * p, dtype=float).reshape(n, p) / 10.0
U = ar1_cov(n, rho=0.6, sigma2=2.0)
V = ar1_cov(p, rho=-0.3, sigma2=1.2)

X_samps = matrix_normal_rvs_numpy(M, U, V, size=N_SAMPLES, rng=rng)  # (N, n, p)

# Mean
M_hat = X_samps.mean(axis=0)

# Covariance of vec (column-stacking)
X_vec = np.stack([vec_f(X_samps[i]) for i in range(X_samps.shape[0])], axis=0)
Sigma_hat = np.cov(X_vec, rowvar=False)
Sigma_theory = np.kron(V, U)  # consistent with vec_f

mean_err = np.linalg.norm(M_hat - M) / np.linalg.norm(M)
rel_cov_err = np.linalg.norm(Sigma_hat - Sigma_theory) / np.linalg.norm(Sigma_theory)

mean_err, rel_cov_err


(0.013396987029412295, 0.021002378515706396)

In [5]:
# Visual intuition: one draw under different row/column correlations
n, p = 8, 8
M = np.zeros((n, p))

scenarios = [
    {"rho_u": 0.0, "rho_v": 0.0, "title": "independent"},
    {"rho_u": 0.6, "rho_v": 0.0, "title": "row-correlated"},
    {"rho_u": 0.6, "rho_v": 0.7, "title": "row+col correlated"},
]

fig = make_subplots(rows=1, cols=3, subplot_titles=[s["title"] for s in scenarios])

for j, s in enumerate(scenarios, start=1):
    U = ar1_cov(n, rho=s["rho_u"], sigma2=1.0)
    V = ar1_cov(p, rho=s["rho_v"], sigma2=1.0)
    X = matrix_normal_rvs_numpy(M, U, V, size=1, rng=rng)
    fig.add_trace(go.Heatmap(z=X, coloraxis="coloraxis"), row=1, col=j)

fig.update_layout(
    title="One draw from MN(0, U, V) under different correlations",
    coloraxis={"colorscale": "RdBu", "cmin": -3, "cmax": 3},
    height=350,
)
fig.show()


In [6]:
def flip_flop_mle(
    X: np.ndarray,
    max_iter: int = 50,
    tol: float = 1e-6,
    normalize: str | None = "trace_v",
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
    """Estimate (M, U, V) for matrix-normal samples using the flip-flop algorithm."""

    X = np.asarray(X, dtype=float)
    if X.ndim != 3:
        raise ValueError("X must have shape (m, n, p)")

    m, n, p = X.shape

    M_hat = X.mean(axis=0)
    E = X - M_hat

    U = np.eye(n)
    V = np.eye(p)

    history = {"loglik": []}
    prev_ll = None

    for _ in range(max_iter):
        # Update U given V
        L_v = chol_spd(V, name="V")
        U_new = np.zeros((n, n), dtype=float)
        for r in range(m):
            B = np.linalg.solve(L_v, E[r].T)  # (p, n)
            U_new += B.T @ B
        U_new /= (m * p)
        U_new = 0.5 * (U_new + U_new.T)

        # Update V given U
        L_u = chol_spd(U_new, name="U")
        V_new = np.zeros((p, p), dtype=float)
        for r in range(m):
            A = np.linalg.solve(L_u, E[r])
            V_new += A.T @ A
        V_new /= (m * n)
        V_new = 0.5 * (V_new + V_new.T)

        if normalize == "trace_v":
            c = float(np.trace(V_new) / p)
            if c <= 0:
                raise ValueError("normalization failed (non-positive trace)")
            V_new = V_new / c
            U_new = U_new * c

        ll = float(matrix_normal_logpdf_numpy(X, M_hat, U_new, V_new).sum())
        history["loglik"].append(ll)

        if prev_ll is not None:
            rel = abs(ll - prev_ll) / (abs(prev_ll) + 1.0)
            if rel < tol:
                U, V = U_new, V_new
                break

        prev_ll = ll
        U, V = U_new, V_new

    return M_hat, U, V, history


# Fit demo on synthetic data
n, p = 4, 3
M_true = rng.normal(size=(n, p))
U_true = ar1_cov(n, rho=0.5, sigma2=1.8)
V_true = ar1_cov(p, rho=-0.4, sigma2=0.9)

X = matrix_normal_rvs_numpy(M_true, U_true, V_true, size=800 if FAST_RUN else 4_000, rng=rng)
M_hat, U_hat, V_hat, info = flip_flop_mle(X, max_iter=60, tol=1e-7)

Sigma_true = np.kron(V_true, U_true)
Sigma_hat = np.kron(V_hat, U_hat)

mean_err = np.linalg.norm(M_hat - M_true) / np.linalg.norm(M_true)
rel_sigma_err = np.linalg.norm(Sigma_hat - Sigma_true) / np.linalg.norm(Sigma_true)

mean_err, rel_sigma_err, len(info["loglik"])


(0.037808598539744685, 0.07246270804981933, 3)

In [7]:
# Basic sampling sanity check
n, p = 5, 4
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.3, sigma2=1.0)
V = ar1_cov(p, rho=0.7, sigma2=0.5)

X_samps = matrix_normal_rvs_numpy(M, U, V, size=3, rng=rng)
X_samps.shape, X_samps[0].mean(), X_samps[0].std()


((3, 5, 4), 0.18895790539751117, 0.5586039053370001)

In [8]:
# A) Univariate marginal PDF + CDF for one entry
n, p = 4, 4
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.65, sigma2=1.2)
V = ar1_cov(p, rho=-0.3, sigma2=0.8)

i, j = 1, 2
mu_ij = M[i, j]
var_ij = U[i, i] * V[j, j]

xs = np.linspace(mu_ij - 4 * np.sqrt(var_ij), mu_ij + 4 * np.sqrt(var_ij), 600)

pdf = norm(loc=mu_ij, scale=np.sqrt(var_ij)).pdf(xs)
cdf = norm(loc=mu_ij, scale=np.sqrt(var_ij)).cdf(xs)

fig = make_subplots(rows=1, cols=2, subplot_titles=["Marginal PDF", "Marginal CDF"])
fig.add_trace(go.Scatter(x=xs, y=pdf, name="theory pdf"), row=1, col=1)
fig.add_trace(go.Scatter(x=xs, y=cdf, name="theory cdf"), row=1, col=2)

fig.update_xaxes(title_text="x", row=1, col=1)
fig.update_yaxes(title_text="pdf", row=1, col=1)
fig.update_xaxes(title_text="x", row=1, col=2)
fig.update_yaxes(title_text="cdf", row=1, col=2)

fig.update_layout(title=f"Entry X[{i},{j}] ~ N({mu_ij:.2f}, {var_ij:.2f})", height=320)
fig.show()


In [9]:
# B) Monte Carlo vs theory for a scalar projection S = tr(T^T X)

n, p = 6, 5
M = rng.normal(size=(n, p)) * 0.2
U = ar1_cov(n, rho=0.5, sigma2=1.0)
V = ar1_cov(p, rho=0.3, sigma2=0.8)

T = rng.normal(size=(n, p))

mu_S = float(np.sum(T * M))  # tr(T^T M)
var_S = float(np.trace(U @ T @ V @ T.T))

X_samps = matrix_normal_rvs_numpy(M, U, V, size=N_SAMPLES, rng=rng)
S_samps = np.einsum("ij,nij->n", T, X_samps)

xs = np.linspace(mu_S - 4 * np.sqrt(var_S), mu_S + 4 * np.sqrt(var_S), 600)

fig = make_subplots(rows=1, cols=2, subplot_titles=["PDF (hist + theory)", "CDF (empirical + theory)"])

fig.add_trace(
    go.Histogram(x=S_samps, nbinsx=70, histnorm="probability density", name="MC"),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(x=xs, y=norm(mu_S, np.sqrt(var_S)).pdf(xs), name="theory"),
    row=1,
    col=1,
)

s_sorted = np.sort(S_samps)
emp_cdf = np.linspace(1.0 / len(s_sorted), 1.0, len(s_sorted))
fig.add_trace(go.Scatter(x=s_sorted, y=emp_cdf, mode="lines", name="empirical"), row=1, col=2)
fig.add_trace(go.Scatter(x=xs, y=norm(mu_S, np.sqrt(var_S)).cdf(xs), name="theory"), row=1, col=2)

fig.update_xaxes(title_text="s", row=1, col=1)
fig.update_yaxes(title_text="density", row=1, col=1)
fig.update_xaxes(title_text="s", row=1, col=2)
fig.update_yaxes(title_text="cdf", row=1, col=2)

fig.update_layout(
    title=f"S=⟨T,X⟩ ~ N({mu_S:.3f}, {var_S:.3f}) (verified by Monte Carlo)",
    height=350,
)
fig.show()


In [10]:
# C) Bivariate marginal of two entries to show correlation

n, p = 4, 4
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.8, sigma2=1.0)
V = ar1_cov(p, rho=0.0, sigma2=1.0)

(i1, j1), (i2, j2) = (0, 1), (2, 1)

mu = np.array([M[i1, j1], M[i2, j2]])
Sigma2 = np.array(
    [
        [U[i1, i1] * V[j1, j1], U[i1, i2] * V[j1, j2]],
        [U[i2, i1] * V[j2, j1], U[i2, i2] * V[j2, j2]],
    ]
)

X_samps = matrix_normal_rvs_numpy(M, U, V, size=N_SAMPLES, rng=rng)
Y = np.stack([X_samps[:, i1, j1], X_samps[:, i2, j2]], axis=1)

xg = np.linspace(Y[:, 0].min(), Y[:, 0].max(), 120)
yg = np.linspace(Y[:, 1].min(), Y[:, 1].max(), 120)
XG, YG = np.meshgrid(xg, yg)
pts = np.stack([XG.ravel(), YG.ravel()], axis=1)
Z = multivariate_normal(mean=mu, cov=Sigma2).pdf(pts).reshape(XG.shape)

fig = go.Figure()
fig.add_trace(
    go.Histogram2dContour(
        x=Y[:, 0],
        y=Y[:, 1],
        nbinsx=40,
        nbinsy=40,
        contours=dict(coloring="fill", showlines=False),
        colorscale="Blues",
        name="MC density",
    )
)
fig.add_trace(
    go.Contour(
        x=xg,
        y=yg,
        z=Z,
        line=dict(color="black"),
        contours=dict(showlabels=False),
        showscale=False,
        name="theory pdf",
    )
)

fig.update_layout(
    title=f"Bivariate marginal: (X[{i1},{j1}], X[{i2},{j2}])",
    xaxis_title=f"X[{i1},{j1}]",
    yaxis_title=f"X[{i2},{j2}]",
    height=380,
)
fig.show()


In [11]:
# SciPy usage: pdf / logpdf / rvs
n, p = 3, 2
M = np.zeros((n, p))
U = ar1_cov(n, rho=0.4, sigma2=1.0)
V = ar1_cov(p, rho=-0.2, sigma2=0.6)

rv = stats.matrix_normal(mean=M, rowcov=U, colcov=V)
X = rv.rvs(size=5, random_state=rng)

print("rvs shape:", X.shape)
print("mean shape:", rv.mean.shape)
print("logpdf[0] scipy:", rv.logpdf(X[0]))
print("logpdf[0] numpy:", matrix_normal_logpdf_numpy(X[0], M, U, V))

# CDF workaround for small np: vectorize and use multivariate_normal.cdf
M_small = np.array([[0.0, 0.0]])  # 1x2
U_small = np.array([[1.0]])
V_small = np.array([[1.0, 0.5], [0.5, 1.0]])

x_small = np.array([[0.2, -0.1]])
Sigma_small = np.kron(V_small, U_small)

cdf_small = multivariate_normal(mean=vec_f(M_small), cov=Sigma_small).cdf(vec_f(x_small))
print("joint CDF for 1x2 example:", float(cdf_small))


rvs shape: (5, 3, 2)
mean shape: (3, 2)
logpdf[0] scipy: -5.61422078234059
logpdf[0] numpy: -5.614220782342528
joint CDF for 1x2 example: 0.3471505869384755


In [12]:
# Example: chi-square test for H0: M = 0 with known (U, V)

n, p = 3, 4
m = 10
M0 = np.zeros((n, p))
U = ar1_cov(n, rho=0.4, sigma2=1.0)
V = ar1_cov(p, rho=0.2, sigma2=0.7)

L_u = chol_spd(U, name="U")
L_v = chol_spd(V, name="V")

alpha = 0.05
crit = chi2.ppf(1 - alpha, df=m * n * p)

rejections = 0
for _ in range(N_EXPERIMENTS):
    X = matrix_normal_rvs_numpy(M0, U, V, size=m, rng=rng)

    Q = 0.0
    for r in range(m):
        A = np.linalg.solve(L_u, X[r] - M0)
        B = np.linalg.solve(L_v, A.T)
        Q += float(np.sum(B**2))

    if Q > crit:
        rejections += 1

rejections / N_EXPERIMENTS


0.036