#  NumPyro Gaussian Mixture Model with discrete sampling

This notebook illustrates how to build a GMM in NumPyro using discrete sampling from the categorical distribution.

Note that in practice you probably want to usa a marginalized mixture model, as is [illustrated in PyMC3 here](https://docs.pymc.io/notebooks/marginalized_gaussian_mixture_model.html).

In [None]:
import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
%load_ext watermark
%watermark --iversions

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

In [None]:
np.random.seed(42)

n = 2500 # Total number of samples
k = 3  # Number of clusters
p_real = np.array([0.2, 0.3, 0.5])  # Probability of choosing each cluster
mus_real = np.array([-1., 1., 4.])  #  Mu of clusters
sigmas_real = np.array([0.2, 0.9, 0.5])  # Sigma of clusters
clusters = np.random.choice(k, size=n, p=p_real)
data = np.random.normal(mus_real[clusters], sigmas_real[clusters], size=n)

print(f'{n} samples in total from {k} clusters. data: {data.shape}')
sns.histplot(data, kde=True)
plt.show()

In [None]:
def gmm_model(data, k):
    # Prior for cluster probabilities
    # Diriclet([1,1,1]) is like uniform distribution over all clusters
    selection_prob = numpyro.sample('selection_prob', dist.Dirichlet(concentration=jnp.ones(k)))
    # Prior on cluster means
    with numpyro.plate('k_plate', k):
        mu = numpyro.sample('mu', dist.Normal(loc=0., scale=10.))
        sigma = numpyro.sample('scale', dist.HalfCauchy(scale=10))
    # Data needs to have it's onwn plate due to the categorical
    with numpyro.plate('data', len(data)):
        cluster_idx = numpyro.sample('cluster_idx', dist.Categorical(selection_prob))
        numpyro.sample('x', dist.Normal(loc=mu[cluster_idx], scale=sigma[cluster_idx]), obs=data)

In [None]:
rng_key = jax.random.PRNGKey(0)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(gmm_model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
)
mcmc.run(rng_key, data=data, k=k)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()

Make sure to `infer_discrete` as [mentionned here](https://github.com/pyro-ppl/numpyro/issues/1121#issuecomment-897363003)

In [None]:
posterior_predictive = Predictive(gmm_model, posterior_samples, infer_discrete=True)
posterior_predictions = posterior_predictive(rng_key, k=k, data=data)

Make sure to ad the selected indices (discrete samples) to the MCMC samples, as is [mentionned here](https://github.com/pyro-ppl/numpyro/issues/1121#issuecomment-897363003).

In [None]:
# Add "cluster_idx" values to mcmc samples
posterior_samples["cluster_idx"] = posterior_predictions["cluster_idx"]

In [None]:


inference_data = az.from_numpyro(
    posterior=mcmc,
    posterior_predictive=posterior_predictions,
)
display(inference_data)

In [None]:
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True)
plt.show()