-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* First pass at mixture modelling * No longer necessary to reference self.comp_dists directly in logp * Add dimension internally (when necessary) * Import get_tau_sd * Misc bugfixes * Add sampling to Mixtures * Differentiate between Discrete and Continuous mixtures when possible * Add support for 2D weights * Gracefully try to calculate mean and mode defaults * Add docstrings for Mixture classes * Export mixture models * Reference self.comp_dists * Remove unnecessary pm. * Add Mixture tests * Add missing imports * Add marginalized Gaussian mixture model example * Calculate the mode of the mixture distribution correctly
- Loading branch information
1 parent
f1e622f
commit 2572852
Showing
5 changed files
with
609 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
319 changes: 319 additions & 0 deletions
319
docs/source/notebooks/marginalized_gaussian_mixture_model.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import numpy as np | ||
import theano.tensor as tt | ||
|
||
from ..math import logsumexp | ||
from .dist_math import bound | ||
from .distribution import Discrete, Distribution, draw_values, generate_samples | ||
from .continuous import get_tau_sd, Normal | ||
|
||
|
||
def all_discrete(comp_dists): | ||
""" | ||
Determine if all distributions in comp_dists are discrete | ||
""" | ||
if isinstance(comp_dists, Distribution): | ||
return isinstance(comp_dists, Discrete) | ||
else: | ||
return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists) | ||
|
||
|
||
class Mixture(Distribution): | ||
R""" | ||
Mixture log-likelihood | ||
Often used to model subpopulation heterogeneity | ||
.. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i) | ||
======== ============================================ | ||
Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)` | ||
Mean :math:`\sum_{i = 1}^n w_i \mu_i` | ||
======== ============================================ | ||
Parameters | ||
---------- | ||
w : array of floats | ||
w >= 0 and w <= 1 | ||
the mixutre weights | ||
comp_dists : multidimensional PyMC3 distribution or iterable of one-dimensional PyMC3 distributions | ||
the component distributions :math:`f_1, \ldots, f_n` | ||
""" | ||
def __init__(self, w, comp_dists, *args, **kwargs): | ||
shape = kwargs.pop('shape', ()) | ||
|
||
self.w = w | ||
self.comp_dists = comp_dists | ||
|
||
defaults = kwargs.pop('defaults', []) | ||
|
||
if all_discrete(comp_dists): | ||
dtype = kwargs.pop('dtype', 'int64') | ||
else: | ||
dtype = kwargs.pop('dtype', 'float64') | ||
|
||
try: | ||
self.mean = (w * self._comp_means()).sum(axis=-1) | ||
|
||
if 'mean' not in defaults: | ||
defaults.append('mean') | ||
except AttributeError: | ||
pass | ||
|
||
try: | ||
comp_modes = self._comp_modes() | ||
comp_mode_logps = self.logp(comp_modes) | ||
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)] | ||
|
||
if 'mode' not in defaults: | ||
defaults.append('mode') | ||
except AttributeError: | ||
pass | ||
|
||
super(Mixture, self).__init__(shape, dtype, defaults=defaults, | ||
*args, **kwargs) | ||
|
||
def _comp_logp(self, value): | ||
comp_dists = self.comp_dists | ||
|
||
try: | ||
value_ = value if value.ndim > 1 else tt.shape_padright(value) | ||
|
||
return comp_dists.logp(value_) | ||
except AttributeError: | ||
return tt.stack([comp_dist.logp(value) for comp_dist in comp_dists], | ||
axis=1) | ||
|
||
def _comp_means(self): | ||
try: | ||
return self.comp_dists.mean | ||
except AttributeError: | ||
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists], | ||
axis=1) | ||
|
||
def _comp_modes(self): | ||
try: | ||
return self.comp_dists.mode | ||
except AttributeError: | ||
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists], | ||
axis=1) | ||
|
||
def _comp_samples(self, point=None, size=None, repeat=None): | ||
try: | ||
samples = self.comp_dists.random(point=point, size=size, repeat=repeat) | ||
except AttributeError: | ||
samples = np.column_stack([comp_dist.random(point=point, size=size, repeat=repeat) | ||
for comp_dist in self.comp_dists]) | ||
|
||
return np.squeeze(samples) | ||
|
||
def logp(self, value): | ||
w = self.w | ||
|
||
return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1).sum(), | ||
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1)) | ||
|
||
def random(self, point=None, size=None, repeat=None): | ||
def random_choice(*args, **kwargs): | ||
w = kwargs.pop('w') | ||
w /= w.sum(axis=-1, keepdims=True) | ||
k = w.shape[-1] | ||
|
||
if w.ndim > 1: | ||
return np.row_stack([np.random.choice(k, p=w_) for w_ in w]) | ||
else: | ||
return np.random.choice(k, p=w, *args, **kwargs) | ||
|
||
w = draw_values([self.w], point=point) | ||
|
||
w_samples = generate_samples(random_choice, | ||
w=w, | ||
broadcast_shape=w.shape[:-1] or (1,), | ||
dist_shape=self.shape, | ||
size=size).squeeze() | ||
comp_samples = self._comp_samples(point=point, size=size, repeat=repeat) | ||
|
||
if comp_samples.ndim > 1: | ||
return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples]) | ||
else: | ||
return np.squeeze(comp_samples[w_samples]) | ||
|
||
|
||
class NormalMixture(Mixture): | ||
R""" | ||
Normal mixture log-likelihood | ||
.. math:: f(x \mid w, \mu, \sigma^2) = \sum_{i = 1}^n w_i N(x \mid \mu_i, \sigma^2_i | ||
======== ======================================= | ||
Support :math:`x \in \mathbb{R}` | ||
Mean :math:`\sum_{i = 1}^n w_i \mu_i` | ||
Variance :math:`\sum_{i = 1}^n w_i^2 \sigma^2_i` | ||
======== ======================================= | ||
Parameters | ||
w : array of floats | ||
w >= 0 and w <= 1 | ||
the mixutre weights | ||
mu : array of floats | ||
the component means | ||
sd : array of floats | ||
the component standard deviations | ||
tau : array of floats | ||
the component precisions | ||
""" | ||
def __init__(self, w, mu, *args, **kwargs): | ||
_, sd = get_tau_sd(tau=kwargs.pop('tau', None), | ||
sd=kwargs.pop('sd', None)) | ||
|
||
super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd), | ||
*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
|
||
from .helpers import SeededTest | ||
from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample | ||
|
||
|
||
# Generate data | ||
def generate_normal_mixture_data(w, mu, sd, size=1000): | ||
component = np.random.choice(w.size, size=size, p=w) | ||
x = np.random.normal(mu[component], sd[component], size=size) | ||
|
||
return x | ||
|
||
|
||
def generate_poisson_mixture_data(w, mu, size=1000): | ||
component = np.random.choice(w.size, size=size, p=w) | ||
x = np.random.poisson(mu[component], size=size) | ||
|
||
return x | ||
|
||
|
||
class TestMixture(SeededTest): | ||
@classmethod | ||
def setUpClass(cls): | ||
super(TestMixture, cls).setUpClass() | ||
|
||
cls.norm_w = np.array([0.75, 0.25]) | ||
cls.norm_mu = np.array([0., 5.]) | ||
cls.norm_sd = np.ones_like(cls.norm_mu) | ||
cls.norm_x = generate_normal_mixture_data(cls.norm_w, cls.norm_mu, cls.norm_sd, size=1000) | ||
|
||
cls.pois_w = np.array([0.4, 0.6]) | ||
cls.pois_mu = np.array([5., 20.]) | ||
cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000) | ||
|
||
def test_mixture_list_of_normals(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.norm_w)) | ||
|
||
mu = Normal('mu', 0., 10., shape=self.norm_w.size) | ||
tau = Gamma('tau', 1., 1., shape=self.norm_w.size) | ||
|
||
x_obs = Mixture('x_obs', w, | ||
[Normal.dist(mu[0], tau=tau[0]), | ||
Normal.dist(mu[1], tau=tau[1])], | ||
observed=self.norm_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.norm_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.norm_mu), | ||
rtol=0.1, atol=0.1) | ||
|
||
def test_normal_mixture(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.norm_w)) | ||
|
||
mu = Normal('mu', 0., 10., shape=self.norm_w.size) | ||
tau = Gamma('tau', 1., 1., shape=self.norm_w.size) | ||
|
||
x_obs = NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.norm_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.norm_mu), | ||
rtol=0.1, atol=0.1) | ||
|
||
def test_poisson_mixture(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.pois_w)) | ||
|
||
mu = Gamma('mu', 1., 1., shape=self.pois_w.size) | ||
|
||
x_obs = Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.pois_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.pois_mu), | ||
rtol=0.1, atol=0.1) | ||
|
||
def test_mixture_list_of_poissons(self): | ||
with Model() as model: | ||
w = Dirichlet('w', np.ones_like(self.pois_w)) | ||
|
||
mu = Gamma('mu', 1., 1., shape=self.pois_w.size) | ||
|
||
x_obs = Mixture('x_obs', w, | ||
[Poisson.dist(mu[0]), Poisson.dist(mu[1])], | ||
observed=self.pois_x) | ||
|
||
step = Metropolis() | ||
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) | ||
|
||
assert_allclose(np.sort(trace['w'].mean(axis=0)), | ||
np.sort(self.pois_w), | ||
rtol=0.1, atol=0.1) | ||
assert_allclose(np.sort(trace['mu'].mean(axis=0)), | ||
np.sort(self.pois_mu), | ||
rtol=0.1, atol=0.1) |