In [None]:
import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp
from jax import lax

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az

from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

## Mixture

In [None]:
from numpyro.distributions import Distribution, Categorical
from numpyro.distributions.util import (
    validate_sample,
    is_prng_key
)


class MixtureSameFamily(Distribution):
    """
    Marginalized Mixture
    
    :param mixing_probabilities: The mixing probabilities between the different distributions of the mixture.
        Shape = (*batch_shape, nb_of_mixtures)
    :param component_distribution: Component distribution vectorized 
    
    
    """
    def __init__(
        self,
        mixing_probabilities,
        component_distribution,
        validate_args=None
    ):
        print("\nMixtureSameFamily.__init__")
        self._mixture_size = mixing_probabilities.shape[-1]
        print("\t_mixture_size: ", self._mixture_size)
        self._categorical = dist.Categorical(mixing_probabilities)
        print("\tself._categorical.shape(): ", self._categorical.shape())
        print("\tself._categorical.batch_shape: ", self._categorical.batch_shape)
        print("\tself._categorical.event_shape: ", self._categorical.event_shape)
        print("\tself._categorical.event_dim: ", self._categorical.event_dim)
        self._component_distribution = component_distribution
        print("\tself._component_distribution.shape(): ", self._component_distribution.shape())
        print("\tself._component_distribution.batch_shape: ", self._component_distribution.batch_shape)
        print("\tself._component_distribution.event_shape: ", self._component_distribution.event_shape)
        print("\tself._component_distribution.event_dim: ", self._component_distribution.event_dim)
        if not isinstance(self._component_distribution, Distribution):
            raise ValueError(
                "The component distribution need to be a numpyro.distributions.Distribution. "
                f"However, it is of type {type(self._component_distribution)}"
            )
        expected_component_batch_shape = self._categorical.batch_shape + (self.mixture_size,)
        assert self._component_distribution.batch_shape == expected_component_batch_shape, (
            f"Component distribution batch shape does not correspond to expected shape according to the "
            f"mixing probabilities shape {self._component_distribution.batch_shape} != {expected_component_batch_shape}"
        )
        batch_shape = mixing_probabilities.shape[:-1]
        print("\tbatch_shape: ", batch_shape)
        event_shape = component_distribution.event_shape
        print("\tevent_shape: ", event_shape)
        super().__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)
        
    @property
    def mixture_size(self):
        """
        Returns the number of distributions in the mixture

        :return: number of mixtures.
        :rtype: int
        """
        return self._mixture_size


    def sample(self, key, sample_shape=()):
        """
        Returns a sample from the mixture distribution having shape given by
        `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
        leading dimensions (of size `sample_shape`) of the returned sample will
        be filled with iid draws from the distribution instance.

        :param jax.random.PRNGKey key: the rng_key key to be used for the distribution.
        :param tuple sample_shape: the sample shape for the distribution.
        :return: an array of shape `sample_shape + batch_shape + event_shape`
        :rtype: numpy.ndarray
        """
        print(f"MixtureSameFamily.sample(sample_shape={sample_shape})")
        assert is_prng_key(key)
        key_comp, key_ind = jax.random.split(key)
#         samples:  (100, 7, 3, 2)
        # Samples from component distribution will have shape (*sample_shape, *batch_shape, *event_shape)
        samples = self._component_distribution.sample(key_comp, sample_shape)
        assert samples.shape == (*sample_shape, *self.batch_shape, self.mixture_size, *self.event_shape)
        print("\tsamples: ", samples.shape)
        # Sample selection indices from the categorical (shape will be sample_shape)
        ind = self._categorical.sample(key_ind, sample_shape)
#         assert ind.shape == sample_shape
        print("\tind: ", ind.shape)
#         _x = lax.broadcast_shapes(jnp.shape(ind), jnp.shape(samples))
#         print("_x: ", _x)
#         ind_t = ind[..., None]
#         print("\tind_t: ", ind_t.shape)
        # Account for default event dimension and end of range (+2 total)
#         axis_expand = tuple(range(-1, -(self.event_dim+2), -1))
#         print("\taxis_expand: ", axis_expand)
#         ind_expanded = jnp.expand_dims(ind, axis=axis_expand)
        n_expand = self.event_dim + 1
        ind_expanded = ind.reshape(ind.shape + (1,)*n_expand)
#         _r = jax.lax.index_take(samples, idxs=ind, axes=(0,))
#         print("_r: ", _r.shape)
        print("\tind_expanded: ", ind_expanded.shape)
#         assert np.allclose(ind_expanded, ind_t)
        axis_to_select = -(self.event_dim+1)
        print("\taxis_to_select: ", axis_to_select)
        samples_selected = jnp.take_along_axis(samples, indices=ind_expanded, axis=axis_to_select)
        print("\tsamples_selected: ", samples_selected.shape)
        final_samples = jnp.squeeze(samples_selected, axis=axis_to_select)
        assert final_samples.shape == (*sample_shape, *self.batch_shape, *self.event_shape)
        return final_samples

    @validate_sample
    def log_prob(self, value):
        """
        Evaluates the log probability density for a batch of samples given by
        `value`.

        :param value: A batch of samples from the distribution.
        :return: an array with shape `value.shape[:-self.event_shape]`
        :rtype: numpy.ndarray
        
        
        
        Return shape (*value.shape[:-self.event_shape])
        """
        print(f"MixtureSameFamily.log_prob(value={value.shape})")
#         value_reshaped = value[..., None]
        nb_value_dims = len(value.shape) - self.event_dim  # Without event dim
        print("\tnb_value_dims: ", nb_value_dims)
        if len(self.batch_shape) > 0:
            batch_shape_size = len(self.batch_shape)
            print("\tbatch_shape_size: ", batch_shape_size)
            assert value.shape[-batch_shape_size:] == self.batch_shape
        prob_dim = (1,)
        reshape = value.shape[:nb_value_dims] + prob_dim + self.event_shape
        print("\treshape: ", reshape)  
        value_reshaped = value.reshape(reshape)
        print("\tvalue_reshaped: ", value_reshaped.shape)
        probs_mixture = self._component_distribution.log_prob(value_reshaped)
        print("\tprobs_mixture: ", probs_mixture.shape)
        print("\tself._categorical.logits: ", self._categorical.logits.shape)
        sum_log_probs = self._categorical.logits + probs_mixture
        print("\tsum_log_probs: ", sum_log_probs.shape)
        lse = jax.nn.logsumexp(sum_log_probs, axis=-1)  # TODO: double check if these really are logprobs
        print("\tlse: ", lse.shape)
        expected_shape = value.shape[:nb_value_dims]
        print("\texpected_shape: ", expected_shape)
        assert lse.shape == expected_shape
        return lse

In [None]:
k = 3
mixing_probabilities = jnp.ones(k) / k
print("mixing_probabilities: ", mixing_probabilities.shape)
loc = jnp.zeros(k)
print("loc: ", loc.shape)
scale = jnp.ones(k)
print("scale: ", scale.shape)
print('')
normal= dist.Normal(loc=loc, scale=scale)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
mixed = MixtureSameFamily(mixing_probabilities=mixing_probabilities, component_distribution=normal)
print("mixed.shape(): ", mixed.shape())
print("mixed.batch_shape: ", mixed.batch_shape)
print("mixed.event_shape: ", mixed.event_shape)
print("mixed.event_dim: ", mixed.event_dim)
print("mixed.mixture_size: ", mixed.mixture_size)
print('')
samples = mixed.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = mixed.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
k = 3
mixing_probabilities = jnp.ones(k) / k
print("mixing_probabilities: ", mixing_probabilities.shape)
loc = jnp.zeros(k)
print("loc: ", loc.shape)
scale = jnp.ones(k)
print("scale: ", scale.shape)
print('')
normal= dist.Normal(loc=loc, scale=scale)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
mixed = MixtureSameFamily(mixing_probabilities=mixing_probabilities, component_distribution=normal)
print("mixed.shape(): ", mixed.shape())
print("mixed.batch_shape: ", mixed.batch_shape)
print("mixed.event_shape: ", mixed.event_shape)
print("mixed.event_dim: ", mixed.event_dim)
print('')
samples = mixed.sample(rng_key, (100, 7))
print('samples: ', samples.shape)
print('')
lp = mixed.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
k = 3
s = 4
mixing_probabilities = jnp.vstack([jnp.ones(k) / k for _ in range(s)])
print("mixing_probabilities: ", mixing_probabilities.shape)
loc = jnp.vstack([jnp.zeros(k) / k for _ in range(s)])
print("loc: ", loc.shape)
scale =jnp.vstack([jnp.ones(k) / k for _ in range(s)])
print("scale: ", scale.shape)
print('')
normal= dist.Normal(loc=loc, scale=scale)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
mixed = MixtureSameFamily(mixing_probabilities=mixing_probabilities, component_distribution=normal)
print("mixed.shape(): ", mixed.shape())
print("mixed.batch_shape: ", mixed.batch_shape)
print("mixed.event_shape: ", mixed.event_shape)
print("mixed.event_dim: ", mixed.event_dim)
print('')
samples = mixed.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = mixed.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
jnp.vstack([jnp.ones(k) / k for _ in range(1)]).shape

In [None]:
k = 3
d = 2
mixing_probabilities = jnp.ones(k) / k
print("mixing_probabilities: ", mixing_probabilities.shape)
loc = jnp.zeros((k, d))
print("loc: ", loc.shape)
cov_matrix = jnp.repeat(jnp.expand_dims(jnp.eye(d, d), 0), k, axis=0)
print("cov_matrix: ", cov_matrix.shape)
print('')
normal= dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
mixed = MixtureSameFamily(mixing_probabilities=mixing_probabilities, component_distribution=normal)
print("mixed.shape(): ", mixed.shape())
print("mixed.batch_shape: ", mixed.batch_shape)
print("mixed.event_shape: ", mixed.event_shape)
print("mixed.event_dim: ", mixed.event_dim)
print('')
samples = mixed.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = mixed.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
k = 3
d = 1
mixing_probabilities = jnp.ones(k) / k
print("mixing_probabilities: ", mixing_probabilities.shape)
loc = jnp.zeros((k, d))
print("loc: ", loc.shape)
cov_matrix = jnp.repeat(jnp.expand_dims(jnp.eye(d, d), 0), k, axis=0)
print("cov_matrix: ", cov_matrix.shape)
print('')
normal= dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
mixed = MixtureSameFamily(mixing_probabilities=mixing_probabilities, component_distribution=normal)
print("mixed.shape(): ", mixed.shape())
print("mixed.batch_shape: ", mixed.batch_shape)
print("mixed.event_shape: ", mixed.event_shape)
print("mixed.event_dim: ", mixed.event_dim)
print('')
samples = mixed.sample(rng_key, (100,))
print('samples: ', samples.shape)
print('')
lp = mixed.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
k = 3
d = 2
mixing_probabilities = jnp.ones(k) / k
print("mixing_probabilities: ", mixing_probabilities.shape)
loc = jnp.zeros((k, d))
print("loc: ", loc.shape)
cov_matrix = jnp.repeat(jnp.expand_dims(jnp.eye(d, d), 0), k, axis=0)
print("cov_matrix: ", cov_matrix.shape)
print('')
normal= dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix)
print("normal.shape(): ", normal.shape())
print("normal.batch_shape: ", normal.batch_shape)
print("normal.event_shape: ", normal.event_shape)
print("normal.event_dim: ", normal.event_dim)
print('')
mixed = MixtureSameFamily(mixing_probabilities=mixing_probabilities, component_distribution=normal)
print("mixed.shape(): ", mixed.shape())
print("mixed.batch_shape: ", mixed.batch_shape)
print("mixed.event_shape: ", mixed.event_shape)
print("mixed.event_dim: ", mixed.event_dim)
print('')
samples = mixed.sample(rng_key, (100, 7))
print('samples: ', samples.shape)
print('')
lp = mixed.log_prob(samples)
print('lp: ', lp.shape)

In [None]:
mixed._component_distribution.tree_flatten()

## Try GMM

In [None]:
np.random.seed(42)

n = 2500 # Total number of samples
k = 3  # Number of clusters
p_real = np.array([0.2, 0.3, 0.5])  # Probability of choosing each cluster
mus_real = np.array([-1., 1., 4.])  #  Mu of clusters
sigmas_real = np.array([0.2, 0.9, 0.5])  # Sigma of clusters
clusters = np.random.choice(k, size=n, p=p_real)
x_data = np.random.normal(mus_real[clusters], sigmas_real[clusters], size=n)

print(f'{n} samples in total from {k} clusters. x_data: {x_data.shape}')
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
sns.histplot(x_data, kde=True, ax=ax)
ax.set_xlabel('x')
plt.show()

In [None]:
def gmm_model(k, x=None):
    # Prior for cluster probabilities
    prob_cluster = numpyro.sample('prob_cluster', dist.Dirichlet(concentration=jnp.ones(k)))
    print("prob_cluster: ", prob_cluster.shape)
    # Prior on cluster means
    with numpyro.plate('k_plate', k):
        loc = numpyro.sample('loc', dist.Normal(loc=0., scale=10.))
        scale = numpyro.sample('scale', dist.HalfCauchy(scale=10))
    print("loc: ", loc.shape)
    print("scale: ", scale.shape)
    normal= dist.Normal(loc=loc, scale=scale)
    print("normal.shape(): ", normal.shape())
    mixed = MixtureSameFamily(mixing_probabilities=prob_cluster, component_distribution=normal)
    print("mixed.shape(): ", mixed.shape())
    print("mixed.batch_shape: ", mixed.batch_shape)
    print("mixed.event_shape: ", mixed.event_shape)
    print("mixed.event_dim: ", mixed.event_dim)
    mixed = numpyro.sample('x', mixed, obs=x)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(gmm_model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
)
mcmc.run(rng_key, x=x_data, k=k)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()

In [None]:
%debug