Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Generative Classification #2468

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/classes.rst
Expand Up @@ -926,7 +926,7 @@ Pairwise metrics
naive_bayes.GaussianNB
naive_bayes.MultinomialNB
naive_bayes.BernoulliNB

naive_bayes.GenerativeBayes

.. _neighbors_ref:

Expand Down
116 changes: 116 additions & 0 deletions doc/modules/naive_bayes.rst
Expand Up @@ -199,3 +199,119 @@ note::
The ``partial_fit`` method call of naive Bayes models introduces some
computational overhead. It is recommended to use data chunk sizes that are as
large as possible, that is as the available RAM allows.


Non-naive Bayes
---------------

As mentioned above, naive Bayesian methods are generally very fast, but often
inaccurate estimators. This can be addressed by relaxing the assumptions that
make the models naive, so that more accurate classifications are possible.

If we return to the general formalism outlined above, we can see that the
generic model for Bayesian classification is:

.. math::
\hat{y} = \arg\max_y P(y) \prod_{i=1}^{n} P(x_i \mid y).

This model only becomes "naive" when we introduce certain assumptions about
the form of :math:`P(x_i \mid y)`, e.g. that each class is drawn from an
axis-aligned normal distribution (the assumption for Gaussian Naive Bayes).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what makes the model naive is that your assume conditional independence of the features. I find this paragraph not clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this paragraph erroneous.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it’s wrong as I suggested in 2016 ;)


However, assumptions like these are in no way required for generative
Bayesian classification formalism: we can equally well fit any suitable
density model to each category to estimate :math:`P(x_i \mid y)`. Some
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gives the impression that your code estimates a KDE/GMM for each feature but you actually estimate P(x \mid y)

note that this can be problematic in high dimension (kde has issues in high dim). A middle ground could be to support also KDE/GMM for each feature ie keep the naive independence. This could be done with an option.

examples of more flexible density models are:

- :class:`sklearn.neighbors.KernelDensity`: discussed in :ref:`kernel_density`
- :class:`sklearn.mixture.GMM`: discussed in :ref:`clustering`

Though it can be much more computationally intense,
using one of these models rather than a naive Gaussian model can lead to much
better generative classifiers, and can be especially applicable in cases of
unbalanced data where accurate posterior classification probabilities are
desired.

.. figure:: ../auto_examples/images/plot_1d_generative_classification_1.png
:target: ../auto_examples/plot_1d_generative_classification.html
:align: center
:scale: 50%

Here we have a 1 dimensional, two-class distribution of data which is not
well-modeled by a normal distribution. The two classes have a small amount
of overlap, and by more accurately modeling the density of each class, we are
able to increase the accuracy by a few percent. This may seem like a small
change, but often it is these marginal cases which are most important in
practice! That is, any basic classification algorithm will correctly
classify the bulk of the data in this situation, but by accurately modeling
the density, we recover an accurate Bayesian probabilistic classification of
the most interesting cases.

This type of classification can be performed with the :class:`GenerativeBayes`
estimator. The estimator can be used very easily:

>>> from sklearn.naive_bayes import GenerativeBayes
>>> from sklearn.datasets import make_blobs
>>> X, y = make_blobs(100, centers=2, random_state=0)
>>> clf = GenerativeBayes(density_estimator='kde')
>>> clf.fit(X[:-10], y[:-10])
GenerativeBayes(density_estimator='kde', model_kwds=None)
>>> clf.predict(X[-10:])
array([1, 1, 1, 1, 0, 0, 1, 1, 0, 1])
>>> y[-10:]
array([1, 1, 1, 1, 0, 0, 1, 1, 0, 1])

The KDE-based Generative classifier for this problem has 100% accuracy on
this small subset of test data.
The specified density estimator can be ``'kde'``, ``'gmm'``,
``'normal_approximation'``, or any class or estimator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"any class or estimator" => "any estimator" if we drop the class support.

which has the same semantics as
:class:`sklearn.neighbors.KernelDensity` (see the documentation of
:class:`GenerativeBayes` for details).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your explanation is clear but I think it would be great if you could find a good online reference from the the literature for people who want to dig further.

Note that care should be taken to make sure that the density estimator for
each class is not over-fitting or under-fitting the data.

.. topic:: References:

* George John and Pat Langley (1995). Estimating Continuous
Distributions in Bayesian Classifiers. Proceedings of the
Eleventh Conference on Uncertainty in Artificial Intelligence.


Random Samples
~~~~~~~~~~~~~~

Another advantage of non-naive Bayesian classification models is that they
provide an accurate generative model of each individual training class. This
means that new random datasets can be drawn which have the same characteristics
as the training data.

Here is an example of a multi-class dataset in two dimensions. The
light-colored points are the training data, and the dark-colored points are
random data drawn from the multi-class generative model:

.. figure:: ../auto_examples/images/plot_generative_sampling_1.png
:target: ../auto_examples/plot_generative_sampling.html
:align: center
:scale: 50%

The red and yellow clusters have four times the number of points as the
blue and cyan clusters; this is accurately reflected in the number of "new"
points drawn from the model.

This type of generative model can be used in higher dimensions to do some
very interesting analysis. For example, here's a generative bayes model
which uses kernel density estimation trained on the digits dataset. The
top panel shows a selection of the input digits, while the bottom panel
shows draws from the class-wise probability distributions. These give an
intuitive feel to what the model "thinks" each digit looks like:

.. figure:: ../auto_examples/images/plot_generative_sampling_2.png
:target: ../auto_examples/plot_generative_sampling.html
:align: center
:scale: 50%

This result can be compared to the
`similar figure <../auto_examples/neighbors/plot_digits_kde_sampling.html`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing ">" before the "`".

drawn from a distribution which does not utilize class information.
77 changes: 77 additions & 0 deletions examples/plot_1d_generative_classification.py
@@ -0,0 +1,77 @@
"""
Generative Bayesian Classification
==================================
This example shows a 1-dimensional, two-class generative classification
using a Gaussian naive Bayes classifier, and some extensions which drop
the naive Gaussian assumption.

In generative Bayesian classification, each class is separately modeled,
and the class yielding the highest posterior probability is selected in
the classification.
"""

# Author: Jake Vanderplas <jakevdp@cs.washington.edu>
# License: BSD 3 Clause

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

from sklearn.cross_validation import cross_val_score
from sklearn.naive_bayes import GenerativeBayes
from sklearn.neighbors.kde import KernelDensity
from sklearn.mixture import GMM

# Generate some two-class data with slight overlap
np.random.seed(0)
X1 = np.vstack([stats.laplace.rvs(2.0, 1, size=(1000, 1)),
stats.laplace.rvs(0.3, 0.2, size=(300,1))])
X2 = np.vstack([stats.laplace.rvs(-2.5, 1, size=(300, 1)),
stats.laplace.rvs(-1.0, 0.5, size=(200, 1))])
X = np.vstack([X1, X2])
y = np.hstack([np.ones(X1.size), np.zeros(X2.size)])
x_plot = np.linspace(-6, 6, 200)

# Test three density estimators
density_estimators = ['normal_approximation',
GMM(3),
KernelDensity(0.25)]
names = ['Normal Approximation',
'Gaussian Mixture Model',
'Kernel Density Estimation']
linestyles = [':', '--', '-']
colors = []

# Plot histograms of the two input distributions
fig, ax = plt.subplots()
for j in range(2):
h = ax.hist(X[y == j, 0], bins=np.linspace(-6, 6, 80),
histtype='stepfilled', normed=False,
alpha=0.3)
colors.append(h[2][0].get_facecolor())
binsize = h[1][1] - h[1][0]


for i in range(3):
clf = GenerativeBayes(density_estimator=density_estimators[i])
clf.fit(X, y)
L = np.exp(clf._joint_log_likelihood(x_plot[:, None]))

for j in range(2):
ax.plot(x_plot,
L[:, j] * np.sum(y == j) * binsize / clf.class_prior_[j],
linestyle=linestyles[i],
color=colors[j],
alpha=1)

# Trick the legend into showing what we want
scores = cross_val_score(clf, X, y, scoring="accuracy", cv=10)
ax.plot([], [], linestyle=linestyles[i], color='black',
label="{0}:\n {1:.1f}% accuracy.".format(names[i],
100 * scores.mean()))

ax.set_xlabel('$x$')
ax.set_ylabel('$N(x)$')
ax.legend(loc='upper left', fontsize=12)

plt.show()
101 changes: 101 additions & 0 deletions examples/plot_generative_sampling.py
@@ -0,0 +1,101 @@
"""
Multiclass Generative Sampling
==============================
This example shows the use of the Generative Bayesian classifier for sampling
from a multi-class distribution.

The first figure shows a simple 2D distribution, overlaying the input points
and new points generated from the class-wise model.

The second figure extends this to a higher dimension. A generative Bayes
classifier based on kernel density estimation is fit to the handwritten digits
data, and a new sample is drawn from each of the class-wise generative
models.
"""
import matplotlib.pyplot as plt
import numpy as np
from sklearn.naive_bayes import GenerativeBayes
from sklearn.decomposition import PCA
from sklearn.grid_search import GridSearchCV
from sklearn.neighbors import KernelDensity
from sklearn.datasets import make_blobs, load_digits

#----------------------------------------------------------------------
# First figure: two-dimensional blobs

# Make 4 blobs with different numbers of points
np.random.seed(0)
X1, y1 = make_blobs(50, 2, centers=2)
X2, y2 = make_blobs(200, 2, centers=2)

X = np.vstack([X1, X2])
y = np.concatenate([y1, y2 + 2])

# Fit a generative Bayesian model to the data
clf = GenerativeBayes('normal_approximation')
clf.fit(X, y)

# Sample new data from the generative Bayesian model
X_new, y_new = clf.sample(200)

# Plot the input data and the sampled data
fig, ax = plt.subplots()
ax.scatter(X[:, 0], X[:, 1], c=y, alpha=0.2)
ax.scatter(X_new[:, 0], X_new[:, 1], c=y_new)

# Create the legend by plotting some empty data
ax.scatter([], [], c='w', alpha=0.2, label="Training (input) data")
ax.scatter([], [], c='w', label="Samples from Model")
ax.legend()

ax.set_xlim(-4, 10)
ax.set_ylim(-8, 8)


#----------------------------------------------------------------------
# Second figure: sampling from digits digits

# load the digits data
digits = load_digits()
data = digits.data
labels = digits.target

# project the 64-dimensional data to a lower dimension
pca = PCA(n_components=15, whiten=False)
data = pca.fit_transform(digits.data)

# use grid search cross-validation to optimize the bandwidth
params = {'bandwidth': np.logspace(-1, 1, 20)}
grid = GridSearchCV(KernelDensity(), params)
grid.fit(data)

print "best bandwidth: {0}".format(grid.best_estimator_.bandwidth)

# train the model with this bandwidth
clf = GenerativeBayes('kde',
model_kwds={'bandwidth':grid.best_estimator_.bandwidth})
clf.fit(data, labels)

new_data, new_labels = clf.sample(44, random_state=0)
new_data = pca.inverse_transform(new_data)

# turn data into a 4x11 grid
new_data = new_data.reshape((4, 11, -1))
real_data = digits.data[:44].reshape((4, 11, -1))

# plot real digits and resampled digits
fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[]))
for j in range(11):
ax[4, j].set_visible(False)
for i in range(4):
im = ax[i, j].imshow(real_data[i, j].reshape((8, 8)),
cmap=plt.cm.binary, interpolation='nearest')
im.set_clim(0, 16)
im = ax[i + 5, j].imshow(new_data[i, j].reshape((8, 8)),
cmap=plt.cm.binary, interpolation='nearest')
im.set_clim(0, 16)

ax[0, 5].set_title('Selection from the input data')
ax[5, 5].set_title('"New" digits drawn from the class-wise kernel density model')

plt.show()