diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 6a5794e0214a4..d6ecda9914286 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -926,7 +926,7 @@ Pairwise metrics naive_bayes.GaussianNB naive_bayes.MultinomialNB naive_bayes.BernoulliNB - + naive_bayes.GenerativeBayes .. _neighbors_ref: diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 8cd76e5803713..1c6526a9097fc 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -199,3 +199,118 @@ note:: The ``partial_fit`` method call of naive Bayes models introduces some computational overhead. It is recommended to use data chunk sizes that are as large as possible, that is as the available RAM allows. + + +Non-naive Bayes +--------------- + +As mentioned above, naive Bayesian methods are generally very fast, but often +inaccurate estimators. This can be addressed by relaxing the assumptions that +make the models naive, so that more accurate classifications are possible. + +If we return to the general formalism outlined above, we can see that the +generic model for Bayesian classification is: + +.. math:: + \hat{y} = \arg\max_y P(y) \prod_{i=1}^{n} P(x_i \mid y). + +This model only becomes "naive" when we introduce certain assumptions about +the form of :math:`P(x_i \mid y)`, e.g. that each class is drawn from an +axis-aligned normal distribution (the assumption for Gaussian Naive Bayes). + +However, assumptions like these are in no way required for generative +Bayesian classification formalism: we can equally well fit any suitable +density model to each category to estimate :math:`P(x_i \mid y)`. Some +examples of more flexible density models are: + +- :class:`sklearn.neighbors.KernelDensity`: discussed in :ref:`kernel_density` +- :class:`sklearn.mixture.GMM`: discussed in :ref:`clustering` + +Though it can be much more computationally intense, +using one of these models rather than a naive Gaussian model can lead to much +better generative classifiers, and can be especially applicable in cases of +unbalanced data where accurate posterior classification probabilities are +desired. + +.. figure:: ../auto_examples/images/plot_1d_generative_classification_1.png + :target: ../auto_examples/plot_1d_generative_classification.html + :align: center + :scale: 50% + +Here we have a 1 dimensional, two-class distribution of data which is not +well-modeled by a normal distribution. The two classes have a small amount +of overlap, and by more accurately modeling the density of each class, we are +able to increase the accuracy by a few percent. This may seem like a small +change, but often it is these marginal cases which are most important in +practice! That is, any basic classification algorithm will correctly +classify the bulk of the data in this situation, but by accurately modeling +the density, we recover an accurate Bayesian probabilistic classification of +the most interesting cases. + +This type of classification can be performed with the :class:`GenerativeBayes` +estimator. The estimator can be used very easily: + + >>> from sklearn.naive_bayes import GenerativeBayes + >>> from sklearn.datasets import make_blobs + >>> X, y = make_blobs(100, centers=2, random_state=0) + >>> clf = GenerativeBayes(density_estimator='kde') + >>> clf.fit(X[:-10], y[:-10]) + GenerativeBayes(density_estimator='kde', model_kwds=None) + >>> clf.predict(X[-10:]) + array([1, 1, 1, 1, 0, 0, 1, 1, 0, 1]) + >>> y[-10:] + array([1, 1, 1, 1, 0, 0, 1, 1, 0, 1]) + +The KDE-based Generative classifier for this problem has 100% accuracy on +this small subset of test data. +The specified density estimator can be ``'kde'``, ``'gmm'``, +``'normal_approximation'``, or any estimator which has the same semantics as +:class:`sklearn.neighbors.KernelDensity` (see the documentation of +:class:`GenerativeBayes` for details). + +Note that care should be taken to make sure that the density estimator for +each class is not over-fitting or under-fitting the data. + +.. topic:: References: + + * George John and Pat Langley (1995). Estimating Continuous + Distributions in Bayesian Classifiers. Proceedings of the + Eleventh Conference on Uncertainty in Artificial Intelligence. + + +Random Samples +~~~~~~~~~~~~~~ + +Another advantage of non-naive Bayesian classification models is that they +provide an accurate generative model of each individual training class. This +means that new random datasets can be drawn which have the same characteristics +as the training data. + +Here is an example of a multi-class dataset in two dimensions. The +light-colored points are the training data, and the dark-colored points are +random data drawn from the multi-class generative model: + +.. figure:: ../auto_examples/images/plot_generative_sampling_1.png + :target: ../auto_examples/plot_generative_sampling.html + :align: center + :scale: 50% + +The red and yellow clusters have four times the number of points as the +blue and cyan clusters; this is accurately reflected in the number of "new" +points drawn from the model. + +This type of generative model can be used in higher dimensions to do some +very interesting analysis. For example, here's a generative bayes model +which uses kernel density estimation trained on the digits dataset. The +top panel shows a selection of the input digits, while the bottom panel +shows draws from the class-wise probability distributions. These give an +intuitive feel to what the model "thinks" each digit looks like: + +.. figure:: ../auto_examples/images/plot_generative_sampling_2.png + :target: ../auto_examples/plot_generative_sampling.html + :align: center + :scale: 50% + +This result can be compared to the +`similar figure <../auto_examples/neighbors/plot_digits_kde_sampling.html>`_ +drawn from a distribution which does not utilize class information. diff --git a/examples/plot_1d_generative_classification.py b/examples/plot_1d_generative_classification.py new file mode 100644 index 0000000000000..52f723669dc21 --- /dev/null +++ b/examples/plot_1d_generative_classification.py @@ -0,0 +1,77 @@ +""" +Generative Bayesian Classification +================================== +This example shows a 1-dimensional, two-class generative classification +using a Gaussian naive Bayes classifier, and some extensions which drop +the naive Gaussian assumption. + +In generative Bayesian classification, each class is separately modeled, +and the class yielding the highest posterior probability is selected in +the classification. +""" + +# Author: Jake Vanderplas +# License: BSD 3 Clause + +import matplotlib.pyplot as plt +import numpy as np +from scipy import stats + +from sklearn.cross_validation import cross_val_score +from sklearn.naive_bayes import GenerativeBayes +from sklearn.neighbors.kde import KernelDensity +from sklearn.mixture import GMM + +# Generate some two-class data with slight overlap +np.random.seed(0) +X1 = np.vstack([stats.laplace.rvs(2.0, 1, size=(1000, 1)), + stats.laplace.rvs(0.3, 0.2, size=(300,1))]) +X2 = np.vstack([stats.laplace.rvs(-2.5, 1, size=(300, 1)), + stats.laplace.rvs(-1.0, 0.5, size=(200, 1))]) +X = np.vstack([X1, X2]) +y = np.hstack([np.ones(X1.size), np.zeros(X2.size)]) +x_plot = np.linspace(-6, 6, 200) + +# Test three density estimators +density_estimators = ['normal_approximation', + GMM(3), + KernelDensity(0.25)] +names = ['Normal Approximation', + 'Gaussian Mixture Model', + 'Kernel Density Estimation'] +linestyles = [':', '--', '-'] +colors = [] + +# Plot histograms of the two input distributions +fig, ax = plt.subplots() +for j in range(2): + h = ax.hist(X[y == j, 0], bins=np.linspace(-6, 6, 80), + histtype='stepfilled', normed=False, + alpha=0.3) + colors.append(h[2][0].get_facecolor()) +binsize = h[1][1] - h[1][0] + + +for i in range(3): + clf = GenerativeBayes(density_estimator=density_estimators[i]) + clf.fit(X, y) + L = np.exp(clf._joint_log_likelihood(x_plot[:, None])) + + for j in range(2): + ax.plot(x_plot, + L[:, j] * np.sum(y == j) * binsize / clf.class_prior_[j], + linestyle=linestyles[i], + color=colors[j], + alpha=1) + + # Trick the legend into showing what we want + scores = cross_val_score(clf, X, y, scoring="accuracy", cv=10) + ax.plot([], [], linestyle=linestyles[i], color='black', + label="{0}:\n {1:.1f}% accuracy.".format(names[i], + 100 * scores.mean())) + +ax.set_xlabel('$x$') +ax.set_ylabel('$N(x)$') +ax.legend(loc='upper left', fontsize=12) + +plt.show() diff --git a/examples/plot_generative_sampling.py b/examples/plot_generative_sampling.py new file mode 100644 index 0000000000000..f86b0448ba6b4 --- /dev/null +++ b/examples/plot_generative_sampling.py @@ -0,0 +1,101 @@ +""" +Multiclass Generative Sampling +============================== +This example shows the use of the Generative Bayesian classifier for sampling +from a multi-class distribution. + +The first figure shows a simple 2D distribution, overlaying the input points +and new points generated from the class-wise model. + +The second figure extends this to a higher dimension. A generative Bayes +classifier based on kernel density estimation is fit to the handwritten digits +data, and a new sample is drawn from each of the class-wise generative +models. +""" +import matplotlib.pyplot as plt +import numpy as np +from sklearn.naive_bayes import GenerativeBayes +from sklearn.decomposition import PCA +from sklearn.grid_search import GridSearchCV +from sklearn.neighbors import KernelDensity +from sklearn.datasets import make_blobs, load_digits + +#---------------------------------------------------------------------- +# First figure: two-dimensional blobs + +# Make 4 blobs with different numbers of points +np.random.seed(0) +X1, y1 = make_blobs(50, 2, centers=2) +X2, y2 = make_blobs(200, 2, centers=2) + +X = np.vstack([X1, X2]) +y = np.concatenate([y1, y2 + 2]) + +# Fit a generative Bayesian model to the data +clf = GenerativeBayes('normal_approximation') +clf.fit(X, y) + +# Sample new data from the generative Bayesian model +X_new, y_new = clf.sample(200) + +# Plot the input data and the sampled data +fig, ax = plt.subplots() +ax.scatter(X[:, 0], X[:, 1], c=y, alpha=0.2) +ax.scatter(X_new[:, 0], X_new[:, 1], c=y_new) + +# Create the legend by plotting some empty data +ax.scatter([], [], c='w', alpha=0.2, label="Training (input) data") +ax.scatter([], [], c='w', label="Samples from Model") +ax.legend() + +ax.set_xlim(-4, 10) +ax.set_ylim(-8, 8) + + +#---------------------------------------------------------------------- +# Second figure: sampling from digits digits + +# load the digits data +digits = load_digits() +data = digits.data +labels = digits.target + +# project the 64-dimensional data to a lower dimension +pca = PCA(n_components=15, whiten=False) +data = pca.fit_transform(digits.data) + +# use grid search cross-validation to optimize the bandwidth +params = {'bandwidth': np.logspace(-1, 1, 20)} +grid = GridSearchCV(KernelDensity(), params) +grid.fit(data) + +print "best bandwidth: {0}".format(grid.best_estimator_.bandwidth) + +# train the model with this bandwidth +clf = GenerativeBayes('kde', + model_kwds={'bandwidth':grid.best_estimator_.bandwidth}) +clf.fit(data, labels) + +new_data, new_labels = clf.sample(44, random_state=0) +new_data = pca.inverse_transform(new_data) + +# turn data into a 4x11 grid +new_data = new_data.reshape((4, 11, -1)) +real_data = digits.data[:44].reshape((4, 11, -1)) + +# plot real digits and resampled digits +fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[])) +for j in range(11): + ax[4, j].set_visible(False) + for i in range(4): + im = ax[i, j].imshow(real_data[i, j].reshape((8, 8)), + cmap=plt.cm.binary, interpolation='nearest') + im.set_clim(0, 16) + im = ax[i + 5, j].imshow(new_data[i, j].reshape((8, 8)), + cmap=plt.cm.binary, interpolation='nearest') + im.set_clim(0, 16) + +ax[0, 5].set_title('Selection from the input data') +ax[5, 5].set_title('"New" digits drawn from the class-wise kernel density model') + +plt.show() diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 7092753cbb7e0..6902bdbf6f900 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -4,6 +4,9 @@ The :mod:`sklearn.naive_bayes` module implements Naive Bayes algorithms. These are supervised learning methods based on applying Bayes' theorem with strong (naive) feature independence assumptions. + +It also implements routines for generative classification with fewer +assumptions, i.e. non-naive Bayesian classification. """ # Author: Vincent Michel @@ -12,6 +15,8 @@ # Yehuda Finkelstein # Lars Buitinck # (parts based on earlier work by Mathieu Blondel) +# Generative Classification by +# Jake Vanderplas # # License: BSD 3 clause @@ -21,16 +26,19 @@ from scipy.sparse import issparse import warnings -from .base import BaseEstimator, ClassifierMixin +from .base import BaseEstimator, ClassifierMixin, clone from .preprocessing import binarize from .preprocessing import LabelBinarizer from .preprocessing import label_binarize -from .utils import array2d, atleast2d_or_csr, column_or_1d, check_arrays +from .utils import (array2d, atleast2d_or_csr, column_or_1d, check_arrays, + check_random_state, check_arrays) from .utils.extmath import safe_sparse_dot, logsumexp from .utils.multiclass import _check_partial_fit_first_call from .externals import six +from .neighbors import KernelDensity +from .mixture import GMM -__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB'] +__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'GenerativeBayes'] class BaseNB(six.with_metaclass(ABCMeta, BaseEstimator, ClassifierMixin)): @@ -568,3 +576,254 @@ def _joint_log_likelihood(self, X): jll += self.class_log_prior_ + neg_prob.sum(axis=1) return jll + + +class _NormalApproximation(BaseEstimator): + """Fit an axis-aligned normal approximation to the data. + + Parameters + ---------- + epsilon : float + The minimum variance along any dimension. Default = 1E-9. + """ + def __init__(self, epsilon=1E-9): + self.epsilon = epsilon + + def fit(self, X): + """Fit the Normal Approximation to data + + Parameters + ---------- + X: array_like, shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + """ + X = array2d(X) + self.mean = X.mean(0) + self.var = X.var(0) + self.epsilon + return self + + def score_samples(self, X): + """Evaluate the model on the data + + Parameters + ---------- + X : array_like + An array of points to query. Last dimension should match dimension + of training data (n_features) + + Returns + ------- + sample scores : ndarray + The array of log density. This has shape X.shape[:-1] + """ + X = array2d(X) + if X.shape[-1] != self.mean.shape[0]: + raise ValueError("dimension of X must match that of training data") + log_norm = -0.5 * (X.shape[-1] * np.log(2 * np.pi) + + np.log(self.var).sum()) + return log_norm - 0.5 * ((X - self.mean) ** 2 / self.var).sum(1) + + def score(self, X): + """Compute the log probability under the model. + + Parameters + ---------- + X : array_like, shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + Returns + ------- + logprob : float + log-likelihood of the data X under the model. + """ + return np.sum(self.score_samples(X)) + + def sample(self, n_samples=1, random_state=None): + """Generate random samples from the model. + + Parameters + ---------- + n_samples : int, optional + Number of samples to generate. Defaults to 1. + + random_state: RandomState or an int seed (0 by default) + A random number generator instance + + Returns + ------- + X : array_like, shape (n_samples, n_features) + List of samples + """ + rng = check_random_state(random_state) + return rng.normal(self.mean, np.sqrt(self.var), + size=(n_samples, len(self.mean))) + + +DENSITY_MODELS = {'normal_approximation': _NormalApproximation, + 'gmm': GMM, + 'kde': KernelDensity} + + +class GenerativeBayes(BaseNB): + """ + Generative Bayes Classifier + + This is a meta-estimator which performs generative Bayesian classification + using flexible underlying density models. + + Parameters + ---------- + density_estimator : str or instance + The density estimator to use for each class. Options are + - 'normal_approximation' : Axis-aligned Normal Approximation + (i.e. Gaussian Naive Bayes) + - 'gmm' : Gaussian Mixture Model + - 'kde' : Kernel Density Estimate + The default is 'normal_approximation'. + Alternatively, a scikit-learn estimator instance can be specified. + The estimator should contain a ``score_samples`` method with semantics + similar to those in :class:`sklearn.neighbors.KDE`. + model_kwds : dict or None + Additional keyword arguments to be passed to the constructor + specified by density_estimator. Ignored if density_estimator is + a class instance. Default=None. + + Attributes + ---------- + `classes_` : array, shape = [n_classes] + the sorted list of classes + + `class_prior_` : array, shape = [n_classes] + probability of each class. + + `estimators_` : list, length = [n_classes] + the density estimator associated with each class + """ + def __init__(self, density_estimator='normal_approximation', + model_kwds=None): + self.density_estimator = density_estimator + self.model_kwds = model_kwds + + # run this here to check for any exceptions; we avoid assigning + # the result here so that the estimator can be cloned. + self._choose_estimator(density_estimator, self.model_kwds) + + def _choose_estimator(self, density_estimator, kwargs=None): + """Choose the estimator based on the input""" + if kwargs is None: + kwargs = {} + + if isinstance(density_estimator, str): + dclass = DENSITY_MODELS.get(density_estimator) + if dclass is None: + raise ValueError("Invalid density_estimator: '%s'" + % density_estimator) + density_estimator = dclass(**kwargs) + + if isinstance(density_estimator, type): + raise TypeError('Invalid density_estimator: %s. ' + 'Expected class instance, not class.') + + if not hasattr(density_estimator, 'score_samples'): + raise TypeError('Invalid density_estimator: %s. ' + 'Missing required score_samples method.' + % density_estimator) + + return density_estimator + + def fit(self, X, y): + """Fit the model using X as training data and y as target values + + Parameters + ---------- + X : array-like + Training data. shape = [n_samples, n_features] + + y : array-like + Target values, array of float values, shape = [n_samples] + """ + X, y = check_arrays(X, y, sparse_format='dense') + y = column_or_1d(y, warn=True) + + estimator = self._choose_estimator(self.density_estimator, + self.model_kwds) + + self.classes_ = np.sort(np.unique(y)) + n_classes = len(self.classes_) + n_samples, self.n_features_ = X.shape + + class_membership_masks = [(y == c) for c in self.classes_] + + self.class_prior_ = np.array([np.float(mask.sum()) / n_samples + for mask in class_membership_masks]) + self.estimators_ = [clone(estimator).fit(X[mask]) + for mask in class_membership_masks] + return self + + def _joint_log_likelihood(self, X): + """Compute the per-class log likelihood of each sample + + Parameters + ---------- + X : array_like + Array of samples on which to compute likelihoods. Shape is + (n_samples, n_features) + + Returns + ------- + logL : array_like + The log likelihood under each class. + Shape is (n_samples, n_classes). logL[i, j] gives the log + likelihood of X[i] within the model representing the class + self.classes_[j]. + """ + X = array2d(X) + + # GMM API, in particular score() and score_samples(), is + # not consistent with the rest of the package. This needs + # to be addressed eventually... + if isinstance(self.estimators_[0], GMM): + return np.array([np.log(prior) + dens.score(X) + for (prior, dens) + in zip(self.class_prior_, + self.estimators_)]).T + else: + return np.array([np.log(prior) + dens.score_samples(X) + for (prior, dens) + in zip(self.class_prior_, + self.estimators_)]).T + + def sample(self, n_samples=1, random_state=None): + """Generate random samples from the model. + + Parameters + ---------- + n_samples : int, optional + Number of samples to generate. Defaults to 1. + + Returns + ------- + X : array_like, shape (n_samples, n_features) + List of samples + y : array_like, shape (n_samples,) + List of class labels for the generated samples + """ + random_state = check_random_state(random_state) + X = np.empty((n_samples, self.n_features_)) + rand = random_state.rand(n_samples) + + # split samples by class + prior_cdf = np.cumsum(self.class_prior_) + labels = prior_cdf.searchsorted(rand) + + # for each class, generate all needed samples + for i, model in enumerate(self.estimators_): + model_mask = (labels == i) + N_model = model_mask.sum() + if N_model > 0: + X[model_mask] = model.sample(N_model, + random_state=random_state) + + return X, self.classes_[labels] diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 5ff2953896091..078d554272a29 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -5,7 +5,7 @@ import warnings -from sklearn.datasets import load_digits +from sklearn.datasets import load_digits, make_blobs from sklearn.cross_validation import cross_val_score from sklearn.utils.testing import assert_almost_equal @@ -14,8 +14,10 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_allclose -from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB +from sklearn.naive_bayes import (GaussianNB, BernoulliNB, MultinomialNB, + GenerativeBayes) # Data is just 6 separable points in the plane X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) @@ -31,6 +33,14 @@ X2 = rng.randint(5, size=(6, 100)) y2 = np.array([1, 1, 2, 2, 3, 3]) +# model keywords for GenerativeBayes classification +MODEL_KWARGS = {'normal_approximation': {}, + 'gmm': {'n_components': 3, + 'covariance_type': 'diag'}, + 'kde': {'bandwidth': 4.0, + 'kernel': 'gaussian', + 'metric': 'euclidean'}} + def test_gnb(): """ @@ -356,3 +366,90 @@ def test_check_accuracy_on_digits(): scores = cross_val_score(GaussianNB(), X_3v8, y_3v8, cv=10) assert_greater(scores.mean(), 0.86) + + # Generative Bayes + scores_cmp = {'kde': (0.96, 0.98), + 'normal_approximation': (0.79, 0.79), + 'gmm': (0.84, 0.92)} + + for model, kwargs in MODEL_KWARGS.iteritems(): + scores = cross_val_score(GenerativeBayes(model, kwargs), + X, y, cv=3) + assert_greater(scores.mean(), scores_cmp[model][0]) + + scores = cross_val_score(GenerativeBayes(model, kwargs), + X_3v8, y_3v8, cv=3) + assert_greater(scores.mean(), scores_cmp[model][1]) + + +def test_compare_generative_gnb(): + """Compare GenerativeBayes to GaussianNB""" + # using normal_approximation, the two should yield an identical model. + clf1 = GenerativeBayes('normal_approximation') + clf2 = GaussianNB() + + p1 = clf1.fit(X, y).predict_proba(X) + p2 = clf2.fit(X, y).predict_proba(X) + + assert_array_almost_equal(p1, p2) + + +def test_generative_model_prior(): + """Test whether class priors are properly set.""" + # class priors should sum to 1. + for model, kwargs in MODEL_KWARGS.iteritems(): + clf = GenerativeBayes(model, kwargs) + + clf.fit(X, y) + assert_array_almost_equal(np.array([3, 3]) / 6.0, + clf.class_prior_, 8) + + clf.fit(X1, y1) + assert_array_almost_equal(clf.class_prior_.sum(), 1) + + +def test_generate_samples(): + # create a simple unbalanced dataset with 4 classes + np.random.seed(0) + X1, y1 = make_blobs(50, 2, centers=2) + X2, y2 = make_blobs(100, 2, centers=2) + X = np.vstack([X1, X2]) + y = np.concatenate([y1, y2 + 2]) + + # test with normal_approximation; other models have their sample() method + # tested independently + clf = GenerativeBayes('normal_approximation') + clf.fit(X, y) + + X_new, y_new = clf.sample(2000) + + for i in range(4): + Xnew_i = X_new[y_new == i] + X_i = X[y == i] + + # check the means + assert_array_almost_equal(Xnew_i.mean(0), + X_i.mean(0), decimal=1) + + # check the standard deviations + assert_array_almost_equal(Xnew_i.std(0), + X_i.std(0), decimal=1) + + # check the number of points + assert_allclose(X_i.shape[0] * 1. / X.shape[0], + Xnew_i.shape[0] * 1. / X_new.shape[0], + rtol=0.1) + + +def test_generative_bayes_invalid(): + # invalid string + assert_raises(ValueError, GenerativeBayes, 'not_a_valid_arg') + + # passing class rather than instance + from sklearn.mixture import GMM + assert_raises(TypeError, GenerativeBayes, GMM) + + # passing a non-density estimator + from sklearn.svm import SVC + assert_raises(TypeError, GenerativeBayes, SVC()) +