In [5]:
try:
    import probml_utils as pml
except ModuleNotFoundError:
    %pip install -qq git+https://github.com/probml/probml-utils.git
    import probml_utils as pml

In [None]:
from probml_utils.dp_mixgauss_utils import dp_mixgauss_sample, NormalInverseWishart

In [2]:
import jax.numpy as jnp
from jax import random, vmap
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

In [None]:
# Example
dim = 2
# Set the hyperparameter for the NIW distribution
hyper_params = dict(loc=jnp.zeros(dim), mean_precision=0.05, df=dim + 5, scale=jnp.eye(dim))

# Generate the NIW object
dp_base_measure = NormalInverseWishart(**hyper_params)
key = random.PRNGKey(0)
num_of_samples = 1000
key, *subkeys = random.split(key, 3)

# Sampling
output1 = dp_mixgauss_sample(key, num_of_samples, 1.0, dp_base_measure)
cluster_params1, cluster_indices1, samples1 = output1
output2 = dp_mixgauss_sample(key, num_of_samples, 2.0, dp_base_measure)
cluster_params2, cluster_indices2, samples2 = output2

# Plotting
cluster_parameters = (cluster_params1, cluster_params2)
cluster_indices = (cluster_indices1, cluster_indices2)
dp_samples = (samples1, samples2)
bb = jnp.arange(0, 2 * jnp.pi, 0.02)
sample_size = [50, 500, 1000]
fig, axes = plt.subplots(3, 2)
plt.setp(axes, xticks=[], yticks=[])
for i in range(2):
    Mu = cluster_parameters[i]["mu"]
    Sigma = cluster_parameters[i]["Sigma"]
    Z = cluster_indices[i]
    X = dp_samples[i]
    Sig_root = jnp.array([sqrtm(sigma) for sigma in Sigma])
    for j in range(3):
        s = sample_size[j]
        axes[j, i].plot(X[:s, 0], X[:s, 1], ".", markersize=5)
        for k in jnp.unique(Z[:s]):
            sig_root = Sig_root[
                k,
            ]
            mu = Mu[
                [k],
            ].T
            circ = mu.dot(jnp.ones((1, len(bb)))) + sig_root.dot(jnp.vstack([jnp.sin(bb), jnp.cos(bb)]))
            axes[j, i].plot(circ[0, :], circ[1, :], linewidth=2, color="k")
plt.show()