# LKJ Cholesky Covariance Priors for Multivariate Normal Models

Based on https://docs.pymc.io/notebooks/LKJ.html, remade with NumPyro
    
More info on LKJ priors:

In [None]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
import numpy as np
import scipy
import scipy.stats

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.patches
import matplotlib.pyplot as plt

from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)
numpyro.enable_validation()

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

In [None]:
blue, orange, *_ = sns.color_palette("tab10")

## Generate Data

### Plot density as ellipse

- https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.patches.Ellipse.html


#### Visualize the 95% density region.

The value 5.991 comes from the treshold that the ellipse captures a certain probability region (as defined by the standard deviations of the multivariate Gaussian). Because a sum of Gaussians results in a Chi-squared distribution we look for P(s<x) = 0.95 with a Chi-squared distribution of 2 degrees of freedom (2 dimensions in the Gaussian).

More info:
- https://people.richland.edu/james/lecture/m170/tbl-chi.html#:~:text=5.991
- https://cookierobotics.com/007/
- https://www.visiondummy.com/2014/04/draw-error-ellipse-representing-covariance-matrix/

In [None]:
scipy.stats.chi2(df=2).ppf(0.95)

In [None]:
def plot_ellipse(mean, var, angle_deg, color, alpha, name=None, ax=None):
    chi2_pff = scipy.stats.chi2(df=2).ppf(0.95)
    horizontal_width, vertical_height = 2. * np.sqrt(chi2_pff * var)
    # https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.patches.Ellipse.html
    # width: horizontal axis
    # height: vertical axis
    # angle: Rotation in degrees anti-clockwise. (from y axis: (0, 1)-vector )
    label="95% density region"
    if name is not None:
        label = label + " - " + name
    e = matplotlib.patches.Ellipse(
        xy=mean, width=horizontal_width, height=vertical_height, angle=angle_deg,
        label="95% density region"
    )
    e.set_alpha(alpha)
    e.set_facecolor(color)
    ax.add_artist(e)
    return e

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
e = plot_ellipse([0, 0], np.array([0.5, 2]), angle_deg=10, color=blue, alpha=0.5, ax=ax)
ax.set_aspect('equal')
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
handles, _ = ax.get_legend_handles_labels()
handles.extend([e])
ax.legend(handles=handles)
plt.show()

### Covariance matrix
In other words, the largest eigenvector of the covariance matrix always points into the direction of the largest variance of the data, and the magnitude of this vector equals the corresponding eigenvalue. This means that the eigenvectors can be used the find the angle of the direction of the largest variance in data.

- https://www.visiondummy.com/2014/04/geometric-interpretation-covariance-matrix/
- https://janakiev.com/blog/covariance-matrix/

#### Computing the angle of the covariance matrix

Note that the the covariance matrix can be decomposed as a rotation matrix and scales:
- https://github.com/peterroelants/notebooks/blob/master/ml_algorithms/PCA_intuition.ipynb

With the [rotation matrix](https://en.wikipedia.org/wiki/Rotation_matrix) given as:
$$
R = \begin{bmatrix}
\cos \theta &-\sin \theta \\
\sin \theta &\cos \theta \\
\end{bmatrix}
$$


We can recover the original rotation by calling the [atan2](https://en.wikipedia.org/wiki/Atan2) function on the first eigenvector (first component). In Numpy this is [`arctan2`](https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html):

$$
\theta_X = arctan\left(\frac{y}{x}\right) = arctan\left(\frac{r \sin \theta}{r \cos \theta}\right)
$$

The direction of vector rotation is counterclockwise if θ is positive (e.g. 90°).

Since the rotation matrix is ambigu we could also use `np.arccos`.

In [None]:
def eig_sorted(mat):
    eig_val, eig_vec = np.linalg.eig(mat)
    eig_idx = eig_val.argsort()[::-1]   
    eig_val = eig_val[eig_idx]
    eig_vec = eig_vec[:,eig_idx]
    return eig_val, eig_vec

In [None]:
# Covariance from rotation and scale matrix
angle_rad = np.deg2rad(20)
R = np.asarray([
    [np.cos(angle_rad), -np.sin(angle_rad)],
    [np.sin(angle_rad),  np.cos(angle_rad)],
])
scale = np.array([0.5, 1.2])
S = np.diag(scale)
# Covariance matrix
Σ = R @ (S**2) @ R.T

# The eigenvalues don't have to be sorted, but this will keep eigen_values consistent
eig_val, eig_vec = eig_sorted(Σ)

# std and scale should be the same (ignoring order)
assert np.allclose(np.sqrt(eig_val), np.sort(scale)[::-1])

# The reconstructed angle is relative to the eignevectors
angle_rad_reconstruct = np.arctan2(eig_vec[1, 0], eig_vec[0,0])
print('Reconstructed rotation angle = {:.2f} degrees'.format(np.rad2deg(angle_rad_reconstruct)))

In [None]:
np.random.seed(42)

N = 1000

μ_actual = np.array([1.0, -2.0])
Σ_actual = Σ
print("Σ_actual: ", Σ_actual.shape)
print(Σ)
L_actual = np.linalg.cholesky(Σ)
print("L_actual: ", L_actual.shape)

x = np.random.multivariate_normal(μ_actual, Σ_actual, size=N)
print("x: ", x.shape)

In [None]:
eig_val, eig_vec = np.linalg.eig(Σ_actual)
angle = np.rad2deg(np.arctan2(eig_vec[1,0], eig_vec[0,0]))

fig, ax = plt.subplots(figsize=(8, 6))

ax.scatter(x[:, 0], x[:, 1], c="k", alpha=0.05)
plot_ellipse(μ_actual, eig_val, angle_deg=angle, color=blue, alpha=0.5, ax=ax)
ax.set_aspect('equal')
plt.show()

## LKJ Cholesky Prior

LKJ is a distribution over [correlation](https://en.wikipedia.org/wiki/Correlation) matrices. Correlation is the normalized version of covariance.

$$
\rho _{X,Y}=\operatorname {corr} (X,Y)={\operatorname {cov} (X,Y) \over \sigma _{X}\sigma _{Y}}={\operatorname {E} [(X-\mu _{X})(Y-\mu _{Y})] \over \sigma _{X}\sigma _{Y}}
$$

More info:
- https://en.wikipedia.org/wiki/Covariance_and_correlation

LKJ Distribution in NumPyro:
- http://num.pyro.ai/en/stable/distributions.html#lkjcholesky

Cholesky decomposition:
- https://en.wikipedia.org/wiki/Cholesky_decomposition

More info on LKJ:
- https://distribution-explorer.github.io/multivariate_continuous/lkj.html
- https://mc-stan.org/docs/2_18/stan-users-guide/multivariate-hierarchical-priors-section.html
- https://eager-roentgen-523c83.netlify.app/2014/12/27/d-lkj-priors/
- http://srmart.in/is-the-lkj1-prior-uniform-yes/
- https://docs.pymc.io/notebooks/LKJ.html

In [None]:
_lkj_chol_dist = dist.LKJCholesky(dimension=2, concentration=1)
print("_lkj_chol_dist.batch_shape: ", _lkj_chol_dist.batch_shape)
print("_lkj_chol_dist.event_shape: ", _lkj_chol_dist.event_shape)
print("_lkj_chol_dist.event_dim: ", _lkj_chol_dist.event_dim)
print("_lkj_chol_dist.shape(): ", _lkj_chol_dist.shape())
print('')

In [None]:
rng_key = jax.random.PRNGKey(42)

# Get cholesky decomposition of covariance matrix
_rho_sample = _lkj_chol_dist.sample(rng_key)
_sigma = np.array([1., 1.])
_L_sample = _sigma[..., None] * _rho_sample
print("_L_sample: \n", _L_sample)
assert np.allclose(_L_sample, _rho_sample @ jnp.diag(_sigma))
print("")

# Compute covariance matrix
_cov_sample = _L_sample @ _L_sample.T
print("_cov_sample: \n", _cov_sample)

In [None]:
rng_key = jax.random.PRNGKey(0)

_mean = np.array([0.0, 0.0])
# Samples
_mvn_dist = dist.MultivariateNormal(loc=_mean, scale_tril=_L_sample)
_mvn_samples = _mvn_dist.sample(rng_key, (10000,))
print("_mvn_samples.shape: ", _mvn_samples.shape)

_emp_cov = np.cov(_mvn_samples.T)
print("_emp_cov: \n", _emp_cov)
assert np.allclose(_cov_sample, np.cov(_mvn_samples.T), atol=1e-1)

In [None]:
_eig_val, _eig_vec = np.linalg.eig(_cov_sample)
_angle = np.rad2deg(np.arctan2(_eig_vec[1,0], _eig_vec[0,0]))

fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(_mvn_samples[:, 0], _mvn_samples[:, 1], c="k", alpha=0.05)
plot_ellipse(_mean, _eig_val, _angle, color=blue, alpha=0.5, ax=ax)
ax.set_aspect('equal')
plt.show()

#### Sample from LKJ

In [None]:
_lkj_chol_dist_samples = _lkj_chol_dist.sample(jax.random.PRNGKey(0), (25,))
print("_lkj_chol_dist_samples.shape: ", _lkj_chol_dist_samples.shape)

In [None]:
def get_ellipse_params(cov):
    eig_val, eig_vec = jnp.linalg.eig(cov)
    eig_val = eig_val.real
    eig_vec = eig_vec.real
    angle = jnp.rad2deg(jnp.arctan2(eig_vec[1, 0], eig_vec[0, 0]))
    return jnp.concatenate([eig_val, angle[...,None]], axis=0)

get_ellipse_params(_cov_sample)

In [None]:
_sigma = np.array([1., 1.])
_L_sample = _sigma[..., None] * _rho_sample
print("_L_sample: \n", _L_sample)
assert np.allclose(_L_sample, _rho_sample @ jnp.diag(_sigma))
print("")

# Compute covariance matrix
_cov_sample = _L_sample @ _L_sample.T
print("_cov_sample: \n", _cov_sample)


In [None]:
_L_samples = _lkj_chol_dist_samples @ jnp.diag(_sigma)

def L2cov(L):
    return L @ L.T

_cov_samples = jax.vmap(L2cov)(_L_samples)
_cov_samples.shape

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
sns.histplot(_cov_samples[:, 1, 0])

In [None]:
np.unique(_cov_samples[:, 1, 1])

In [None]:
_ellipse_params = jax.vmap(get_ellipse_params)(_cov_samples)
_ellipse_params.shape

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
for i in range(_ellipse_params.shape[0]):
    var = _ellipse_params[i][0:2]
    angle = _ellipse_params[i][2]
    plot_ellipse([0, 0], var, angle_deg=angle, color=blue, alpha=0.1, ax=ax)
ax.set_aspect('equal')
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
plt.show()

## Fit to data

In [None]:
def model(d, obs=None):
    lkj_conc = numpyro.sample("lkj_concentration", dist.Gamma(concentration=2., rate=1.0))
    lkj_chol = numpyro.sample("lkj_chol", dist.LKJCholesky(dimension=d, concentration=lkj_conc))
    scale = numpyro.sample("scale", dist.Exponential(rate=jnp.ones(d)))
    L_cov = numpyro.deterministic("L_cov", scale[..., None] * lkj_chol)
    loc = numpyro.sample('loc', dist.Normal(loc=jnp.zeros(d), scale=jnp.ones(d)*1.5))
    obs = numpyro.sample('obs', dist.MultivariateNormal(loc=loc, scale_tril=L_cov), obs=obs)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=4,
    chain_method='parallel',
)
mcmc.run(rng_key, d=2, obs=x)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()

In [None]:
import warnings
warnings.filterwarnings('ignore', message="divide by zero encountered in true_divide")
warnings.filterwarnings('ignore', message="invalid value encountered in true_divide")
warnings.filterwarnings('ignore', message="invalid value encountered in double_scalars")

In [None]:
az.summary(mcmc, var_names=["~lkj_chol"], round_to=2)

NaNs in n_eff and r_hat of cholesky matrix? This is because these are constant: https://discourse.mc-stan.org/t/in-the-estimation-results-se-mean-n-eff-and-rhat-are-nan-why/22482/7

In [None]:
rng_key = jax.random.PRNGKey(42)

posterior_predictive = Predictive(model, posterior_samples=posterior_samples)
posterior_predictions = posterior_predictive(rng_key, d=2)
print('Posterior predictions: ', posterior_predictions['obs'].shape)

In [None]:



inference_data = az.from_numpyro(
    posterior=mcmc,
    posterior_predictive=posterior_predictions,
)
display(inference_data)

az.plot_trace(
    inference_data,
    compact=True,
    lines=[
        ("loc", {}, μ_actual),
        ("L_cov", {}, L_actual),
    ],
)
plt.suptitle('Trace plots', fontsize=18)
plt.show()

In [None]:
μ_post = posterior_samples["loc"].mean(axis=0)
μ_post - μ_actual

In [None]:
L_post = posterior_samples["L_cov"].mean(axis=(0))
L_post - L_actual

In [None]:
Σ_post = L_post @ L_post.T
Σ_post - Σ_actual

In [None]:
eig_val_post, eig_vec_post = np.linalg.eig(Σ_post)
angle_post = np.rad2deg(np.arctan2(eig_vec_post[1,0], eig_vec_post[0,0]))


eig_val_actual, eig_vec_actual = np.linalg.eig(Σ_actual)
angle_actual = np.rad2deg(np.arctan2(eig_vec_actual[1,0], eig_vec_actual[0,0]))


fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(x[:, 0], x[:, 1], c="k", alpha=0.05)
e_actual = plot_ellipse(μ_actual, eig_val_actual, angle_actual, color=blue, alpha=0.5, name='actual', ax=ax)
e_post = plot_ellipse(μ_post, eig_val_post, angle_post, color=orange, alpha=0.5, name='post', ax=ax)
ax.set_aspect('equal')
handles, _ = ax.get_legend_handles_labels()
handles.extend([e_actual, e_post])
ax.legend(handles=handles)
plt.show()