-
Notifications
You must be signed in to change notification settings - Fork 21
add dp_mixgauss related functions #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Transformed from | ||
# https://github.com/ericmjl/dl-workshop/blob/master/src/dl_workshop/gaussian_mixture.py | ||
# change the code to work with the current jax | ||
|
||
import jax.numpy as jnp | ||
from jax.scipy import stats | ||
|
||
|
||
def loglike_one_component(component_weight, component_mu, log_component_scale, datum): | ||
"""Log likelihood of datum under one component of the mixture. | ||
Defined as the log likelihood of observing that datum from the component | ||
(i.e. log of component probability) | ||
added to the log likelihood of observing that datum | ||
under the Gaussian that belongs to that component. | ||
:param component_weight: Component weight, a scalar value between 0 and 1. | ||
:param component_mu: A scalar value. | ||
:param log_component_scale: A scalar value. | ||
Gets exponentiated before being passed into norm.logpdf. | ||
:returns: A scalar. | ||
""" | ||
component_scale = jnp.exp(log_component_scale) | ||
return jnp.log(component_weight) + stats.norm.logpdf(datum, loc=component_mu, scale=component_scale) | ||
|
||
|
||
def normalize_weights(weights): | ||
"""Normalize a weights vector to sum to 1.""" | ||
return weights / jnp.sum(weights) | ||
|
||
|
||
from functools import partial | ||
from jax.scipy.special import logsumexp | ||
from jax import vmap | ||
|
||
|
||
def loglike_across_components(log_component_weights, component_mus, | ||
log_component_scales, datum): | ||
"""Log likelihood of datum under all components of the mixture.""" | ||
component_weights = normalize_weights(jnp.exp(log_component_weights)) | ||
loglike_components = vmap(partial(loglike_one_component, datum=datum))( | ||
component_weights, component_mus, log_component_scales) | ||
return logsumexp(loglike_components) | ||
|
||
|
||
def mixture_loglike(log_component_weights, component_mus, | ||
log_component_scales, data): | ||
"""Log likelihood of data (not datum!) under all components of the mixture.""" | ||
ll_per_data = vmap(partial(loglike_across_components, log_component_weights, | ||
component_mus, log_component_scales,))(data) | ||
return jnp.sum(ll_per_data) | ||
|
||
|
||
from jax.scipy.stats import norm | ||
|
||
|
||
def plot_component_norm_pdfs(log_component_weights, component_mus, | ||
log_component_scales, xmin, xmax, ax, title): | ||
component_weights = normalize_weights(jnp.exp(log_component_weights)) | ||
component_scales = jnp.exp(log_component_scales) | ||
x = jnp.linspace(xmin, xmax, 1000).reshape(-1, 1) | ||
pdfs = component_weights * norm.pdf(x, loc=component_mus, scale=component_scales) | ||
for component in range(pdfs.shape[1]): | ||
ax.plot(x, pdfs[:, component]) | ||
ax.set_title(title) | ||
|
||
|
||
def get_loss(state, get_params_func, loss_func, data): | ||
params = get_params_func(state) | ||
loss_score = loss_func(params, data) | ||
return loss_score | ||
|
||
|
||
import matplotlib.pyplot as plt | ||
from celluloid import Camera | ||
|
||
|
||
def animate_training(params_for_plotting, interval, data_mixture): | ||
"""Animation function for mixture likelihood.""" | ||
log_component_weights_history = params_for_plotting['log_component_weight'] | ||
component_mus_history = params_for_plotting['component_mus'] | ||
log_component_scales_history = params_for_plotting['log_component_scale'] | ||
fig, ax = plt.subplots() | ||
cam = Camera(fig) | ||
for w, m, s in zip(log_component_weights_history[::interval], | ||
component_mus_history[::interval], | ||
log_component_scales_history[::interval]): | ||
ax.hist(data_mixture, bins=40, density=True, color="blue") | ||
plot_component_norm_pdfs(w, m, s, xmin=-20, xmax=20, ax=ax, title=None) | ||
cam.snap() | ||
animation = cam.animate() | ||
return animation | ||
|
||
|
||
from jax import lax | ||
|
||
|
||
def stick_breaking_weights(beta_draws): | ||
"""Return weights from a stick breaking process. | ||
:param beta_draws: i.i.d draws from a Beta distribution. | ||
This should be a row vector. | ||
""" | ||
def weighting(occupied_probability, beta_i): | ||
""" | ||
:param occupied_probability: The cumulative occupied probability taken up. | ||
:param beta_i: Current value of beta to consider. | ||
""" | ||
weight = (1 - occupied_probability) * beta_i | ||
return occupied_probability + weight, weight | ||
occupied_probability, weights = lax.scan(weighting, jnp.array(0.0), beta_draws) | ||
weights = weights / jnp.sum(weights) | ||
return occupied_probability, weights | ||
|
||
|
||
from jax import random | ||
|
||
|
||
def beta_draw_from_weights(weights): | ||
def beta_from_w(accounted_probability, weights_i): | ||
""" | ||
:param accounted_probability: The cumulative probability acounted for. | ||
:param weights_i: Current value of weights to consider. | ||
""" | ||
denominator = 1 - accounted_probability | ||
log_denominator = jnp.log(denominator) | ||
log_beta_i = jnp.log(weights_i) - log_denominator | ||
newly_accounted_probability = accounted_probability + weights_i | ||
return newly_accounted_probability, jnp.exp(log_beta_i) | ||
final, betas = lax.scan(beta_from_w, jnp.array(0.0), weights) | ||
return final, betas | ||
|
||
|
||
def component_probs_loglike(log_component_probs, log_concentration, num_components): | ||
"""Evaluate log likelihood of probability vector under Dirichlet process. | ||
:param log_component_probs: A vector. | ||
:param log_concentration: Real-valued scalar. | ||
:param num_compnents: Scalar integer. | ||
""" | ||
concentration = jnp.exp(log_concentration) | ||
component_probs = normalize_weights(jnp.exp(log_component_probs)) | ||
_, beta_draws = beta_draw_from_weights(component_probs) | ||
eval_draws = beta_draws[:num_components] | ||
return jnp.sum(stats.beta.logpdf(x=eval_draws, a=1, b=concentration)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import jax.numpy as jnp | ||
from jax import random | ||
from collections import namedtuple | ||
from multivariate_t_utils import log_predic_t | ||
|
||
|
||
def dp_mixture_simu(N, alpha, H, key): | ||
""" | ||
Generating samples from the Gaussian Dirichlet process mixture model. | ||
We set the base measure of the DP to be Normal Inverse Wishart (NIW) | ||
and the likelihood be multivariate normal distribution | ||
------------------------------------------------------ | ||
N: int | ||
Number of samples to be generated from the mixture model | ||
alpha: float | ||
Concentration parameter of the Dirichlet process | ||
H: object of NormalInverseWishart | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better to avoid short. ambiguous variable names. Replace |
||
Base measure of the Dirichlet process | ||
key: jax.random.PRNGKey | ||
Seed of initial random cluster | ||
-------------------------------------------- | ||
* array(N): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specify that these are return parameters, and give the variable names (eg |
||
Simulation of cluster assignment | ||
* array(N, dimension): | ||
Simulation of samples from the DP mixture model | ||
* array(K, dimension): | ||
Simulation of mean of each cluster | ||
* array(K, dimension, dimension): | ||
Simulation of covariance of each cluster | ||
""" | ||
Z = jnp.full(N, 0) | ||
# Sample cluster assignment from the Chinese restaurant process prior | ||
CR = [] | ||
for i in range(N): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A fun (optional!) exercise would be to figure out how to vectorize this (eg with |
||
p = jnp.array(CR + [alpha]) | ||
key, subkey = random.split(key) | ||
k = random.categorical(subkey, logits=jnp.log(p)) | ||
# Add new cluster to the mixture | ||
if k == len(CR): | ||
CR = CR + [1] | ||
# Increase the size of corresponding cluster by 1 | ||
else: | ||
CR[k] += 1 | ||
Z = Z.at[i].set(k) | ||
# Sample the parameters for each component of the mixture distribution, from the base measure | ||
key, subkey = random.split(key) | ||
params = H.sample(seed=subkey, sample_shape=(len(CR),)) | ||
Sigma = params['Sigma'] | ||
Mu = params['mu'] | ||
# Sample from the mixture distribtuion | ||
subkeys = random.split(key, N) | ||
X = [random.multivariate_normal(subkeys[i], Mu[Z[i]], Sigma[Z[i]]) for i in range(N)] | ||
return Z, jnp.array(X), Mu, Sigma | ||
|
||
|
||
def dp_cluster(T, X, alpha, hyper_params, key): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Give this function a more descriptive name, eg |
||
""" | ||
Implementation of algorithm3 of R.M.Neal(2000) | ||
https://www.tandfonline.com/doi/abs/10.1080/10618600.2000.10474879 | ||
The clustering analysis using Gaussian Dirichlet process (DP) mixture model | ||
--------------------------------------------------------------------------- | ||
T: int | ||
Number of iterations of the MCMC sampling | ||
X: array(size_of_data, dimension) | ||
The array of observations | ||
alpha: float | ||
Concentration parameter of the DP | ||
hyper_params: object of NormalInverseWishart | ||
Base measure of the Dirichlet process | ||
key: jax.random.PRNGKey | ||
Seed of initial random cluster | ||
---------------------------------- | ||
* array(T, size_of_data): | ||
Simulation of cluster assignment | ||
""" | ||
n, dim = X.shape | ||
Zs = [] | ||
Cluster = namedtuple('Cluster', ["label", "members"]) | ||
# Initialize by setting all observations to cluster0 | ||
cluster0 = Cluster(label=0, members=list(range(n))) | ||
# CR is set of clusters | ||
CR = [cluster0] | ||
Z = jnp.full(n, 0) | ||
new_label = 1 | ||
for t in range(T): | ||
# Update the cluster assignment for every observation | ||
for i in range(n): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be vectorized? |
||
labels = [cluster.label for cluster in CR] | ||
j = labels.index(Z[i]) | ||
CR[j].members.remove(i) | ||
if len(CR[j].members) == 0: | ||
del CR[j] | ||
lp0 = [jnp.log(len(cluster.members)) + log_predic_t(X[i,], jnp.atleast_2d(X[cluster.members[:],]), hyper_params) for cluster in CR] | ||
lp1 = [jnp.log(alpha) + log_predic_t(X[i,], jnp.empty((0, dim)), hyper_params)] | ||
logits = jnp.array(lp0 + lp1) | ||
key, subkey = random.split(key) | ||
k = random.categorical(subkey, logits=logits) | ||
if k==len(logits)-1: | ||
new_cluster = Cluster(label=new_label, members=[i]) | ||
new_label += 1 | ||
CR.append(new_cluster) | ||
Z = Z.at[i].set(new_cluster.label) | ||
else: | ||
CR[k].members.append(i) | ||
Z = Z.at[i].set(CR[k].label) | ||
Zs.append(Z) | ||
return jnp.array(Zs) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
""" | ||
This implementation of Normal Inverse Wishart distribution is directly copied from the note book of Scott Linderman: | ||
'Implementing a Normal Inverse Wishart Distribution in Tensorflow Probability' | ||
https://github.com/lindermanlab/hackathons/blob/master/notebooks/TFP_Normal_Inverse_Wishart.ipynb | ||
and | ||
https://github.com/lindermanlab/hackathons/blob/master/notebooks/TFP_Normal_Inverse_Wishart_(Part_2).ipynb | ||
""" | ||
import jax.numpy as np | ||
import jax.random as jr | ||
from jax import vmap | ||
from jax.tree_util import tree_map | ||
import tensorflow_probability.substrates.jax as tfp | ||
import matplotlib.pyplot as plt | ||
from functools import partial | ||
|
||
tfd = tfp.distributions | ||
tfb = tfp.bijectors | ||
|
||
|
||
class NormalInverseWishart(tfd.JointDistributionNamed): | ||
def __init__(self, loc, mean_precision, df, scale, **kwargs): | ||
""" | ||
A normal inverse Wishart (NIW) distribution with | ||
|
||
Args: | ||
loc: \mu_0 in math above | ||
mean_precision: \kappa_0 | ||
df: \nu | ||
scale: \Psi | ||
|
||
Returns: | ||
A tfp.JointDistribution object. | ||
""" | ||
# Store hyperparameters. | ||
self._loc = loc | ||
self._mean_precision = mean_precision | ||
self._df = df | ||
self._scale = scale | ||
|
||
# Convert the inverse Wishart scale to the scale_tril of a Wishart. | ||
# Note: this could be done more efficiently. | ||
self.wishart_scale_tril = np.linalg.cholesky(np.linalg.inv(scale)) | ||
|
||
super(NormalInverseWishart, self).__init__(dict( | ||
Sigma=lambda: tfd.TransformedDistribution( | ||
tfd.WishartTriL(df, scale_tril=self.wishart_scale_tril), | ||
tfb.Chain([tfb.CholeskyOuterProduct(), | ||
tfb.CholeskyToInvCholesky(), | ||
tfb.Invert(tfb.CholeskyOuterProduct()) | ||
])), | ||
mu=lambda Sigma: tfd.MultivariateNormalFullCovariance( | ||
loc, Sigma / mean_precision) | ||
)) | ||
|
||
# Replace the default JointDistributionNamed parameters with the NIW ones | ||
# because the JointDistributionNamed parameters contain lambda functions, | ||
# which are not jittable. | ||
self._parameters = dict( | ||
loc=loc, | ||
mean_precision=mean_precision, | ||
df=df, | ||
scale=scale | ||
) | ||
|
||
# These functions compute the pseudo-observations implied by the NIW prior | ||
# and convert sufficient statistics to a NIW posterior. We'll describe them | ||
# in more detail below. | ||
@property | ||
def natural_parameters(self): | ||
"""Compute pseudo-observations from standard NIW parameters.""" | ||
dim = self._loc.shape[-1] | ||
chi_1 = self._df + dim + 2 | ||
chi_2 = np.einsum('...,...i->...i', self._mean_precision, self._loc) | ||
chi_3 = self._scale + self._mean_precision * \ | ||
np.einsum("...i,...j->...ij", self._loc, self._loc) | ||
chi_4 = self._mean_precision | ||
return chi_1, chi_2, chi_3, chi_4 | ||
|
||
@classmethod | ||
def from_natural_parameters(cls, natural_params): | ||
"""Convert natural parameters into standard parameters and construct.""" | ||
chi_1, chi_2, chi_3, chi_4 = natural_params | ||
dim = chi_2.shape[-1] | ||
df = chi_1 - dim - 2 | ||
mean_precision = chi_4 | ||
loc = np.einsum('..., ...i->...i', 1 / mean_precision, chi_2) | ||
scale = chi_3 - mean_precision * np.einsum('...i,...j->...ij', loc, loc) | ||
return cls(loc, mean_precision, df, scale) | ||
|
||
def _mode(self): | ||
r"""Solve for the mode. Recall, | ||
.. math:: | ||
p(\mu, \Sigma) \propto | ||
\mathrm{N}(\mu | \mu_0, \Sigma / \kappa_0) \times | ||
\mathrm{IW}(\Sigma | \nu_0, \Psi_0) | ||
The optimal mean is :math:`\mu^* = \mu_0`. Substituting this in, | ||
.. math:: | ||
p(\mu^*, \Sigma) \propto IW(\Sigma | \nu_0 + 1, \Psi_0) | ||
and the mode of this inverse Wishart distribution is at | ||
.. math:: | ||
\Sigma^* = \Psi_0 / (\nu_0 + d + 2) | ||
""" | ||
dim = self._loc.shape[-1] | ||
covariance = np.einsum("...,...ij->...ij", | ||
1 / (self._df + dim + 2), self._scale) | ||
return self._loc, covariance | ||
|
||
|
||
class MultivariateNormalFullCovariance(tfd.MultivariateNormalFullCovariance): | ||
""" | ||
This wrapper adds simple functions to get sufficient statistics and | ||
construct a MultivariateNormalFullCovariance from parameters drawn | ||
from the normal inverse Wishart distribution. | ||
""" | ||
@classmethod | ||
def from_parameters(cls, params, **kwargs): | ||
return cls(*params, **kwargs) | ||
|
||
@staticmethod | ||
def sufficient_statistics(datapoint): | ||
return (1.0, datapoint, np.outer(datapoint, datapoint), 1.0) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to
dp_mixgauss_ancestral_sample