# LKJ Cholesky Covariance Priors for Multivariate Normal Models

Replication of the PyMC3 notebook on LKJ Cholesky Covariance priors: https://docs.pymc.io/notebooks/LKJ.html

More info:
- [PyMC3 LKJCholeskyCov](https://docs.pymc.io/api/distributions/multivariate.html#pymc3.distributions.multivariate.LKJCholeskyCov)
- [NumPyro LKJCholesky](http://num.pyro.ai/en/stable/distributions.html#lkjcholesky)

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='svg'

In [None]:
import sys
import warnings

import numpy as np
import scipy
import scipy.stats

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
from matplotlib.patches import Ellipse
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

In [None]:
blue, orange, *_ = sns.color_palette("tab10")

In [None]:
RANDOM_SEED = 8924
np.random.seed(3264602)  # from random.org

N = 10000

μ_actual = np.array([1.0, -2.0])
sigmas_actual = np.array([0.7, 1.5])
Rho_actual = np.matrix([[1.0, -0.4], [-0.4, 1.0]])

Σ_actual = np.diag(sigmas_actual) * Rho_actual * np.diag(sigmas_actual)

x = np.random.multivariate_normal(μ_actual, Σ_actual, size=N)
Σ_actual

In [None]:
var, U = np.linalg.eig(Σ_actual)
angle = 180.0 / np.pi * np.arccos(np.abs(U[0, 0]))

fig, ax = plt.subplots(figsize=(8, 6))

blue, _, red, *_ = sns.color_palette()

e = Ellipse(μ_actual, 2 * np.sqrt(5.991 * var[0]), 2 * np.sqrt(5.991 * var[1]), angle=angle)
e.set_alpha(0.5)
e.set_facecolor(blue)
e.set_zorder(10)
ax.add_artist(e)

ax.scatter(x[:, 0], x[:, 1], c="k", alpha=0.05, zorder=11)

rect = plt.Rectangle((0, 0), 1, 1, fc=blue, alpha=0.5)
ax.legend([rect], ["95% density region"], loc=2);

In [None]:
def model(obs):
    chol_stds = numpyro.sample("chol_stds", dist.Exponential(rate=jnp.ones(2)))
    lkj_chol = numpyro.sample("lkj_chol", dist.LKJCholesky(dimension=2, concentration=2.0))
    chol_corr = numpyro.deterministic("chol_corr", lkj_chol@lkj_chol.T)
    # Create cholesky matrix by scaling lkj_chol matrix with standard deviations
#     chol = numpyro.deterministic("chol", jnp.matmul(jnp.diag(chol_stds), lkj_chol))
    chol = numpyro.deterministic("chol", chol_stds[..., None] * lkj_chol)

    μ = numpyro.sample("μ", dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)*1.5))
    mvn = dist.MultivariateNormal(loc=μ, scale_tril=chol)
    cov = numpyro.deterministic("cov", mvn.covariance_matrix)
    obs = numpyro.sample("obs", mvn, obs=obs)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 1000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=4,
    chain_method='parallel',
    
)
mcmc.run(rng_key, obs=x)
posterior_samples = mcmc.get_samples()

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
az.summary(mcmc, var_names=["~lkj_chol", "~chol"], round_to=2)

In [None]:
rng_key = jax.random.PRNGKey(42)

posterior_predictive = Predictive(model, posterior_samples=posterior_samples)
posterior_predictions = posterior_predictive(rng_key, obs=x)

In [None]:
inference_data = az.from_numpyro(
    posterior=mcmc,
    posterior_predictive=posterior_predictions,
    coords={"xy": jnp.arange(2)},
    dims={"μ": ["xy"], "chol_stds": ["xy"], "chol": ["xy", "xy"], "lkj_chol": ["xy", "xy"], "cov": ["xy", "xy"], "chol_corr": ["xy", "xy"]}
)
display(inference_data)

In [None]:
az.plot_trace(
    inference_data,
    compact=True,
    var_names=["~lkj_chol", "~chol"],
    lines=[
        ("μ", {}, μ_actual),
        ("chol_stds", {}, sigmas_actual),
        ("chol_corr", {}, Rho_actual),
        ("cov", {}, Σ_actual),
    ],
)
plt.suptitle('Trace plots', fontsize=18)
plt.show()