# 2D GMM

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import numpy as np

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Create data

In [None]:
np.random.seed(42)

n = 2500 # Total number of samples
k = 3  # Number of clusters

# Probability of choosing each cluster
true_mixture_probs = np.array([0.2, 0.5, 0.3])  
assert np.isclose(true_mixture_probs.sum(), 1.)

# Mean of clusters
true_locs = np.array([  
    [-1.2, 1.5],
    [2.0,  2.],
    [-1,   4.]
])

# Correlation between x and y in clusters
true_corrs = np.array([-0.85, 0.0, 0.85])

# Correlation matrix
true_corr_mats = np.stack([np.array([[1., true_corrs[i]], [true_corrs[i], 1.]]) for i in range(k)])
print("true_corr_mats: ", true_corr_mats.shape)
# Scales, or standard deviation in x&y directions of clusters
true_scales = np.array([
    [0.9, 1.6],
    [1.0, 1.0],
    [1.4, 0.8],
])
print("true_scales: ", true_scales.shape)
# Covariance matrix
true_cov = np.einsum('ki,kj,kij->kij', true_scales, true_scales, true_corr_mats)

# Sample mixture component indices
true_mixture_idxs = np.random.choice(np.arange(k), p=true_mixture_probs, size=n)

# Sample observations
obs_data = np.vstack([
    np.random.multivariate_normal(true_locs[idx], true_cov[idx])
    for idx in true_mixture_idxs
])
assert obs_data.shape == (n, 2)

cmap = {
    i: sns.color_palette("tab10")[i]
    for i in range(k)
}

# Show observations
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(k):
    c_idx = (true_mixture_idxs == i)
    ax.plot(obs_data[c_idx, 0], obs_data[c_idx, 1], 'o', alpha=0.3, color=cmap[i], label=i)
ax.set_aspect('equal')
ax.set_title('Observations')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
ax.legend()
plt.show()

## Mixture distribution

### Gaussian Mixture Model

TODO:
- Check if scales multiplies with corr_lower indeed give the cholesky matrix of the covariance with the same shapes... (I doubt so)
  - More info: https://www2.stat.duke.edu/courses/Spring12/sta104.1/Lectures/Lec22.pdf

In [None]:
@jax.vmap
def create_chol_lower(scale, corr_lower):
    return scale[..., None] * corr_lower


def gmm_model(d: int, k: int, obs=None):
    """
    :param d: Dimension of Gaussian.
    :param k: Number of mixtures
    :param obs: Observations
    """
    # Prior for cluster probabilities
    mixing_prob = numpyro.sample('mixing_probabilities', dist.Dirichlet(concentration=jnp.ones((k, ))))
    # Prior on cluster means
    with numpyro.plate('mixture_plate', k, dim=-2):
        scales = numpyro.sample("scales", dist.HalfCauchy(scale=jnp.ones(d)*2))
        locs = numpyro.sample('locs', dist.Cauchy(loc=jnp.zeros(d), scale=jnp.ones(d)*2))
    # Prior on correlation trough LKJ prior
    with numpyro.plate('mixture_plate', k, dim=-1):
        corr_lower = numpyro.sample("corr_lower", dist.LKJCholesky(dimension=d, concentration=1.))
        # Extract correlation for later analysis
        corrs = numpyro.deterministic("correlations", corr_lower[:, 1, 0])
    # Mixing distribution
    mixing_dist = dist.Categorical(probs=mixing_prob)
    # Mixture components
    lower_cholesky = create_chol_lower(scales, corr_lower)
    component_dist = dist.MultivariateNormal(loc=locs, scale_tril=lower_cholesky)
    # Mixture distribution
    gmm_dist = dist.MixtureSameFamily(mixing_distribution=mixing_dist, component_distribution=component_dist)
    numpyro.sample('obs', gmm_dist, obs=obs)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 3000

# Run NUTS.
kernel = NUTS(gmm_model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=4,
    chain_method='parallel',
)
mcmc.run(rng_key, d=2, k=3, obs=obs_data)
posterior_samples = mcmc.get_samples()

In [None]:
az.summary(mcmc, var_names=["~corr_lower"], round_to=2)

In [None]:
rng_key = jax.random.PRNGKey(42)

posterior_predictive = Predictive(gmm_model, posterior_samples=posterior_samples, batch_ndims=0)
posterior_predictions = posterior_predictive(rng_key, d=2, k=3, obs=None)
print('Posterior predictions: ', posterior_predictions['obs'].shape)

In [None]:
inference_data = az.from_numpyro(
    posterior=mcmc,
    posterior_predictive=posterior_predictions,
    coords={"mixture": np.arange(k), "dim": np.arange(2)},
    dims={"locs": ["mixture", "dim"], "scales": ["mixture", "dim"], "mixing_probabilities": ["mixture"]}
)
display(inference_data)


In [None]:
az.plot_trace(
    inference_data,
    compact=True,
    var_names=["~corr_lower"],
    lines=[
        ("correlations", {}, true_corrs),
        ("locs", {}, true_locs),
        ("scales", {}, true_scales),
        ("mixing_probabilities", {}, true_mixture_probs)
    ],
)
plt.suptitle('Trace plots', fontsize=18)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(posterior_predictions["obs"][:, 0], posterior_predictions["obs"][:, 1], 'o', alpha=0.1)
ax.set_aspect('equal')
ax.set_title('Posterior predicted samples')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()