# Gaussian Mixture Model (GMM)

Based on:
- https://docs.pymc.io/notebooks/marginalized_gaussian_mixture_model.html

NumPyro Marginalized Mixture model:
- https://forum.pyro.ai/t/sample-from-the-mixture-same-family-distribution/3178/

More info on marginalized mixture models:
- https://www.youtube.com/watch?v=KOIudAB6vJ0
- https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html


In [None]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

In [None]:
np.random.seed(42)

n = 2500 # Total number of samples
k = 3  # Number of clusters
p_real = np.array([0.2, 0.3, 0.5])  # Probability of choosing each cluster
mus_real = np.array([-1., 1., 4.])  #  Mu of clusters
sigmas_real = np.array([0.2, 0.9, 0.5])  # Sigma of clusters
clusters = np.random.choice(k, size=n, p=p_real)
x_data = np.random.normal(mus_real[clusters], sigmas_real[clusters], size=n)

print(f'{n} samples in total from {k} clusters. x_data: {x_data.shape}')
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
sns.histplot(x_data, kde=True, ax=ax)
ax.set_xlabel('x')
plt.show()

A natural parameterization of the Gaussian mixture model is as the latent variable model

$$
\begin{split}\begin{align*}
\mu_1, \ldots, \mu_k        & \sim \mathcal{N}(0, \sigma^2) \\
\sigma_1, \ldots, \sigma_k  & \sim \text{HalfCauchy}(b) \\
w                           & \sim \text{Dirichlet}(\alpha_1, \ldots. \alpha_k) \\
z \mid w                    & \sim \text{Categorial}(w) \\
x \mid z                    & \sim \mathcal{N}(\mu_z, \sigma_z).
\end{align*}\end{split}
$$

The disadvantage of this is that that sampling the posterior relies on sampling from the discrete categorical variable $z$. And thus we need to create with a lot of different elements in order not to get stuck during sampling.

An alternative is to try to marginalise out the categorical $z$ to sample from a single [mixture distribution](https://en.wikipedia.org/wiki/Mixture_distribution) at the end:

$$
\begin{split}\begin{align*}
\mu_1, \ldots, \mu_k          & \sim \mathcal{N}(0, \sigma^2) \\
\sigma_1, \ldots, \sigma_k    & \sim \text{HalfCauchy}(b) \\
w                             & \sim \text{Dirichlet}(\alpha_1, \ldots. \alpha_k) \\
f(x \mid w) & = \sum_{i = 1}^k w_i \mathcal{N}(x \mid \mu_i, \sigma_z)
\end{align*}\end{split}
$$

with

$$
N(x \mid \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi} \sigma} \exp\left(-\frac{1}{2 \sigma^2} (x - \mu)^2\right)
$$

In [None]:
class MixtureGaussian(dist.Distribution):
    def __init__(self, loc, scale, mixing_probs, validate_args=None):
        expand_shape = jax.lax.broadcast_shapes(
            jnp.shape(loc), jnp.shape(scale), jnp.shape(mixing_probs)
        )
        self._gaussian = dist.Normal(loc=loc, scale=scale).expand(expand_shape)
        self._categorical = dist.Categorical(jnp.broadcast_to(mixing_probs, expand_shape))
        super(MixtureGaussian, self).__init__(batch_shape=expand_shape[:-1], validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        key, key_idx = jax.random.split(key)
        samples = self._gaussian.sample(key, sample_shape)
        ind = self._categorical.sample(key_idx, sample_shape)
        return jnp.take_along_axis(samples, ind[..., None], -1)[..., 0]

    def log_prob(self, value):
        print(f"\nlog_prob(value={value.shape})")
        value_reshaped = value[..., None]
        print("value_reshaped: ", value_reshaped.shape)
        probs_mixture = self._gaussian.log_prob(value[..., None])
        print("probs_mixture: ", probs_mixture.shape)
        sum_probs = self._categorical.logits + probs_mixture
        print("sum_probs: ", sum_probs.shape)
        lse = jax.nn.logsumexp(sum_probs, axis=-1)
        print("lse: ", lse.shape)
        return lse

In [None]:
def gmm_model(k, x=None):
    # Prior for cluster probabilities
    prob_cluster = numpyro.sample('prob_cluster', dist.Dirichlet(concentration=jnp.ones(k)))
    # Prior on cluster means
    with numpyro.plate('k_plate', k):
        loc = numpyro.sample('loc', dist.Normal(loc=0., scale=10.))
        sigma = numpyro.sample('scale', dist.HalfCauchy(scale=10))
    print("loc: ", loc.shape)
    print("sigma: ", sigma.shape)
    numpyro.sample('x', MixtureGaussian(loc=loc, scale=sigma, mixing_probs=prob_cluster), obs=x)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(gmm_model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
)
mcmc.run(rng_key, x=x_data, k=k)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()

In [None]:
rng_key = jax.random.PRNGKey(42)

x_posterior = np.linspace(min(x_data)-1, max(x_data)+1, 100)

posterior_predictive = Predictive(gmm_model, posterior_samples=posterior_samples)
posterior_predictions = posterior_predictive(rng_key, k=k)
print('Posterior predictions: ', posterior_predictions['x'].shape)

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(10, 6))
inference_data = az.from_numpyro(
    posterior=mcmc,
    posterior_predictive=posterior_predictions,
    coords={"cluster": np.arange(k)},
    dims={"loc": ["cluster"], "scale": ["cluster"], "prob_cluster": ["cluster"]}
)
display(inference_data)

az.plot_trace(inference_data, compact=True, axes=axes)
plt.suptitle('Trace plots', fontsize=18)
plt.show()

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(12, 8))
az.plot_posterior(inference_data, var_names=['loc', 'scale', 'prob_cluster'], kind='hist', ax=axes)
plt.suptitle('Posterior plots', fontsize=18)
plt.show()

In [None]:

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
sns.histplot(x_data, kde=True, label='original', ax=ax, color=sns.color_palette("tab10")[0])
sns.histplot(posterior_predictions['x'], kde=True, label='posterior', ax=ax, color=sns.color_palette("tab10")[1])
ax.set_title("Posterior predictive vs Original")
ax.set_xlabel("x")
ax.legend()
plt.show()