Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Generative Classification #2468

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
258 changes: 258 additions & 0 deletions sklearn/generative.py
@@ -0,0 +1,258 @@
"""
Bayesian Generative Classification
==================================
This module contains routines for general Bayesian generative classification.
Perhaps the best-known instance of generative classification is the Naive
Bayes Classifier, in which the distribution of each training class is
approximated by an axis-aligned multi-dimensional normal distribution, and
unknown points are evaluated by comparing their posterior probability under
each model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no such thing as "the" Naive Bayes classifier. You're thinking of Gaussian NB, but NLP people will variously think of Bernoulli or multinomial NB. (The first time I encountered the Gaussian variant was while reading sklearn source code :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! The first time I encountered any version other than Gaussian NB was reading the sklearn source code 😄


This idea can be straightforwardly extended by using a more sophisticated
distribution model for each class: i.e. instead of simply modeling it as a
normal distribution, you might use a mixture of gaussians or a kernel density
estimate.

Mathematical Background
-----------------------
Bayesian Generative classification relies on Bayes' formula,

.. math::

P(C|D) = \frac{P(D|C)P(C)}{P(D)}

where here, :math:`D` refers to the observed data of an unknown sample, and
:math:`C` refers to its classification. The :math:`P(D|C)` term represents
the per-class likelihood (e.g. using the normal approximation to the given
class in the case of naive Bayes), and :math:`P(C)` gives the class prior,
often inferred from the training sample. :math:`P(D)` acts as a normalization
parameter. The final classification is the class :math:`C` which gives the
largest posterior probability :math:`P(D|C)`.
"""
# Author: Jake Vanderplas <jakevdp@cs.washington.edu>

__all__ = ['GenerativeBayes']

import numpy as np
from .neighbors import KernelDensity
from .mixture import GMM
from .base import BaseEstimator, clone
from .utils import array2d, check_random_state, check_arrays
from .utils.extmath import logsumexp
from .naive_bayes import BaseNB


class _NormalApproximation(BaseEstimator):
"""Normal Approximation Density Estimator"""
def __init__(self):
pass

def fit(self, X):
"""Fit the Normal Approximation to data

Parameters
----------
X: array_like, shape (n_samples, n_features)
List of n_features-dimensional data points. Each row
corresponds to a single data point.
"""
X = array2d(X)
epsilon = 1e-9
self.mean = X.mean(0)
self.var = X.var(0) + epsilon
return self

def score_samples(self, X):
"""Evaluate the model on the data

Parameters
----------
X : array_like
An array of points to query. Last dimension should match dimension
of training data (n_features)

Returns
-------
density : ndarray
The array of density evaluations. This has shape X.shape[:-1]
"""
X = array2d(X)
if X.shape[-1] != self.mean.shape[0]:
raise ValueError("dimension of X must match that of training data")
norm = 1. / np.sqrt(2 ** X.shape[-1] * np.sum(self.var))
res = np.log(norm * np.exp(-0.5 * ((X - self.mean) ** 2
/ self.var).sum(1)))
return res

def score(self, X):
"""Compute the log probability under the model.

Parameters
----------
X : array_like, shape (n_samples, n_features)
List of n_features-dimensional data points. Each row
corresponds to a single data point.

Returns
-------
logprob : array_like, shape (n_samples,)
Log probabilities of each data point in X
"""
return np.sum(np.log(self.eval(X)))

def sample(self, n_samples=1, random_state=None):
"""Generate random samples from the model.

Parameters
----------
n_samples : int, optional
Number of samples to generate. Defaults to 1.

random_state: RandomState or an int seed (0 by default)
A random number generator instance

Returns
-------
X : array_like, shape (n_samples, n_features)
List of samples
"""
rng = check_random_state(random_state)
return rng.normal(self.mean, np.sqrt(self.var),
size=(n_samples, len(self.mean)))


MODEL_TYPES = {'norm_approx': _NormalApproximation,
'gmm': GMM,
'kde': KernelDensity}


class GenerativeBayes(BaseNB):
"""
Generative Bayes Classifier

This is a meta-estimator which performs generative Bayesian classification.

Parameters
----------
density_estimator : str, class, or instance
The density estimator to use for each class. Options are
'norm_approx' : Normal Approximation
'gmm' : Gaussian Mixture Model
'kde' : Kernel Density Estimate
Alternatively, a class or class instance can be specified. The
instantiated class should be a sklearn estimator, and contain a
``score_samples`` method with semantics similar to that in
:class:`sklearn.neighbors.KDE` or :class:`sklearn.mixture.GMM`.
**kwargs :
additional keyword arguments to be passed to the constructor
specified by density_estimator.
"""
def __init__(self, density_estimator, **kwargs):
self.density_estimator = density_estimator
self.kwargs = kwargs

# run this here to check for any exceptions; we avoid assigning
# the result here so that the estimator can be cloned.
self._choose_estimator(density_estimator, **kwargs)

def _choose_estimator(self, density_estimator, **kwargs):
if isinstance(density_estimator, str):
dclass = MODEL_TYPES.get(density_estimator)
return dclass(**kwargs)
elif isinstance(density_estimator, type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like type is undefined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a builtin

return density_estimator(**kwargs)
else:
return density_estimator

def fit(self, X, y):
"""Fit the model using X as training data and y as target values

Parameters
----------
X : array-like
Training data. shape = [n_samples, n_features]

y : array-like
Target values, array of float values, shape = [n_samples]
"""
X, y = check_arrays(X, y, sparse_format='dense')
estimator = self._choose_estimator(self.density_estimator,
**self.kwargs)

self.classes_ = np.sort(np.unique(y))
n_classes = len(self.classes_)
n_samples, self.n_features_ = X.shape

masks = [(y == c) for c in self.classes_]

self.class_prior_ = np.array([np.float(mask.sum()) / n_samples
for mask in masks])
self.estimators_ = [clone(estimator).fit(X[mask])
for mask in masks]
return self

def _joint_log_likelihood(self, X):
"""Compute the per-class log likelihood of each sample

Parameters
----------
X : array_like
Array of samples on which to compute likelihoods. Shape is
(n_samples, n_features)

Returns
-------
logL : array_like
The log likelihood under each class.
Shape is (n_samples, n_classes). logL[i, j] gives the log
likelihood of X[i] within the model representing the class
self.classes_[j].
"""
X = array2d(X)

# GMM API, in particular score() and score_samples(), is
# not consistent with the rest of the package. This needs
# to be addressed eventually...
if isinstance(self.density_estimator, GMM):
return np.array([np.log(prior) + dens.score(X)
for (prior, dens)
in zip(self.class_prior_,
self.estimators_)]).T
else:
return np.array([np.log(prior) + dens.score_samples(X)
for (prior, dens)
in zip(self.class_prior_,
self.estimators_)]).T

def sample(self, n_samples=1, random_state=None):
"""Generate random samples from the model.

Parameters
----------
n_samples : int, optional
Number of samples to generate. Defaults to 1.

Returns
-------
X : array_like, shape (n_samples, n_features)
List of samples
y : array_like, shape (n_samples,)
List of class labels for the generated samples
"""
random_state = check_random_state(random_state)
X = np.empty((n_samples, self.n_features_))
rand = random_state.rand(n_samples)

# split samples by class
prior_cdf = np.cumsum(self.class_prior_)
labels = prior_cdf.searchsorted(rand)

# for each class, generate all needed samples
for i, model in enumerate(self.estimators_):
model_mask = (labels == i)
N_model = model_mask.sum()
if N_model > 0:
X[model_mask] = model.sample(N_model,
random_state=random_state)

return X, self.classes_[labels]
115 changes: 115 additions & 0 deletions sklearn/tests/test_generative.py
@@ -0,0 +1,115 @@
"""Tests for Generative Classification"""
import numpy as np
from sklearn.generative import GenerativeBayes
from sklearn.naive_bayes import GaussianNB
from sklearn.utils.testing import\
assert_array_almost_equal, assert_greater, assert_allclose
from sklearn.cross_validation import cross_val_score
from sklearn.datasets import load_digits, make_blobs

# Data is just 6 separable points in the plane
X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]])
y = np.array([1, 1, 1, 2, 2, 2])

# A bit more random tests
rng = np.random.RandomState(0)
X1 = rng.normal(size=(10, 3))
y1 = (rng.normal(size=(10)) > 0).astype(np.int)

# Data is 6 random integer points in a 100 dimensional space classified to
# three classes.
X2 = rng.randint(5, size=(6, 100))
y2 = np.array([1, 1, 2, 2, 3, 3])


MODEL_KWARGS = {'norm_approx': {},
'gmm':{'n_components': 3,
'covariance_type': 'diag'},
'kde':{'bandwidth': 4.0,
'kernel': 'gaussian',
'metric': 'euclidean'}}


def test_compare_generative_gnb():
"""Compare GenerativeBayes to GaussianNB"""
# using norm_approx, the two should yield an identical model.
clf1 = GenerativeBayes('norm_approx')
clf2 = GaussianNB()

p1 = clf1.fit(X, y).predict_proba(X)
p2 = clf2.fit(X, y).predict_proba(X)

assert_array_almost_equal(p1, p2)


def test_generative_model_prior():
"""Test whether class priors are properly set."""
# class priors should sum to 1.
for model, kwargs in MODEL_KWARGS.iteritems():
clf = GenerativeBayes(model, **kwargs)

clf.fit(X, y)
assert_array_almost_equal(np.array([3, 3]) / 6.0,
clf.class_prior_, 8)

clf.fit(X1, y1)
assert_array_almost_equal(clf.class_prior_.sum(), 1)


def test_check_accuracy_on_digits():
digits = load_digits()
X, y = digits.data, digits.target

binary_3v8 = np.logical_or(digits.target == 3, digits.target == 8)
X_3v8, y_3v8 = X[binary_3v8], y[binary_3v8]

scores_cmp = {'kde': (0.98, 0.99),
'norm_approx': (0.89, 0.93),
'gmm': (0.92, 0.93)}

for model, kwargs in MODEL_KWARGS.iteritems():
scores = cross_val_score(GenerativeBayes(model, **kwargs),
X, y, cv=4)
assert_greater(scores.mean(), scores_cmp[model][0])

scores = cross_val_score(GenerativeBayes(model, **kwargs),
X_3v8, y_3v8, cv=4)
assert_greater(scores.mean(), scores_cmp[model][1])


def test_generate_samples():
# create a simple unbalanced dataset with 4 classes
np.random.seed(0)
X1, y1 = make_blobs(50, 2, centers=2)
X2, y2 = make_blobs(100, 2, centers=2)
X = np.vstack([X1, X2])
y = np.concatenate([y1, y2 + 2])

# test with norm_approx; other models have their sample() method
# tested independently
clf = GenerativeBayes('norm_approx')
clf.fit(X, y)

X_new, y_new = clf.sample(2000)

for i in range(4):
Xnew_i = X_new[y_new == i]
X_i = X[y == i]

# check the means
assert_array_almost_equal(Xnew_i.mean(0),
X_i.mean(0), decimal=1)

# check the standard deviations
assert_array_almost_equal(Xnew_i.std(0),
X_i.std(0), decimal=1)

# check the number of points
assert_allclose(X_i.shape[0] * 1. / X.shape[0],
Xnew_i.shape[0] * 1. / X_new.shape[0],
rtol=0.1)


if __name__ == '__main__':
import nose
nose.runmodule()