# 2D GMM

In [None]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

## Create data

In [None]:
np.random.seed(42)

n = 1000 # Total number of samples
k = 3  # Number of clusters
p_real = np.array([0.2, 0.5, 0.3])  # Probability of choosing each cluster
assert np.isclose(p_real.sum(), 1.)
print("p_real: ", p_real.shape)
loc_real = np.array([  # Mean of clusters
    [-1.2, 1.5],
    [2., 2.],
    [-1, 4.]
])
assert loc_real.shape == (3, 2)
print("loc_real: ", loc_real.shape)
cov_real = np.array([
    [
        [0.1, -0.2],
        [-0.2, 1.0],
    ],
    [
        [0.75, 0.0],
        [0.0, 0.75],
    ],
    [
        [1.0, 0.5],
        [0.5, 0.27],
    ],
])  # Covariance of clusters
assert cov_real.shape == (3, 2, 2)
print("cov_real: ", cov_real.shape)
for i in range(k):
    assert (np.linalg.eigvals(cov_real[i]) > 0).all()
    assert np.allclose(cov_real[i], cov_real[i].T)
L_real = np.stack([
    np.linalg.cholesky(cov_real[i])
    for i in range(k)
])
print("L_real: ", L_real.shape)


nb_cluster_samples = (n * p_real).astype(np.int32)
assert len(nb_cluster_samples) == k
nb_cluster_samples[-1] = n - nb_cluster_samples[:-1].sum()
assert nb_cluster_samples.sum() == n

clusters = np.hstack([
    np.ones(nbcs, dtype=np.int32) * idx for idx, nbcs in enumerate(nb_cluster_samples)
])
assert clusters.shape[0] == n, clusters.shape
assert (np.unique(clusters, return_counts=True)[-1] == nb_cluster_samples).all()

obs_data = np.vstack([
    np.random.multivariate_normal(loc_real[idx], cov_real[idx], size=nbcs)
    for idx, nbcs in enumerate(nb_cluster_samples)
])
assert obs_data.shape == (n, 2)

# Shuffle
_idx_permutations = np.random.permutation(n)
clusters = clusters[_idx_permutations]
obs_data = obs_data[_idx_permutations, :]

cmap = {
    i: sns.color_palette("tab10")[i]
    for i in range(k)
}

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
for i in range(k):
    c_idx = (clusters == i)
    ax.plot(obs_data[c_idx, 0], obs_data[c_idx, 1], 'o', alpha=0.3, color=cmap[i], label=i)
ax.set_xlim(-5, 5)
ax.set_ylim(-2, 6)
ax.set_aspect('equal')
ax.set_title('Observations')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
plt.show()

## Test NumPyro Multivariate distributions

In [None]:
_mvn = dist.MultivariateNormal(loc=loc_real[2], covariance_matrix=cov_real[2])
_mvn_samples = _mvn.sample(rng_key, (100,))
print(_mvn_samples.shape)
fig, ax = plt.subplots(1, 1, figsize=(7, 5))
ax.plot(_mvn_samples[:, 0], _mvn_samples[:, 1], 'o', alpha=0.3, color=cmap[2], label=2)
ax.set_xlim(-5, 5)
ax.set_ylim(-2, 6)
ax.set_aspect('equal')
ax.set_title('Test MultivariateNormal')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
plt.show()

In [None]:
print("loc_real: ", loc_real.shape)
print("cov_real: ", cov_real.shape)
print("p_real: ", p_real.shape)

_mvn_multi = dist.MultivariateNormal(loc=loc_real, covariance_matrix=cov_real)
print("_mvn_multi.shape(): ", _mvn_multi.shape())
print("_mvn_multi.batch_shape: ", _mvn_multi.batch_shape)
print("_mvn_multi.event_shape: ", _mvn_multi.event_shape)
print("_mvn_multi.event_dim: ", _mvn_multi.event_dim)
_mvn_multi_samples = _mvn_multi.sample(rng_key, (1000,))
print("_mvn_multi_samples: ", _mvn_multi_samples.shape)

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
for i in range(k):
    ax.plot(_mvn_multi_samples[:, i, 0], _mvn_multi_samples[:, i, 1], 'o', alpha=0.3, color=cmap[i], label=i)
ax.set_xlim(-5, 5)
ax.set_ylim(-2, 6)
ax.set_aspect('equal')
ax.set_title('Test MultivariateNormal multiple distributions')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
plt.show()

In [None]:
# _p_real = np.array([0.2, 0.5, 0.3])  # Probability of choosing each cluster
# assert np.isclose(_p_real.sum(), 1.)
# print("_p_real: ", _p_real.shape)
# _loc_real = np.array([
#     [-1.2, 2., -1],
#     [1.5, 2., 4.]
# ])
# assert _loc_real.shape == (2, 3)
# print("_loc_real: ", _loc_real.shape)
# _cov_real = np.array([
#     [
#         [0.1, 0.75, 1.0],
#         [-0.2, 0.0, 0.5],
#     ],
#     [
#         [-0.2, 0.0, 0.5],
#         [1.0, 0.75, 0.27],
#     ],
# ])
# assert _cov_real.shape == (2, 2, 3)
# print("_cov_real: ", _cov_real.shape)
# print("_loc_real: ", _loc_real.shape)
# print("_cov_real: ", _cov_real.shape)
# print("_p_real: ", _p_real.shape)

# _mvn_multi = dist.MultivariateNormal(loc=_loc_real, covariance_matrix=_cov_real)
# print("_mvn_multi.shape(): ", _mvn_multi.shape())
# print("_mvn_multi.batch_shape: ", _mvn_multi.batch_shape)
# print("_mvn_multi.event_shape: ", _mvn_multi.event_shape)
# print("_mvn_multi.event_dim: ", _mvn_multi.event_dim)
# _mvn_multi_samples = _mvn_multi.sample(rng_key, (1000,))
# print("_mvn_multi_samples: ", _mvn_multi_samples.shape)

# # fig, ax = plt.subplots(1, 1, figsize=(7, 5))
# # for i in range(k):
# #     ax.plot(_mvn_multi_samples[:, i, 0], _mvn_multi_samples[:, i, 1], 'o', alpha=0.3, color=cmap[i], label=i)
# # ax.set_xlim(-5, 5)
# # ax.set_ylim(-2, 6)
# # ax.set_aspect('equal')
# # ax.set_title('Test MultivariateNormal multiple distributions')
# # ax.set_xlabel('x')
# # ax.set_ylabel('y')
# # ax.legend()
# # plt.show()

In [None]:
_cat = dist.Categorical(p_real, validate_args=True)
_cat_samples = _cat.sample(rng_key, (100,))
print(_cat_samples.shape)

## Mixture distribution

In [None]:
class MixtureMultivariateNormal(dist.Distribution):
    def __init__(self, mixing_probs, loc, covariance_matrix=None, scale_tril=None, validate_args=False):
        print("loc: ", loc.shape)
        print("mixing_probs: ", mixing_probs.shape)
        if covariance_matrix is not None and scale_tril is not None:
            raise AttributeError("Only one of `covariance_matrix` or `scale_tril` can be given")
        if covariance_matrix is None and scale_tril is None:
            raise AttributeError("One of `covariance_matrix` or `scale_tril` needs to be given")
        if covariance_matrix is not None:
            print("covariance_matrix: ", covariance_matrix.shape)
            self._mvn = dist.MultivariateNormal(loc=loc, covariance_matrix=covariance_matrix, validate_args=validate_args)
        if scale_tril is not None:
            print("scale_tril: ", scale_tril.shape)
            self._mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril, validate_args=validate_args)
        self._categorical = dist.Categorical(mixing_probs, validate_args=validate_args)
        super().__init__(validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        print(f"sample(sample_shape={sample_shape})")
        key, key_idx = jax.random.split(key)
        samples = self._mvn.sample(key, sample_shape)
        print(f"samples={samples.shape})")
        ind = self._categorical.sample(key_idx, sample_shape)
        print(f"ind={ind.shape})")
        print(f"ind[..., None, None]={ind[..., None, None].shape})")
        return jnp.take_along_axis(samples, ind[..., None, None], 1)[:, 0, :]

    def log_prob(self, value):
        print(f"log_prob(value={value.shape})")
        print(f"value[:, None, :]: {value[:, None, :].shape}")
        probs_mixture = self._mvn.log_prob(value[:, None, :])
        print("probs_mixture: ", probs_mixture.shape)
        print("self._categorical.logits: ", self._categorical.logits.shape)
        sum_probs = self._categorical.logits + probs_mixture
        print("sum_probs: ", sum_probs.shape)
        lse = jax.nn.logsumexp(sum_probs, axis=-1)
        print("lse: ", lse.shape)
        return lse

    
_mvn_mixture = MixtureMultivariateNormal(loc=loc_real, covariance_matrix=cov_real, mixing_probs=p_real, validate_args=True)
_mvn_mixture_samples = _mvn_mixture.sample(rng_key, (1000,))
print("_mvn_mixture_samples: ", _mvn_mixture_samples.shape)

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
ax.plot(_mvn_mixture_samples[:, 0], _mvn_mixture_samples[:, 1], 'o', alpha=0.3)
ax.set_xlim(-5, 5)
ax.set_ylim(-2, 6)
ax.set_aspect('equal')
ax.set_title('Observations')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()


_lp = _mvn_mixture.log_prob(_mvn_mixture_samples)
print("_lp: ", _lp.shape)


### Gaussian Mixture Model

In [None]:
def gmm_model(d, k, obs=None):
    # Prior for cluster probabilities
    prob_cluster = numpyro.sample('prob_cluster', dist.Dirichlet(concentration=jnp.ones((k, ))))
    print("prob_cluster: ", prob_cluster.shape)
    # Prior on cluster means
    with numpyro.plate('k_plate', k, dim=-2):
        scale = numpyro.sample("scale", dist.HalfCauchy(scale=jnp.ones(d)*2))
        print("scale: ", scale.shape)
        loc = numpyro.sample('loc', dist.Cauchy(loc=jnp.zeros(d), scale=jnp.ones(d)*2))
        print("loc: ", loc.shape)
    with numpyro.plate('k_plate', k, dim=-1):
        lkj_chol = numpyro.sample("lkj_chol", dist.LKJCholesky(dimension=d, concentration=1.))
        print("lkj_chol: ", lkj_chol.shape)
        L_cov = numpyro.deterministic("L_cov", scale[..., None] * lkj_chol)
        print("L_cov: ", L_cov.shape)
    numpyro.sample('obs', MixtureMultivariateNormal(mixing_probs=prob_cluster, loc=loc, scale_tril=L_cov), obs=obs)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(gmm_model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=4,
    chain_method='parallel',
)
mcmc.run(rng_key, d=2, k=3, obs=obs_data)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()

In [None]:
import warnings
warnings.filterwarnings('ignore', message="divide by zero encountered in true_divide")
warnings.filterwarnings('ignore', message="invalid value encountered in true_divide")
warnings.filterwarnings('ignore', message="invalid value encountered in double_scalars")

In [None]:
az.summary(mcmc, var_names=["~lkj_chol", "~L_cov", "~scale"], round_to=2)

In [None]:
rng_key = jax.random.PRNGKey(42)

posterior_predictive = Predictive(gmm_model, posterior_samples=posterior_samples, batch_ndims=0)
posterior_predictions = posterior_predictive(rng_key, d=2, k=3, obs=None)
print('Posterior predictions: ', posterior_predictions['obs'].shape)

In [None]:
inference_data = az.from_numpyro(
    posterior=mcmc,
    posterior_predictive=posterior_predictions,
    coords={"cluster": np.arange(k), "dim": np.arange(2)},
    dims={"loc": ["cluster", "dim"], "scale": ["cluster", "dim"], "prob_cluster": ["cluster"]}
)
display(inference_data)


In [None]:
az.plot_trace(
    inference_data,
    compact=True,
    var_names=["~lkj_chol"],
    lines=[
        ("prob_cluster", {}, p_real),
        ("loc", {}, loc_real),
        ("L_cov", {}, L_real)
    ],
)
plt.suptitle('Trace plots', fontsize=18)
plt.show()

In [None]:

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
ax.plot(posterior_predictions["obs"][:, 0], posterior_predictions["obs"][:, 1], 'o', alpha=0.1)
ax.set_xlim(-5, 5)
ax.set_ylim(-2, 6)
ax.set_aspect('equal')
ax.set_title('Posterior predicted observations')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()