Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions probml_utils/dp_mixgauss_truncatated_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Transformed from
# https://github.com/ericmjl/dl-workshop/blob/master/src/dl_workshop/gaussian_mixture.py
# change the code to work with the current jax

import jax.numpy as jnp
from jax.scipy import stats


def loglike_one_component(component_weight, component_mu, log_component_scale, datum):
"""Log likelihood of datum under one component of the mixture.
Defined as the log likelihood of observing that datum from the component
(i.e. log of component probability)
added to the log likelihood of observing that datum
under the Gaussian that belongs to that component.
:param component_weight: Component weight, a scalar value between 0 and 1.
:param component_mu: A scalar value.
:param log_component_scale: A scalar value.
Gets exponentiated before being passed into norm.logpdf.
:returns: A scalar.
"""
component_scale = jnp.exp(log_component_scale)
return jnp.log(component_weight) + stats.norm.logpdf(datum, loc=component_mu, scale=component_scale)


def normalize_weights(weights):
"""Normalize a weights vector to sum to 1."""
return weights / jnp.sum(weights)


from functools import partial
from jax.scipy.special import logsumexp
from jax import vmap


def loglike_across_components(log_component_weights, component_mus,
log_component_scales, datum):
"""Log likelihood of datum under all components of the mixture."""
component_weights = normalize_weights(jnp.exp(log_component_weights))
loglike_components = vmap(partial(loglike_one_component, datum=datum))(
component_weights, component_mus, log_component_scales)
return logsumexp(loglike_components)


def mixture_loglike(log_component_weights, component_mus,
log_component_scales, data):
"""Log likelihood of data (not datum!) under all components of the mixture."""
ll_per_data = vmap(partial(loglike_across_components, log_component_weights,
component_mus, log_component_scales,))(data)
return jnp.sum(ll_per_data)


from jax.scipy.stats import norm


def plot_component_norm_pdfs(log_component_weights, component_mus,
log_component_scales, xmin, xmax, ax, title):
component_weights = normalize_weights(jnp.exp(log_component_weights))
component_scales = jnp.exp(log_component_scales)
x = jnp.linspace(xmin, xmax, 1000).reshape(-1, 1)
pdfs = component_weights * norm.pdf(x, loc=component_mus, scale=component_scales)
for component in range(pdfs.shape[1]):
ax.plot(x, pdfs[:, component])
ax.set_title(title)


def get_loss(state, get_params_func, loss_func, data):
params = get_params_func(state)
loss_score = loss_func(params, data)
return loss_score


import matplotlib.pyplot as plt
from celluloid import Camera


def animate_training(params_for_plotting, interval, data_mixture):
"""Animation function for mixture likelihood."""
log_component_weights_history = params_for_plotting['log_component_weight']
component_mus_history = params_for_plotting['component_mus']
log_component_scales_history = params_for_plotting['log_component_scale']
fig, ax = plt.subplots()
cam = Camera(fig)
for w, m, s in zip(log_component_weights_history[::interval],
component_mus_history[::interval],
log_component_scales_history[::interval]):
ax.hist(data_mixture, bins=40, density=True, color="blue")
plot_component_norm_pdfs(w, m, s, xmin=-20, xmax=20, ax=ax, title=None)
cam.snap()
animation = cam.animate()
return animation


from jax import lax


def stick_breaking_weights(beta_draws):
"""Return weights from a stick breaking process.
:param beta_draws: i.i.d draws from a Beta distribution.
This should be a row vector.
"""
def weighting(occupied_probability, beta_i):
"""
:param occupied_probability: The cumulative occupied probability taken up.
:param beta_i: Current value of beta to consider.
"""
weight = (1 - occupied_probability) * beta_i
return occupied_probability + weight, weight
occupied_probability, weights = lax.scan(weighting, jnp.array(0.0), beta_draws)
weights = weights / jnp.sum(weights)
return occupied_probability, weights


from jax import random


def beta_draw_from_weights(weights):
def beta_from_w(accounted_probability, weights_i):
"""
:param accounted_probability: The cumulative probability acounted for.
:param weights_i: Current value of weights to consider.
"""
denominator = 1 - accounted_probability
log_denominator = jnp.log(denominator)
log_beta_i = jnp.log(weights_i) - log_denominator
newly_accounted_probability = accounted_probability + weights_i
return newly_accounted_probability, jnp.exp(log_beta_i)
final, betas = lax.scan(beta_from_w, jnp.array(0.0), weights)
return final, betas


def component_probs_loglike(log_component_probs, log_concentration, num_components):
"""Evaluate log likelihood of probability vector under Dirichlet process.
:param log_component_probs: A vector.
:param log_concentration: Real-valued scalar.
:param num_compnents: Scalar integer.
"""
concentration = jnp.exp(log_concentration)
component_probs = normalize_weights(jnp.exp(log_component_probs))
_, beta_draws = beta_draw_from_weights(component_probs)
eval_draws = beta_draws[:num_components]
return jnp.sum(stats.beta.logpdf(x=eval_draws, a=1, b=concentration))
110 changes: 110 additions & 0 deletions probml_utils/dp_mixgauss_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import jax.numpy as jnp
from jax import random
from collections import namedtuple
from multivariate_t_utils import log_predic_t


def dp_mixture_simu(N, alpha, H, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to dp_mixgauss_ancestral_sample

"""
Generating samples from the Gaussian Dirichlet process mixture model.
We set the base measure of the DP to be Normal Inverse Wishart (NIW)
and the likelihood be multivariate normal distribution
------------------------------------------------------
N: int
Number of samples to be generated from the mixture model
alpha: float
Concentration parameter of the Dirichlet process
H: object of NormalInverseWishart
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to avoid short. ambiguous variable names. Replace H with niw_prior.

Base measure of the Dirichlet process
key: jax.random.PRNGKey
Seed of initial random cluster
--------------------------------------------
* array(N):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify that these are return parameters, and give the variable names (eg Z: array(N): ...)

Simulation of cluster assignment
* array(N, dimension):
Simulation of samples from the DP mixture model
* array(K, dimension):
Simulation of mean of each cluster
* array(K, dimension, dimension):
Simulation of covariance of each cluster
"""
Z = jnp.full(N, 0)
# Sample cluster assignment from the Chinese restaurant process prior
CR = []
for i in range(N):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fun (optional!) exercise would be to figure out how to vectorize this (eg with lax.scan). Might be tricky because the shapes need to be of fixed size. I think you could pre-allocate CR to a fixed sized vector and then use a binary mask to select the 'valid' prefix.

p = jnp.array(CR + [alpha])
key, subkey = random.split(key)
k = random.categorical(subkey, logits=jnp.log(p))
# Add new cluster to the mixture
if k == len(CR):
CR = CR + [1]
# Increase the size of corresponding cluster by 1
else:
CR[k] += 1
Z = Z.at[i].set(k)
# Sample the parameters for each component of the mixture distribution, from the base measure
key, subkey = random.split(key)
params = H.sample(seed=subkey, sample_shape=(len(CR),))
Sigma = params['Sigma']
Mu = params['mu']
# Sample from the mixture distribtuion
subkeys = random.split(key, N)
X = [random.multivariate_normal(subkeys[i], Mu[Z[i]], Sigma[Z[i]]) for i in range(N)]
return Z, jnp.array(X), Mu, Sigma


def dp_cluster(T, X, alpha, hyper_params, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give this function a more descriptive name, eg dp_mixgauss_gibbs_sample

"""
Implementation of algorithm3 of R.M.Neal(2000)
https://www.tandfonline.com/doi/abs/10.1080/10618600.2000.10474879
The clustering analysis using Gaussian Dirichlet process (DP) mixture model
---------------------------------------------------------------------------
T: int
Number of iterations of the MCMC sampling
X: array(size_of_data, dimension)
The array of observations
alpha: float
Concentration parameter of the DP
hyper_params: object of NormalInverseWishart
Base measure of the Dirichlet process
key: jax.random.PRNGKey
Seed of initial random cluster
----------------------------------
* array(T, size_of_data):
Simulation of cluster assignment
"""
n, dim = X.shape
Zs = []
Cluster = namedtuple('Cluster', ["label", "members"])
# Initialize by setting all observations to cluster0
cluster0 = Cluster(label=0, members=list(range(n)))
# CR is set of clusters
CR = [cluster0]
Z = jnp.full(n, 0)
new_label = 1
for t in range(T):
# Update the cluster assignment for every observation
for i in range(n):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be vectorized?

labels = [cluster.label for cluster in CR]
j = labels.index(Z[i])
CR[j].members.remove(i)
if len(CR[j].members) == 0:
del CR[j]
lp0 = [jnp.log(len(cluster.members)) + log_predic_t(X[i,], jnp.atleast_2d(X[cluster.members[:],]), hyper_params) for cluster in CR]
lp1 = [jnp.log(alpha) + log_predic_t(X[i,], jnp.empty((0, dim)), hyper_params)]
logits = jnp.array(lp0 + lp1)
key, subkey = random.split(key)
k = random.categorical(subkey, logits=logits)
if k==len(logits)-1:
new_cluster = Cluster(label=new_label, members=[i])
new_label += 1
CR.append(new_cluster)
Z = Z.at[i].set(new_cluster.label)
else:
CR[k].members.append(i)
Z = Z.at[i].set(CR[k].label)
Zs.append(Z)
return jnp.array(Zs)



123 changes: 123 additions & 0 deletions probml_utils/gauss_inv_wishart_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
This implementation of Normal Inverse Wishart distribution is directly copied from the note book of Scott Linderman:
'Implementing a Normal Inverse Wishart Distribution in Tensorflow Probability'
https://github.com/lindermanlab/hackathons/blob/master/notebooks/TFP_Normal_Inverse_Wishart.ipynb
and
https://github.com/lindermanlab/hackathons/blob/master/notebooks/TFP_Normal_Inverse_Wishart_(Part_2).ipynb
"""
import jax.numpy as np
import jax.random as jr
from jax import vmap
from jax.tree_util import tree_map
import tensorflow_probability.substrates.jax as tfp
import matplotlib.pyplot as plt
from functools import partial

tfd = tfp.distributions
tfb = tfp.bijectors


class NormalInverseWishart(tfd.JointDistributionNamed):
def __init__(self, loc, mean_precision, df, scale, **kwargs):
"""
A normal inverse Wishart (NIW) distribution with

Args:
loc: \mu_0 in math above
mean_precision: \kappa_0
df: \nu
scale: \Psi

Returns:
A tfp.JointDistribution object.
"""
# Store hyperparameters.
self._loc = loc
self._mean_precision = mean_precision
self._df = df
self._scale = scale

# Convert the inverse Wishart scale to the scale_tril of a Wishart.
# Note: this could be done more efficiently.
self.wishart_scale_tril = np.linalg.cholesky(np.linalg.inv(scale))

super(NormalInverseWishart, self).__init__(dict(
Sigma=lambda: tfd.TransformedDistribution(
tfd.WishartTriL(df, scale_tril=self.wishart_scale_tril),
tfb.Chain([tfb.CholeskyOuterProduct(),
tfb.CholeskyToInvCholesky(),
tfb.Invert(tfb.CholeskyOuterProduct())
])),
mu=lambda Sigma: tfd.MultivariateNormalFullCovariance(
loc, Sigma / mean_precision)
))

# Replace the default JointDistributionNamed parameters with the NIW ones
# because the JointDistributionNamed parameters contain lambda functions,
# which are not jittable.
self._parameters = dict(
loc=loc,
mean_precision=mean_precision,
df=df,
scale=scale
)

# These functions compute the pseudo-observations implied by the NIW prior
# and convert sufficient statistics to a NIW posterior. We'll describe them
# in more detail below.
@property
def natural_parameters(self):
"""Compute pseudo-observations from standard NIW parameters."""
dim = self._loc.shape[-1]
chi_1 = self._df + dim + 2
chi_2 = np.einsum('...,...i->...i', self._mean_precision, self._loc)
chi_3 = self._scale + self._mean_precision * \
np.einsum("...i,...j->...ij", self._loc, self._loc)
chi_4 = self._mean_precision
return chi_1, chi_2, chi_3, chi_4

@classmethod
def from_natural_parameters(cls, natural_params):
"""Convert natural parameters into standard parameters and construct."""
chi_1, chi_2, chi_3, chi_4 = natural_params
dim = chi_2.shape[-1]
df = chi_1 - dim - 2
mean_precision = chi_4
loc = np.einsum('..., ...i->...i', 1 / mean_precision, chi_2)
scale = chi_3 - mean_precision * np.einsum('...i,...j->...ij', loc, loc)
return cls(loc, mean_precision, df, scale)

def _mode(self):
r"""Solve for the mode. Recall,
.. math::
p(\mu, \Sigma) \propto
\mathrm{N}(\mu | \mu_0, \Sigma / \kappa_0) \times
\mathrm{IW}(\Sigma | \nu_0, \Psi_0)
The optimal mean is :math:`\mu^* = \mu_0`. Substituting this in,
.. math::
p(\mu^*, \Sigma) \propto IW(\Sigma | \nu_0 + 1, \Psi_0)
and the mode of this inverse Wishart distribution is at
.. math::
\Sigma^* = \Psi_0 / (\nu_0 + d + 2)
"""
dim = self._loc.shape[-1]
covariance = np.einsum("...,...ij->...ij",
1 / (self._df + dim + 2), self._scale)
return self._loc, covariance


class MultivariateNormalFullCovariance(tfd.MultivariateNormalFullCovariance):
"""
This wrapper adds simple functions to get sufficient statistics and
construct a MultivariateNormalFullCovariance from parameters drawn
from the normal inverse Wishart distribution.
"""
@classmethod
def from_parameters(cls, params, **kwargs):
return cls(*params, **kwargs)

@staticmethod
def sufficient_statistics(datapoint):
return (1.0, datapoint, np.outer(datapoint, datapoint), 1.0)


Loading