Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Bayesian Gaussian Mixture (Integration of GSoC2015 -- second step) #6651

Merged
merged 33 commits into from Aug 30, 2016
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
37660ed
Modification of GaussianMixture class.
tguillemot Aug 1, 2016
c773d61
Fix comments.
tguillemot Aug 2, 2016
bd03020
Modification of the Docstring.
tguillemot Aug 4, 2016
a068ca4
Add license and author.
tguillemot Aug 4, 2016
fc6e422
Add the new BayesianGaussianMixture class.
tguillemot Apr 6, 2016
5d53d02
Add the use of the cholesky decomposition of the precision matrix.
tguillemot May 20, 2016
18f9e9a
Fix some bugs.
tguillemot May 25, 2016
0751194
Fix pb typo of eq 10.64 and 10.62.
tguillemot Jun 16, 2016
3d9a2c9
Correct VBGMM bugs.
tguillemot Jul 16, 2016
c935fa6
Fix full version.
tguillemot Jul 25, 2016
37dceb3
Fix the precision normalisation pb.
tguillemot Jul 25, 2016
8fc4ea1
Fix all cov_type algo for BayesianGaussianMixture.
tguillemot Jul 27, 2016
00199e7
Optimisation of spherical and diag computation.
tguillemot Jul 27, 2016
7d7d803
Code simplification.
tguillemot Jul 28, 2016
68922f2
Check the Gaussian Mixture tests are ok.
tguillemot Jul 28, 2016
27cd957
Add test.
tguillemot Jul 28, 2016
3e66a16
Add new tests for BayesianGaussianMixture and GaussianMixture.
tguillemot Jul 29, 2016
d029632
Add the bayesian_gaussian_example and the doc.
tguillemot Aug 3, 2016
9436885
Fix comments.
tguillemot Aug 4, 2016
09831b4
Fix review comments and add license and author.
tguillemot Aug 4, 2016
fbeb957
Fix test compare covar type.
tguillemot Aug 4, 2016
0ae0e5f
Fix reviews.
tguillemot Aug 5, 2016
9b116a7
Fix tests.
tguillemot Aug 8, 2016
a170d28
Fix review comments.
tguillemot Aug 10, 2016
4b25b39
Correct reviews.
tguillemot Aug 16, 2016
9897f8a
Fix travis pb.
tguillemot Aug 17, 2016
3964cd2
Fix circleci pb.
tguillemot Aug 17, 2016
0b07ca8
Fix review comments.
tguillemot Aug 18, 2016
3079b06
Fix typo.
tguillemot Aug 19, 2016
ebc2242
Fix comments.
tguillemot Aug 29, 2016
d2804fa
Fix comments.
tguillemot Aug 29, 2016
a5dcd7f
Fix comments.
tguillemot Aug 29, 2016
9c7ca50
[ci skip] Correct legend.
tguillemot Aug 30, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/classes.rst
Expand Up @@ -954,8 +954,8 @@ See the :ref:`metrics` section of the user guide for further details.
:template: class.rst

mixture.GaussianMixture
mixture.BayesianGaussianMixture
mixture.DPGMM
mixture.VBGMM


.. _multiclass_ref:
Expand Down
80 changes: 47 additions & 33 deletions doc/modules/mixture.rst
Expand Up @@ -133,40 +133,13 @@ parameters to maximize the likelihood of the data given those
assignments. Repeating this process is guaranteed to always converge
to a local optimum.

.. _vbgmm:
.. _bgmm:

VBGMM: variational Gaussian mixtures
====================================
Bayesian Gaussian Mixture
=========================

The :class:`VBGMM` object implements a variant of the Gaussian mixture
model with :ref:`variational inference <variational_inference>` algorithms.

Pros and cons of class :class:`VBGMM`: variational inference
------------------------------------------------------------

Pros
.....

:Regularization: due to the incorporation of prior information,
variational solutions have less pathological special cases than
expectation-maximization solutions. One can then use full
covariance matrices in high dimensions or in cases where some
components might be centered around a single point without
risking divergence.

Cons
.....

:Bias: to regularize a model one has to add biases. The
variational algorithm will bias all the means towards the origin
(part of the prior information adds a "ghost point" in the origin
to every mixture component) and it will bias the covariances to
be more spherical. It will also, depending on the concentration
parameter, bias the cluster structure either towards uniformity
or towards a rich-get-richer scenario.

:Hyperparameters: this algorithm needs an extra hyperparameter
that might need experimental tuning via cross-validation.
The :class:`BayesianGaussianMixture` object implements a variant of the Gaussian
mixture model with variational inference algorithms.

.. _variational_inference:

Expand All @@ -175,7 +148,7 @@ Estimation algorithm: variational inference

Variational inference is an extension of expectation-maximization that
maximizes a lower bound on model evidence (including
priors) instead of data likelihood. The principle behind
priors) instead of data likelihood. The principle behind
variational methods is the same as expectation-maximization (that is
both are iterative algorithms that alternate between finding the
probabilities for each point to be generated by each mixture and
Expand All @@ -195,6 +168,47 @@ to some mixture components getting almost all the points while most
mixture components will be centered on just a few of the remaining
points.

.. figure:: ../auto_examples/mixture/images/sphx_glr_plot_bayesian_gaussian_mixture_001.png
:target: ../auto_examples/mixture/plot_bayesian_gaussian_mixture.html
:align: center
:scale: 50%

.. topic:: Examples:

* See :ref:`plot_bayesian_gaussian_mixture.py` for a comparaison of
the results of the ``BayesianGaussianMixture`` for different values
of the parameter ``alpha``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alpha parameter has been renamed to dirichlet_concentration_prior.


Pros and cons of variational inference with :class:BayesianGaussianMixture
--------------------------------------------------------------------------

Pros
.....

:Regularization: due to the incorporation of prior information,
variational solutions have less pathological special cases than
expectation-maximization solutions.

:Automatic selection: when `dirichlet_concentration_prior` is small enough and
`n_components` is larger than what is found necessary by the model, the
Variational Bayesian mixture model has a natural tendency to set some mixture
weights values close to zero. This makes it possible to let the model choose a
suitable number of effective components automatically.

Copy link
Member

@ogrisel ogrisel Aug 19, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add that the Variational Bayesian mixture model has a natural tendency to set some mixture weights to values close to zero when alpha is small enough and n_components is larger that what is found necessary by the model. This makes it possible to let the model choose a suitable number of effective components automatically.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually alpha is now named dirichlet_concentration_prior but you get what I meant.

Cons
.....

:Bias: to regularize a model one has to add biases. The
variational algorithm will bias all the means towards the origin
(part of the prior information adds a "ghost point" in the origin
to every mixture component) and it will bias the covariances to
be more spherical. It will also, depending on the concentration
parameter, bias the cluster structure either towards uniformity
or towards a rich-get-richer scenario.

:Hyperparameters: this algorithm needs an extra hyperparameter
that might need experimental tuning via cross-validation.

.. _dpgmm:

DPGMM: Infinite Gaussian mixtures
Expand Down
27 changes: 18 additions & 9 deletions doc/whats_new.rst
Expand Up @@ -64,13 +64,13 @@ Model Selection Enhancements and API Changes

- **Parameters ``n_folds`` and ``n_iter`` renamed to ``n_splits``**

Some parameter names have changed:
The ``n_folds`` parameter in :class:`model_selection.KFold`,
:class:`model_selection.LabelKFold`, and
Some parameter names have changed:
The ``n_folds`` parameter in :class:`model_selection.KFold`,
:class:`model_selection.LabelKFold`, and
:class:`model_selection.StratifiedKFold` is now renamed to ``n_splits``.
The ``n_iter`` parameter in :class:`model_selection.ShuffleSplit`,
:class:`model_selection.LabelShuffleSplit`,
and :class:`model_selection.StratifiedShuffleSplit` is now renamed
:class:`model_selection.LabelShuffleSplit`,
and :class:`model_selection.StratifiedShuffleSplit` is now renamed
to ``n_splits``.


Expand Down Expand Up @@ -141,8 +141,8 @@ New features
<https://github.com/scikit-learn/scikit-learn/pull/6954>`_) by `Nelson
Liu`_

- Added new cross-validation splitter
:class:`model_selection.TimeSeriesSplit` to handle time series data.
- Added new cross-validation splitter
:class:`model_selection.TimeSeriesSplit` to handle time series data.
(`#6586
<https://github.com/scikit-learn/scikit-learn/pull/6586>`_) by `YenChen
Lin`_
Expand Down Expand Up @@ -396,10 +396,19 @@ API changes summary
- Access to public attributes ``.X_`` and ``.y_`` has been deprecated in
:class:`isotonic.IsotonicRegression`. By `Jonathan Arfa`_.

- The old :class:`VBGMM` is deprecated in favor of the new
:class:`BayesianGaussianMixture`. The new class solves the computational
problems of the old class and computes the Variational Bayesian Gaussian
mixture faster than before.
Ref :ref:`b` for more information.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tguillemot what's b supposed to reference? It's a dead link.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be fix with #7295.

(`#6651 <https://github.com/scikit-learn/scikit-learn/pull/6651>`_) by
`Wei Xue`_ and `Thierry Guillemot`_.

- The old :class:`GMM` is deprecated in favor of the new
:class:`GaussianMixture`. The new class computes the Gaussian mixture
faster than before and some of computational problems have been solved.
By `Wei Xue`_ and `Thierry Guillemot`_.
(`#6666 <https://github.com/scikit-learn/scikit-learn/pull/6666>`_) by
`Wei Xue`_ and `Thierry Guillemot`_.

- The ``grid_scores_`` attribute of :class:`model_selection.GridSearchCV`
and :class:`model_selection.RandomizedSearchCV` is deprecated in favor of
Expand All @@ -409,7 +418,7 @@ API changes summary
`Raghav R V`_.

- The parameters ``n_iter`` or ``n_folds`` in old CV splitters are replaced
by the new parameter ``n_splits`` since it can provide a consistent
by the new parameter ``n_splits`` since it can provide a consistent
and unambiguous interface to represent the number of train-test splits.
(`#7187 <https://github.com/scikit-learn/scikit-learn/pull/7187>`_)
by `YenChen Lin`_.
Expand Down
115 changes: 115 additions & 0 deletions examples/mixture/plot_bayesian_gaussian_mixture.py
@@ -0,0 +1,115 @@
"""
======================================================
Bayesian Gaussian Mixture Concentration Prior Analysis
======================================================

Plot the resulting ellipsoids of a mixture of three Gaussians with EM and
variational Bayesian Gaussian Mixture for three different values of on the prior
the dirichlet concentration.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the -> on the


For all models, the Variationnal Bayesian Gaussian Mixture adapts its number of
mixture automatically. The parameter `dirichlet_concentration_prior` has a
direct link with the resulting number of components. Specifying a high value of
`dirichlet_concentration_prior` leads more often to uniformly-sized mixture
components, while specifying small (under 0.1) values will lead to some mixture
components getting almost all the points while most mixture components will be
centered on just a few of the remaining points.
"""
# Author: Wei Xue <xuewei4d@gmail.com>
# Thierry Guillemot <thierry.guillemot.work@gmail.com>
# License: BSD 3 clause

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from sklearn.mixture import BayesianGaussianMixture

print(__doc__)


def plot_ellipses(ax, weights, means, covars):
for n in range(means.shape[0]):
v, w = np.linalg.eigh(covars[n][:2, :2])
u = w[0] / np.linalg.norm(w[0])
angle = np.arctan2(u[1], u[0])
angle = 180 * angle / np.pi # convert to degrees
v = 2 * np.sqrt(2) * np.sqrt(v)
ell = mpl.patches.Ellipse(means[n, :2], v[0], v[1], 180 + angle)
ell.set_clip_box(ax.bbox)
ell.set_alpha(weights[n])
ax.add_artist(ell)


def plot_results(ax1, ax2, estimator, dirichlet_concentration_prior, X, y, plot_title=False):
estimator.dirichlet_concentration_prior = dirichlet_concentration_prior
estimator.fit(X)
ax1.set_title("Bayesian Gaussian Mixture for "
r"$dc_0=%.1e$" % dirichlet_concentration_prior)
# ax1.axis('equal')
ax1.scatter(X[:, 0], X[:, 1], s=5, marker='o', color=colors[y], alpha=0.8)
ax1.set_xlim(-2., 2.)
ax1.set_ylim(-3., 3.)
ax1.set_xticks(())
ax1.set_yticks(())
plot_ellipses(ax1, estimator.weights_, estimator.means_,
estimator.covariances_)

ax2.get_xaxis().set_tick_params(direction='out')
ax2.yaxis.grid(True, alpha=0.7)
for k, w in enumerate(estimator.weights_):
ax2.bar(k - .45, w, width=0.9, color='royalblue', zorder=3)
ax2.text(k, w + 0.007, "%.1f%%" % (w * 100.),
horizontalalignment='center')
ax2.set_xlim(-.6, 2 * n_components - .4)
ax2.set_ylim(0., 1.1)
ax2.tick_params(axis='y', which='both', left='off',
right='off', labelleft='off')
ax2.tick_params(axis='x', which='both', top='off')

if plot_title:
ax1.set_ylabel('Estimated Mixtures')
ax2.set_ylabel('Weight of each component')

# Parameters
random_state = 2
n_components, n_features = 3, 2
colors = np.array(['mediumseagreen', 'royalblue', 'r', 'gold',
'orchid', 'indigo', 'darkcyan', 'tomato'])
dirichlet_concentration_prior = np.logspace(-3, 3, 3)
covars = np.array([[[.7, .0], [.0, .1]],
[[.5, .0], [.0, .1]],
[[.5, .0], [.0, .1]]])
samples = np.array([200, 500, 200])
means = np.array([[.0, -.70],
[.0, .0],
[.0, .70]])


# Here we put beta_prior to 0.8 to minimize the influence of the prior for this
# dataset
estimator = BayesianGaussianMixture(n_components=2 * n_components,
init_params='random', max_iter=1500,
mean_precision_prior=.8, tol=1e-9,
random_state=random_state)

# Generate data
rng = np.random.RandomState(random_state)
X = np.vstack([
rng.multivariate_normal(means[j], covars[j], samples[j])
for j in range(n_components)])
y = np.concatenate([j * np.ones(samples[j], dtype=int)
for j in range(n_components)])

# Plot Results
plt.figure(figsize=(4.7 * 3, 8))
plt.subplots_adjust(bottom=.04, top=0.95, hspace=.05, wspace=.05,
left=.03, right=.97)

gs = gridspec.GridSpec(3, len(dirichlet_concentration_prior))
for k, dc in enumerate(dirichlet_concentration_prior):
plot_results(plt.subplot(gs[0:2, k]), plt.subplot(gs[2, k]),
estimator, dc, X, y, plot_title=k == 0)

plt.show()
4 changes: 3 additions & 1 deletion sklearn/mixture/__init__.py
Expand Up @@ -8,6 +8,7 @@
from .dpgmm import DPGMM, VBGMM

from .gaussian_mixture import GaussianMixture
from .bayesian_mixture import BayesianGaussianMixture


__all__ = ['DPGMM',
Expand All @@ -17,4 +18,5 @@
'distribute_covar_matrix_to_match_covariance_type',
'log_multivariate_normal_density',
'sample_gaussian',
'GaussianMixture']
'GaussianMixture',
'BayesianGaussianMixture']
11 changes: 7 additions & 4 deletions sklearn/mixture/base.py
Expand Up @@ -237,7 +237,6 @@ def fit(self, X, y=None):

return self

@abstractmethod
def _e_step(self, X):
"""E step.

Expand All @@ -248,12 +247,14 @@ def _e_step(self, X):
Returns
-------
log_prob_norm : array, shape (n_samples,)
log p(X)
Logarithm of the probability of X.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logarithm of the probability of each sample in X.


log_responsibility : array, shape (n_samples, n_components)
logarithm of the responsibilities
Logarithm of the posterior probabilities (or responsibilities) of
the point of X.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of each sample in X.

"""
pass
log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
return np.mean(log_prob_norm), log_resp

@abstractmethod
def _m_step(self, X, log_resp):
Expand All @@ -264,6 +265,8 @@ def _m_step(self, X, log_resp):
X : array-like, shape (n_samples, n_features)

log_resp : array-like, shape (n_samples, n_components)
Logarithm of the posterior probabilities (or responsibilities) of
the point of X.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of each sample in X.

"""
pass

Expand Down