Skip to content

Commit

Permalink
ENH Mixture Models (#1437)
Browse files Browse the repository at this point in the history
* First pass at mixture modelling

* No longer necessary to reference self.comp_dists directly in logp

* Add dimension internally (when necessary)

* Import get_tau_sd

* Misc bugfixes

* Add sampling to Mixtures

* Differentiate between Discrete and Continuous mixtures when possible

* Add support for 2D weights

* Gracefully try to calculate mean and mode defaults

* Add docstrings for Mixture classes

* Export mixture models

* Reference self.comp_dists

* Remove unnecessary pm.

* Add Mixture tests

* Add missing imports

* Add marginalized Gaussian mixture model example

* Calculate the mode of the mixture distribution correctly
  • Loading branch information
AustinRochford authored and twiecki committed Oct 18, 2016
1 parent f1e622f commit 2572852
Show file tree
Hide file tree
Showing 5 changed files with 609 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Mixture Models

.. toctree::
notebooks/gaussian_mixture_model.ipynb
notebooks/marginalized_gaussian_mixture_model.ipynb
notebooks/gaussian-mixture-model-advi.ipynb
notebooks/dp_mix.ipynb

319 changes: 319 additions & 0 deletions docs/source/notebooks/marginalized_gaussian_mixture_model.ipynb

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
from .distribution import TensorType
from .distribution import draw_values

from .mixture import Mixture
from .mixture import NormalMixture

from .multivariate import MvNormal
from .multivariate import MvStudentT
from .multivariate import Dirichlet
Expand Down Expand Up @@ -112,5 +115,7 @@
'AR1',
'GaussianRandomWalk',
'GARCH11',
'SkewNormal'
'SkewNormal',
'Mixture',
'NormalMixture'
]
169 changes: 169 additions & 0 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import numpy as np
import theano.tensor as tt

from ..math import logsumexp
from .dist_math import bound
from .distribution import Discrete, Distribution, draw_values, generate_samples
from .continuous import get_tau_sd, Normal


def all_discrete(comp_dists):
"""
Determine if all distributions in comp_dists are discrete
"""
if isinstance(comp_dists, Distribution):
return isinstance(comp_dists, Discrete)
else:
return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists)


class Mixture(Distribution):
R"""
Mixture log-likelihood
Often used to model subpopulation heterogeneity
.. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)
======== ============================================
Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)`
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
======== ============================================
Parameters
----------
w : array of floats
w >= 0 and w <= 1
the mixutre weights
comp_dists : multidimensional PyMC3 distribution or iterable of one-dimensional PyMC3 distributions
the component distributions :math:`f_1, \ldots, f_n`
"""
def __init__(self, w, comp_dists, *args, **kwargs):
shape = kwargs.pop('shape', ())

self.w = w
self.comp_dists = comp_dists

defaults = kwargs.pop('defaults', [])

if all_discrete(comp_dists):
dtype = kwargs.pop('dtype', 'int64')
else:
dtype = kwargs.pop('dtype', 'float64')

try:
self.mean = (w * self._comp_means()).sum(axis=-1)

if 'mean' not in defaults:
defaults.append('mean')
except AttributeError:
pass

try:
comp_modes = self._comp_modes()
comp_mode_logps = self.logp(comp_modes)
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)]

if 'mode' not in defaults:
defaults.append('mode')
except AttributeError:
pass

super(Mixture, self).__init__(shape, dtype, defaults=defaults,
*args, **kwargs)

def _comp_logp(self, value):
comp_dists = self.comp_dists

try:
value_ = value if value.ndim > 1 else tt.shape_padright(value)

return comp_dists.logp(value_)
except AttributeError:
return tt.stack([comp_dist.logp(value) for comp_dist in comp_dists],
axis=1)

def _comp_means(self):
try:
return self.comp_dists.mean
except AttributeError:
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists],
axis=1)

def _comp_modes(self):
try:
return self.comp_dists.mode
except AttributeError:
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists],
axis=1)

def _comp_samples(self, point=None, size=None, repeat=None):
try:
samples = self.comp_dists.random(point=point, size=size, repeat=repeat)
except AttributeError:
samples = np.column_stack([comp_dist.random(point=point, size=size, repeat=repeat)
for comp_dist in self.comp_dists])

return np.squeeze(samples)

def logp(self, value):
w = self.w

return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1).sum(),
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1))

def random(self, point=None, size=None, repeat=None):
def random_choice(*args, **kwargs):
w = kwargs.pop('w')
w /= w.sum(axis=-1, keepdims=True)
k = w.shape[-1]

if w.ndim > 1:
return np.row_stack([np.random.choice(k, p=w_) for w_ in w])
else:
return np.random.choice(k, p=w, *args, **kwargs)

w = draw_values([self.w], point=point)

w_samples = generate_samples(random_choice,
w=w,
broadcast_shape=w.shape[:-1] or (1,),
dist_shape=self.shape,
size=size).squeeze()
comp_samples = self._comp_samples(point=point, size=size, repeat=repeat)

if comp_samples.ndim > 1:
return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples])
else:
return np.squeeze(comp_samples[w_samples])


class NormalMixture(Mixture):
R"""
Normal mixture log-likelihood
.. math:: f(x \mid w, \mu, \sigma^2) = \sum_{i = 1}^n w_i N(x \mid \mu_i, \sigma^2_i
======== =======================================
Support :math:`x \in \mathbb{R}`
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
Variance :math:`\sum_{i = 1}^n w_i^2 \sigma^2_i`
======== =======================================
Parameters
w : array of floats
w >= 0 and w <= 1
the mixutre weights
mu : array of floats
the component means
sd : array of floats
the component standard deviations
tau : array of floats
the component precisions
"""
def __init__(self, w, mu, *args, **kwargs):
_, sd = get_tau_sd(tau=kwargs.pop('tau', None),
sd=kwargs.pop('sd', None))

super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd),
*args, **kwargs)
114 changes: 114 additions & 0 deletions pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
from numpy.testing import assert_allclose

from .helpers import SeededTest
from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample


# Generate data
def generate_normal_mixture_data(w, mu, sd, size=1000):
component = np.random.choice(w.size, size=size, p=w)
x = np.random.normal(mu[component], sd[component], size=size)

return x


def generate_poisson_mixture_data(w, mu, size=1000):
component = np.random.choice(w.size, size=size, p=w)
x = np.random.poisson(mu[component], size=size)

return x


class TestMixture(SeededTest):
@classmethod
def setUpClass(cls):
super(TestMixture, cls).setUpClass()

cls.norm_w = np.array([0.75, 0.25])
cls.norm_mu = np.array([0., 5.])
cls.norm_sd = np.ones_like(cls.norm_mu)
cls.norm_x = generate_normal_mixture_data(cls.norm_w, cls.norm_mu, cls.norm_sd, size=1000)

cls.pois_w = np.array([0.4, 0.6])
cls.pois_mu = np.array([5., 20.])
cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000)

def test_mixture_list_of_normals(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.norm_w))

mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)

x_obs = Mixture('x_obs', w,
[Normal.dist(mu[0], tau=tau[0]),
Normal.dist(mu[1], tau=tau[1])],
observed=self.norm_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.norm_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.norm_mu),
rtol=0.1, atol=0.1)

def test_normal_mixture(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.norm_w))

mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)

x_obs = NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.norm_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.norm_mu),
rtol=0.1, atol=0.1)

def test_poisson_mixture(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.pois_w))

mu = Gamma('mu', 1., 1., shape=self.pois_w.size)

x_obs = Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.pois_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.pois_mu),
rtol=0.1, atol=0.1)

def test_mixture_list_of_poissons(self):
with Model() as model:
w = Dirichlet('w', np.ones_like(self.pois_w))

mu = Gamma('mu', 1., 1., shape=self.pois_w.size)

x_obs = Mixture('x_obs', w,
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],
observed=self.pois_x)

step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False)

assert_allclose(np.sort(trace['w'].mean(axis=0)),
np.sort(self.pois_w),
rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
np.sort(self.pois_mu),
rtol=0.1, atol=0.1)

0 comments on commit 2572852

Please sign in to comment.