From a571b01ac8a92504c522a860993428c0f46d7b29 Mon Sep 17 00:00:00 2001 From: "Michael A. Alcorn" Date: Mon, 28 Aug 2017 08:31:45 -0500 Subject: [PATCH] ENH Implement Complement Naive Bayes (#8190) --- doc/modules/classes.rst | 1 + doc/modules/naive_bayes.rst | 40 ++++++++ doc/whats_new.rst | 4 + .../document_classification_20newsgroups.py | 3 +- sklearn/naive_bayes.py | 93 ++++++++++++++++++- sklearn/tests/test_naive_bayes.py | 67 ++++++++++++- sklearn/utils/estimator_checks.py | 18 ++-- 7 files changed, 214 insertions(+), 12 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 128f1c85f13e2..0fd3d6e82b180 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1081,6 +1081,7 @@ Model validation naive_bayes.BernoulliNB naive_bayes.GaussianNB naive_bayes.MultinomialNB + naive_bayes.ComplementNB .. _neighbors_ref: diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 7d83ba38d1e71..bbf8e31571ade 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -133,6 +133,46 @@ in further computations. Setting :math:`\alpha = 1` is called Laplace smoothing, while :math:`\alpha < 1` is called Lidstone smoothing. +.. _complement_naive_bayes: + +Complement Naive Bayes +----------------------- + +:class:`ComplementNB` implements the complement naive Bayes (CNB) algorithm. +CNB is an adaptation of the standard multinomial naive Bayes (MNB) algorithm +that is particularly suited for imbalanced data sets. Specifically, CNB uses +statistics from the *complement* of each class to compute the model's weights. +The inventors of CNB show empirically that the parameter estimates for CNB are +more stable than those for MNB. Further, CNB regularly outperforms MNB (often +by a considerable margin) on text classification tasks. The procedure for +calculating the weights is as follows: + +.. math:: + + \hat{\theta}_{ci} = \frac{\sum{j:y_j \neq c} d_{ij} + \alpha_i} + {\sum{j:y_j \neq c} \sum{k} d_{kj} + \alpha} + w_{ci} = \log \hat{\theta}_{ci} + w_{ci} = \frac{w_{ci}{\sum{j} w_{cj}} + +where the summation is over all documents :math:`j` not in class :math:`c`, +:math:`d_{ij}` is either the count or tf-idf value of term :math:`i` in document +:math:`j`, and :math:`\alpha` is a smoothing hyperparameter like that found in +MNB. The second normalization addresses the tendency for longer documents to +dominate parameter estimates in MNB. The classification rule is: + +.. math:: + + \hat{c} = \arg\min_c \sum{i} t_i w_{ci} + +i.e., a document is assigned to the class that is the *poorest* complement +match. + +.. topic:: References: + + * Rennie, J. D., Shih, L., Teevan, J., & Karger, D. R. (2003). + `Tackling the poor assumptions of naive bayes text classifiers. + `_ + In ICML (Vol. 3, pp. 616-623). .. _bernoulli_naive_bayes: diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 2bc793bfbd459..01e3c06fd17e0 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -145,6 +145,10 @@ Classifiers and regressors during the first epochs of ridge and logistic regression. :issue:`8446` by `Arthur Mensch`_. +- Added :class:`naive_bayes.ComplementNB`, which implements the Complement + Naive Bayes classifier described in Rennie et al. (2003). + By :user:`Michael A. Alcorn `. + Other estimators - Added the :class:`neighbors.LocalOutlierFactor` class for anomaly diff --git a/examples/text/document_classification_20newsgroups.py b/examples/text/document_classification_20newsgroups.py index 22b559e56e7fd..8876dd776481a 100644 --- a/examples/text/document_classification_20newsgroups.py +++ b/examples/text/document_classification_20newsgroups.py @@ -42,7 +42,7 @@ from sklearn.linear_model import SGDClassifier from sklearn.linear_model import Perceptron from sklearn.linear_model import PassiveAggressiveClassifier -from sklearn.naive_bayes import BernoulliNB, MultinomialNB +from sklearn.naive_bayes import BernoulliNB, ComplementNB, MultinomialNB from sklearn.neighbors import KNeighborsClassifier from sklearn.neighbors import NearestCentroid from sklearn.ensemble import RandomForestClassifier @@ -283,6 +283,7 @@ def benchmark(clf): print("Naive Bayes") results.append(benchmark(MultinomialNB(alpha=.01))) results.append(benchmark(BernoulliNB(alpha=.01))) +results.append(benchmark(ComplementNB(alpha=.1))) print('=' * 80) print("LinearSVC with L1-based feature selection") diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index c324a98083e51..8e4bda8a9fabc 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -33,7 +33,7 @@ from .utils.validation import check_is_fitted from .externals import six -__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB'] +__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB'] class BaseNB(six.with_metaclass(ABCMeta, BaseEstimator, ClassifierMixin)): @@ -726,6 +726,97 @@ def _joint_log_likelihood(self, X): self.class_log_prior_) +class ComplementNB(BaseDiscreteNB): + """The Complement Naive Bayes classifier described in Rennie et al. (2003). + + The Complement Naive Bayes classifier was designed to correct the "severe + assumptions" made by the standard Multinomial Naive Bayes classifier. It is + particularly suited for imbalanced data sets. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + alpha : float, optional (default=1.0) + Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). + + fit_prior : boolean, optional (default=True) + Only used in edge case with a single class in the training set. + + class_prior : array-like, size (n_classes,), optional (default=None) + Prior probabilities of the classes. Not used. + + Attributes + ---------- + class_log_prior_ : array, shape (n_classes, ) + Smoothed empirical log probability for each class. Only used in edge + case with a single class in the training set. + + feature_log_prob_ : array, shape (n_classes, n_features) + Empirical weights for class complements. + + class_count_ : array, shape (n_classes,) + Number of samples encountered for each class during fitting. This + value is weighted by the sample weight when provided. + + feature_count_ : array, shape (n_classes, n_features) + Number of samples encountered for each (class, feature) during fitting. + This value is weighted by the sample weight when provided. + + feature_all_ : array, shape (n_features,) + Number of samples encountered for each feature during fitting. This + value is weighted by the sample weight when provided. + + Examples + -------- + >>> import numpy as np + >>> X = np.random.randint(5, size=(6, 100)) + >>> y = np.array([1, 2, 3, 4, 5, 6]) + >>> from sklearn.naive_bayes import ComplementNB + >>> clf = ComplementNB() + >>> clf.fit(X, y) + ComplementNB(alpha=1.0, class_prior=None, fit_prior=True) + >>> print(clf.predict(X[2:3])) + [3] + + References + ---------- + Rennie, J. D., Shih, L., Teevan, J., & Karger, D. R. (2003). + Tackling the poor assumptions of naive bayes text classifiers. In ICML + (Vol. 3, pp. 616-623). + http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf + """ + + def __init__(self, alpha=1.0, fit_prior=True, class_prior=None): + self.alpha = alpha + self.fit_prior = fit_prior + self.class_prior = class_prior + + def _count(self, X, Y): + """Count feature occurrences.""" + if np.any((X.data if issparse(X) else X) < 0): + raise ValueError("Input X must be non-negative") + self.feature_count_ += safe_sparse_dot(Y.T, X) + self.class_count_ += Y.sum(axis=0) + self.feature_all_ = self.feature_count_.sum(axis=0) + + def _update_feature_log_prob(self, alpha): + """Apply smoothing to raw counts and compute the weights.""" + comp_count = self.feature_all_ + alpha - self.feature_count_ + logged = np.log(comp_count / comp_count.sum(axis=1, keepdims=True)) + self.feature_log_prob_ = logged / logged.sum(axis=1, keepdims=True) + + def _joint_log_likelihood(self, X): + """Calculate the class scores for the samples in X.""" + check_is_fitted(self, "classes_") + + X = check_array(X, accept_sparse="csr") + jll = safe_sparse_dot(X, self.feature_log_prob_.T) + if len(self.classes_) == 1: + jll += self.class_log_prior_ + return jll + + class BernoulliNB(BaseDiscreteNB): """Naive Bayes classifier for multivariate Bernoulli models. diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index f43ddf0a0c553..e5b0a0b3eae6a 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1,3 +1,5 @@ +from __future__ import division + import pickle from io import BytesIO import numpy as np @@ -18,7 +20,8 @@ from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_warns -from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB +from sklearn.naive_bayes import GaussianNB, BernoulliNB +from sklearn.naive_bayes import MultinomialNB, ComplementNB # Data is just 6 separable points in the plane X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) @@ -530,6 +533,68 @@ def test_bnb(): assert_array_almost_equal(clf.predict_proba(X_test), predict_proba) +def test_cnb(): + # Tests ComplementNB when alpha=1.0 for the toy example in Manning, + # Raghavan, and Schuetze's "Introduction to Information Retrieval" book: + # http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html + + # Training data points are: + # Chinese Beijing Chinese (class: China) + # Chinese Chinese Shanghai (class: China) + # Chinese Macao (class: China) + # Tokyo Japan Chinese (class: Japan) + + # Features are Beijing, Chinese, Japan, Macao, Shanghai, and Tokyo. + X = np.array([[1, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 1, 0], + [0, 1, 0, 1, 0, 0], + [0, 1, 1, 0, 0, 1]]) + + # Classes are China (0), Japan (1). + Y = np.array([0, 0, 0, 1]) + + # Verify inputs are nonnegative. + clf = ComplementNB(alpha=1.0) + assert_raises(ValueError, clf.fit, -X, Y) + + clf.fit(X, Y) + + # Check that counts are correct. + feature_count = np.array([[1, 3, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1]]) + assert_array_equal(clf.feature_count_, feature_count) + class_count = np.array([3, 1]) + assert_array_equal(clf.class_count_, class_count) + feature_all = np.array([1, 4, 1, 1, 1, 1]) + assert_array_equal(clf.feature_all_, feature_all) + + # Check that weights are correct. See steps 4-6 in Table 4 of + # Rennie et al. (2003). + theta = np.array([ + [ + (0 + 1) / (3 + 6), + (1 + 1) / (3 + 6), + (1 + 1) / (3 + 6), + (0 + 1) / (3 + 6), + (0 + 1) / (3 + 6), + (1 + 1) / (3 + 6) + ], + [ + (1 + 1) / (6 + 6), + (3 + 1) / (6 + 6), + (0 + 1) / (6 + 6), + (1 + 1) / (6 + 6), + (1 + 1) / (6 + 6), + (0 + 1) / (6 + 6) + ]]) + + weights = np.zeros(theta.shape) + for i in range(2): + weights[i] = np.log(theta[i]) + weights[i] /= weights[i].sum() + + assert_array_equal(clf.feature_log_prob_, weights) + + def test_naive_bayes_scale_invariance(): # Scaling the data should not change the prediction results iris = load_iris() diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c3b066e5e31be..99faee5737818 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -115,12 +115,12 @@ def _yield_classifier_checks(name, classifier): # basic consistency testing yield check_classifiers_train yield check_classifiers_regression_target - if (name not in - ["MultinomialNB", "LabelPropagation", "LabelSpreading"] and + if (name not in ["MultinomialNB", "ComplementNB", "LabelPropagation", + "LabelSpreading"] and # TODO some complication with -1 label - name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): - # We don't raise a warning in these classifiers, as - # the column y interface is used by the forests. + name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): + # We don't raise a warning in these classifiers, as + # the column y interface is used by the forests. yield check_supervised_y_2d # test if NotFittedError is raised @@ -1088,7 +1088,7 @@ def check_classifiers_train(name, classifier_orig): n_classes = len(classes) n_samples, n_features = X.shape classifier = clone(classifier_orig) - if name in ['BernoulliNB', 'MultinomialNB']: + if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']: X -= X.min() set_random_state(classifier) # raises error on malformed input for fit @@ -1102,7 +1102,7 @@ def check_classifiers_train(name, classifier_orig): y_pred = classifier.predict(X) assert_equal(y_pred.shape, (n_samples,)) # training set performance - if name not in ['BernoulliNB', 'MultinomialNB']: + if name not in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']: assert_greater(accuracy_score(y, y_pred), 0.83) # raises error on malformed input for predict @@ -1245,8 +1245,8 @@ def check_classifiers_classes(name, classifier_orig): classes = np.unique(y_) classifier = clone(classifier_orig) - if name == 'BernoulliNB': - classifier.set_params(binarize=X.mean()) + if name in ['BernoulliNB', 'ComplementNB']: + X = X > X.mean() set_random_state(classifier) # fit classifier.fit(X, y_)