From 07e95e41b60f2752fa66d6439795456e1e950d51 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 10 Aug 2017 12:28:15 +0200 Subject: [PATCH 01/11] EHN add BalancedBaggingClassifier --- imblearn/ensemble/__init__.py | 4 +- imblearn/ensemble/easy_ensemble.py | 203 +++++++- imblearn/ensemble/tests/test_easy_ensemble.py | 442 +++++++++++++++++- 3 files changed, 644 insertions(+), 5 deletions(-) diff --git a/imblearn/ensemble/__init__.py b/imblearn/ensemble/__init__.py index 6c17409e5..13e148991 100644 --- a/imblearn/ensemble/__init__.py +++ b/imblearn/ensemble/__init__.py @@ -3,7 +3,7 @@ under-sampled subsets combined inside an ensemble. """ -from .easy_ensemble import EasyEnsemble +from .easy_ensemble import EasyEnsemble, BalancedBaggingClassifier from .balance_cascade import BalanceCascade -__all__ = ['EasyEnsemble', 'BalanceCascade'] +__all__ = ['EasyEnsemble', 'BalancedBaggingClassifier', 'BalanceCascade'] diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index ffa69ded3..2f80220fa 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -4,16 +4,44 @@ # Christos Aridas # License: MIT +import numbers + import numpy as np -from sklearn.utils import check_random_state +import sklearn +from sklearn.base import clone +from sklearn.ensemble import BaggingClassifier +from sklearn.ensemble.bagging import _generate_bagging_indices +from sklearn.tree import DecisionTreeClassifier +from sklearn.utils import check_random_state, indices_to_mask from .base import BaseEnsembleSampler +from ..pipeline import Pipeline from ..under_sampling import RandomUnderSampler MAX_INT = np.iinfo(np.int32).max +old_generate = _generate_bagging_indices + + +def _masked_bagging_indices(random_state, bootstrap_features, + bootstrap_samples, n_features, n_samples, + max_features, max_samples): + """Monkey-patch to always get a mask instead of indices""" + feature_indices, sample_indices = old_generate(random_state, + bootstrap_features, + bootstrap_samples, + n_features, n_samples, + max_features, max_samples) + sample_indices = indices_to_mask(sample_indices, n_samples) + + return feature_indices, sample_indices + + +sklearn.ensemble.bagging._generate_bagging_indices = _masked_bagging_indices + + class EasyEnsemble(BaseEnsembleSampler): """Create an ensemble sets by iteratively applying random under-sampling. @@ -147,3 +175,176 @@ def _sample(self, X, y): np.array(idx_under)) else: return np.array(X_resampled), np.array(y_resampled) + + +class BalancedBaggingClassifier(BaggingClassifier): + """A Bagging classifier with additional balancing. + + This implementation of Bagging is similar to the scikit-learn + implementation. It includes an additional step to balance the training set + at fit time using a ``RandomUnderSampler``. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + base_estimator : object or None, optional (default=None) + The base estimator to fit on random subsets of the dataset. + If None, then the base estimator is a decision tree. + + n_estimators : int, optional (default=10) + The number of base estimators in the ensemble. + + max_samples : int or float, optional (default=1.0) + The number of samples to draw from X to train each base estimator. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. + + max_features : int or float, optional (default=1.0) + The number of features to draw from X to train each base estimator. + - If int, then draw `max_features` features. + - If float, then draw `max_features * X.shape[1]` features. + + bootstrap : boolean, optional (default=True) + Whether samples are drawn with replacement. + + bootstrap_features : boolean, optional (default=False) + Whether features are drawn with replacement. + + oob_score : bool + Whether to use out-of-bag samples to estimate + the generalization error. + + warm_start : bool, optional (default=False) + When set to True, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit + a whole new ensemble. + .. versionadded:: 0.17 + *warm_start* constructor parameter. + + ratio : str, dict, or callable, optional (default='auto') + Ratio to use for resampling the data set. + + - If ``str``, has to be one of: (i) ``'minority'``: resample the + minority class; (ii) ``'majority'``: resample the majority class, + (iii) ``'not minority'``: resample all classes apart of the minority + class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``: + correspond to ``'all'`` with for over-sampling methods and ``'not + minority'`` for under-sampling methods. The classes targeted will be + over-sampled or under-sampled to achieve an equal number of sample + with the majority or minority class. + - If ``dict``, the keys correspond to the targeted classes. The values + correspond to the desired number of samples. + - If callable, function taking ``y`` and returns a ``dict``. The keys + correspond to the targeted classes. The values correspond to the + desired number of samples. + + replacement : bool, optional (default=False) + Whether or not to sample randomly with replacement or not. + + n_jobs : int, optional (default=1) + The number of jobs to run in parallel for both `fit` and `predict`. + If -1, then the number of jobs is set to the number of cores. + + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + verbose : int, optional (default=0) + Controls the verbosity of the building process. + + Attributes + ---------- + base_estimator_ : estimator + The base estimator from which the ensemble is grown. + + estimators_ : list of estimators + The collection of fitted base estimators. + + estimators_samples_ : list of arrays + The subset of drawn samples (i.e., the in-bag samples) for each base + estimator. Each subset is defined by a boolean mask. + + estimators_features_ : list of arrays + The subset of drawn features for each base estimator. + + classes_ : array of shape = [n_classes] + The classes labels. + + n_classes_ : int or list + The number of classes. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + + oob_decision_function_ : array of shape = [n_samples, n_classes] + Decision function computed with out-of-bag estimate on the training + set. If n_estimators is small it might be possible that a data point + was never left out during the bootstrap. In this case, + `oob_decision_function_` might contain NaN. + + References + ---------- + .. [1] L. Breiman, "Pasting small votes for classification in large + databases and on-line", Machine Learning, 36(1), 85-103, 1999. + .. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140, + 1996. + .. [3] T. Ho, "The random subspace method for constructing decision + forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844, + 1998. + .. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine + Learning and Knowledge Discovery in Databases, 346-361, 2012. + + """ + def __init__(self, + base_estimator=None, + n_estimators=10, + max_samples=1.0, + max_features=1.0, + bootstrap=True, + bootstrap_features=False, + oob_score=False, + warm_start=False, + ratio='auto', + replacement=False, + n_jobs=1, + random_state=None, + verbose=0): + + super(BaggingClassifier, self).__init__( + base_estimator, + n_estimators=n_estimators, + max_samples=max_samples, + max_features=max_features, + bootstrap=bootstrap, + bootstrap_features=bootstrap_features, + oob_score=oob_score, + warm_start=warm_start, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose) + self.ratio = ratio + self.replacement = replacement + + def _validate_estimator(self, default=DecisionTreeClassifier()): + """Check the estimator and the n_estimator attribute, set the + `base_estimator_` attribute.""" + if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): + raise ValueError("n_estimators must be an integer, " + "got {0}.".format(type(self.n_estimators))) + + if self.n_estimators <= 0: + raise ValueError("n_estimators must be greater than zero, " + "got {0}.".format(self.n_estimators)) + + if self.base_estimator is not None: + base_estimator = clone(self.base_estimator) + else: + base_estimator = clone(default) + + self.base_estimator_ = Pipeline( + [('sampler', RandomUnderSampler(ratio=self.ratio, + replacement=self.replacement)), + ('classifier', base_estimator)]) diff --git a/imblearn/ensemble/tests/test_easy_ensemble.py b/imblearn/ensemble/tests/test_easy_ensemble.py index 565bd7b83..6b3de9e44 100644 --- a/imblearn/ensemble/tests/test_easy_ensemble.py +++ b/imblearn/ensemble/tests/test_easy_ensemble.py @@ -6,9 +6,28 @@ from __future__ import print_function import numpy as np -from sklearn.utils.testing import assert_array_equal, assert_equal -from imblearn.ensemble import EasyEnsemble +from sklearn.datasets import load_iris, make_hastie_10_2 +from sklearn.model_selection import (GridSearchCV, ParameterGrid, + train_test_split) +from sklearn.dummy import DummyClassifier +from sklearn.linear_model import Perceptron, LogisticRegression +from sklearn.tree import DecisionTreeClassifier +from sklearn.neighbors import KNeighborsClassifier +from sklearn.svm import SVC +from sklearn.feature_selection import SelectKBest +from sklearn.utils.testing import (assert_array_equal, assert_equal, + assert_array_almost_equal, assert_less, + assert_warns, assert_raises, assert_false, + assert_greater, assert_true, + assert_warns_message) + +from imblearn.datasets import make_imbalance +from imblearn.ensemble import EasyEnsemble, BalancedBaggingClassifier +from imblearn.pipeline import make_pipeline +from imblearn.under_sampling import RandomUnderSampler + +iris = load_iris() # Generate a global dataset to use RND_SEED = 0 @@ -97,3 +116,422 @@ def test_random_state_none(): # Get the different subset X_resampled, y_resampled = ee.fit_sample(X, Y) + + +def test_balanced_bagging_classifier(): + # Check classification for various parameter settings. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=0) + grid = ParameterGrid({"max_samples": [0.5, 1.0], + "max_features": [1, 2, 4], + "bootstrap": [True, False], + "bootstrap_features": [True, False]}) + + for base_estimator in [None, + DummyClassifier(), + Perceptron(), + DecisionTreeClassifier(), + KNeighborsClassifier(), + SVC()]: + for params in grid: + BalancedBaggingClassifier( + base_estimator=base_estimator, + random_state=0, + **params).fit(X_train, y_train).predict(X_test) + + +def test_bootstrap_samples(): + # Test that bootstrapping samples generate non-perfect base estimators. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=0) + + base_estimator = DecisionTreeClassifier().fit(X_train, y_train) + + # without bootstrap, all trees are perfect on the training set + # disable the resampling by passing an empty dictionary. + ensemble = BalancedBaggingClassifier( + base_estimator=DecisionTreeClassifier(), + max_samples=1.0, + bootstrap=False, + n_estimators=10, + ratio={}, + random_state=0).fit(X_train, y_train) + + assert_equal(base_estimator.score(X_train, y_train), + ensemble.score(X_train, y_train)) + + # with bootstrap, trees are no longer perfect on the training set + ensemble = BalancedBaggingClassifier( + base_estimator=DecisionTreeClassifier(), + max_samples=1.0, + bootstrap=True, + random_state=0).fit(X_train, y_train) + + assert_greater(base_estimator.score(X_train, y_train), + ensemble.score(X_train, y_train)) + + +def test_bootstrap_features(): + # Test that bootstrapping features may generate duplicate features. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=0) + + ensemble = BalancedBaggingClassifier( + base_estimator=DecisionTreeClassifier(), + max_features=1.0, + bootstrap_features=False, + random_state=0).fit(X_train, y_train) + + for features in ensemble.estimators_features_: + assert_equal(X.shape[1], np.unique(features).shape[0]) + + ensemble = BalancedBaggingClassifier( + base_estimator=DecisionTreeClassifier(), + max_features=1.0, + bootstrap_features=True, + random_state=0).fit(X_train, y_train) + + unique_features = [np.unique(features).shape[0] + for features in ensemble.estimators_features_] + assert_greater(X.shape[1], np.median(unique_features)) + + +def test_probability(): + # Predict probabilities. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=0) + + with np.errstate(divide="ignore", invalid="ignore"): + # Normal case + ensemble = BalancedBaggingClassifier( + base_estimator=DecisionTreeClassifier(), + random_state=0).fit(X_train, y_train) + + assert_array_almost_equal(np.sum(ensemble.predict_proba(X_test), + axis=1), + np.ones(len(X_test))) + + assert_array_almost_equal(ensemble.predict_proba(X_test), + np.exp(ensemble.predict_log_proba(X_test))) + + # Degenerate case, where some classes are missing + ensemble = BalancedBaggingClassifier( + base_estimator=LogisticRegression(), + random_state=0, + max_samples=5).fit(X_train, y_train) + + assert_array_almost_equal(np.sum(ensemble.predict_proba(X_test), + axis=1), + np.ones(len(X_test))) + + assert_array_almost_equal(ensemble.predict_proba(X_test), + np.exp(ensemble.predict_log_proba(X_test))) + + +def test_oob_score_classification(): + # Check that oob prediction is a good estimation of the generalization + # error. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=0) + + for base_estimator in [DecisionTreeClassifier(), SVC()]: + clf = BalancedBaggingClassifier( + base_estimator=base_estimator, + n_estimators=100, + bootstrap=True, + oob_score=True, + random_state=0).fit(X_train, y_train) + + test_score = clf.score(X_test, y_test) + + assert_less(abs(test_score - clf.oob_score_), 0.1) + + # Test with few estimators + assert_warns(UserWarning, + BalancedBaggingClassifier( + base_estimator=base_estimator, + n_estimators=1, + bootstrap=True, + oob_score=True, + random_state=0).fit, + X_train, + y_train) + + +def test_single_estimator(): + # Check singleton ensembles. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=0) + + clf1 = BalancedBaggingClassifier( + base_estimator=KNeighborsClassifier(), + n_estimators=1, + bootstrap=False, + bootstrap_features=False, + random_state=0).fit(X_train, y_train) + + clf2 = make_pipeline(RandomUnderSampler( + random_state=clf1.estimators_[0].steps[0][1].random_state), + KNeighborsClassifier()).fit(X_train, y_train) + + assert_array_equal(clf1.predict(X_test), clf2.predict(X_test)) + + +def test_error(): + # Test that it gives proper exception on deficient input. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}) + base = DecisionTreeClassifier() + + # Test max_samples + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_samples=-1).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_samples=0.0).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_samples=2.0).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_samples=1000).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_samples="foobar").fit, + X, y) + + # Test max_features + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_features=-1).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_features=0.0).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_features=2.0).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_features=5).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, max_features="foobar").fit, + X, y) + + # Test support of decision_function + assert_false(hasattr(BalancedBaggingClassifier(base).fit(X, y), + 'decision_function')) + + +def test_gridsearch(): + # Check that bagging ensembles can be grid-searched. + # Transform iris into a binary classification task + X, y = iris.data, iris.target.copy() + y[y == 2] = 1 + + # Grid search with scoring based on decision_function + parameters = {'n_estimators': (1, 2), + 'base_estimator__C': (1, 2)} + + GridSearchCV(BalancedBaggingClassifier(SVC()), + parameters, + scoring="roc_auc").fit(X, y) + + +def test_base_estimator(): + # Check base_estimator and its default values. + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, + random_state=0) + + ensemble = BalancedBaggingClassifier(None, + n_jobs=3, + random_state=0).fit(X_train, y_train) + + assert_true(isinstance(ensemble.base_estimator_.steps[-1][1], + DecisionTreeClassifier)) + + ensemble = BalancedBaggingClassifier(DecisionTreeClassifier(), + n_jobs=3, + random_state=0).fit(X_train, y_train) + + assert_true(isinstance(ensemble.base_estimator_.steps[-1][1], + DecisionTreeClassifier)) + + ensemble = BalancedBaggingClassifier(Perceptron(), + n_jobs=3, + random_state=0).fit(X_train, y_train) + + assert_true(isinstance(ensemble.base_estimator_.steps[-1][1], + Perceptron)) + + +def test_bagging_with_pipeline(): + X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}, + random_state=0) + estimator = BalancedBaggingClassifier( + make_pipeline(SelectKBest(k=1), + DecisionTreeClassifier()), + max_features=2) + estimator.fit(X, y).predict(X) + + +def test_warm_start(random_state=42): + # Test if fitting incrementally with warm start gives a forest of the + # right size and the same results as a normal fit. + X, y = make_hastie_10_2(n_samples=20, random_state=1) + + clf_ws = None + for n_estimators in [5, 10]: + if clf_ws is None: + clf_ws = BalancedBaggingClassifier(n_estimators=n_estimators, + random_state=random_state, + warm_start=True) + else: + clf_ws.set_params(n_estimators=n_estimators) + clf_ws.fit(X, y) + assert_equal(len(clf_ws), n_estimators) + + clf_no_ws = BalancedBaggingClassifier(n_estimators=10, + random_state=random_state, + warm_start=False) + clf_no_ws.fit(X, y) + + assert_equal(set([pipe.steps[-1][1].random_state + for pipe in clf_ws]), + set([pipe.steps[-1][1].random_state + for pipe in clf_no_ws])) + + +def test_warm_start_smaller_n_estimators(): + # Test if warm start'ed second fit with smaller n_estimators raises error. + X, y = make_hastie_10_2(n_samples=20, random_state=1) + clf = BalancedBaggingClassifier(n_estimators=5, warm_start=True) + clf.fit(X, y) + clf.set_params(n_estimators=4) + assert_raises(ValueError, clf.fit, X, y) + + +def test_warm_start_equal_n_estimators(): + # Test that nothing happens when fitting without increasing n_estimators + X, y = make_hastie_10_2(n_samples=20, random_state=1) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=43) + + clf = BalancedBaggingClassifier(n_estimators=5, warm_start=True, + random_state=83) + clf.fit(X_train, y_train) + + y_pred = clf.predict(X_test) + # modify X to nonsense values, this should not change anything + X_train += 1. + + assert_warns_message(UserWarning, + "Warm-start fitting without increasing n_estimators" + " does not", clf.fit, X_train, y_train) + assert_array_equal(y_pred, clf.predict(X_test)) + + +def test_warm_start_equivalence(): + # warm started classifier with 5+5 estimators should be equivalent to + # one classifier with 10 estimators + X, y = make_hastie_10_2(n_samples=20, random_state=1) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=43) + + clf_ws = BalancedBaggingClassifier(n_estimators=5, warm_start=True, + random_state=3141) + clf_ws.fit(X_train, y_train) + clf_ws.set_params(n_estimators=10) + clf_ws.fit(X_train, y_train) + y1 = clf_ws.predict(X_test) + + clf = BalancedBaggingClassifier(n_estimators=10, warm_start=False, + random_state=3141) + clf.fit(X_train, y_train) + y2 = clf.predict(X_test) + + assert_array_almost_equal(y1, y2) + + +def test_warm_start_with_oob_score_fails(): + # Check using oob_score and warm_start simultaneously fails + X, y = make_hastie_10_2(n_samples=20, random_state=1) + clf = BalancedBaggingClassifier(n_estimators=5, warm_start=True, + oob_score=True) + assert_raises(ValueError, clf.fit, X, y) + + +def test_oob_score_removed_on_warm_start(): + X, y = make_hastie_10_2(n_samples=2000, random_state=1) + + clf = BalancedBaggingClassifier(n_estimators=50, oob_score=True) + clf.fit(X, y) + + clf.set_params(warm_start=True, oob_score=False, n_estimators=100) + clf.fit(X, y) + + assert_raises(AttributeError, getattr, clf, "oob_score_") + + +def test_oob_score_consistency(): + # Make sure OOB scores are identical when random_state, estimator, and + # training data are fixed and fitting is done twice + X, y = make_hastie_10_2(n_samples=200, random_state=1) + bagging = BalancedBaggingClassifier(KNeighborsClassifier(), + max_samples=0.5, + max_features=0.5, oob_score=True, + random_state=1) + assert_equal(bagging.fit(X, y).oob_score_, bagging.fit(X, y).oob_score_) + + +def test_estimators_samples(): + # Check that format of estimators_samples_ is correct and that results + # generated at fit time can be identically reproduced at a later time + # using data saved in object attributes. + X, y = make_hastie_10_2(n_samples=200, random_state=1) + + # remap the y outside of the BalancedBaggingclassifier + # _, y = np.unique(y, return_inverse=True) + bagging = BalancedBaggingClassifier(LogisticRegression(), max_samples=0.5, + max_features=0.5, random_state=1, + bootstrap=False) + bagging.fit(X, y) + + # Get relevant attributes + estimators_samples = bagging.estimators_samples_ + estimators_features = bagging.estimators_features_ + estimators = bagging.estimators_ + + # Test for correct formatting + assert_equal(len(estimators_samples), len(estimators)) + assert_equal(len(estimators_samples[0]), len(X)) + assert_equal(estimators_samples[0].dtype.kind, 'b') + + # Re-fit single estimator to test for consistent sampling + estimator_index = 0 + estimator_samples = estimators_samples[estimator_index] + estimator_features = estimators_features[estimator_index] + estimator = estimators[estimator_index] + + X_train = (X[estimator_samples])[:, estimator_features] + y_train = y[estimator_samples] + + orig_coefs = estimator.steps[-1][1].coef_ + estimator.fit(X_train, y_train) + new_coefs = estimator.steps[-1][1].coef_ + + assert_array_almost_equal(orig_coefs, new_coefs) + + +def test_max_samples_consistency(): + # Make sure validated max_samples and original max_samples are identical + # when valid integer max_samples supplied by user + max_samples = 100 + X, y = make_hastie_10_2(n_samples=2*max_samples, random_state=1) + bagging = BalancedBaggingClassifier(KNeighborsClassifier(), + max_samples=max_samples, + max_features=0.5, random_state=1) + bagging.fit(X, y) + assert_equal(bagging._max_samples, max_samples) From 0e06de7f00aff90307c8d2ce054d96a8f8797569 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 11 Aug 2017 17:12:00 +0200 Subject: [PATCH 02/11] TST add two missing test --- imblearn/ensemble/tests/test_easy_ensemble.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/imblearn/ensemble/tests/test_easy_ensemble.py b/imblearn/ensemble/tests/test_easy_ensemble.py index 6b3de9e44..8749807c3 100644 --- a/imblearn/ensemble/tests/test_easy_ensemble.py +++ b/imblearn/ensemble/tests/test_easy_ensemble.py @@ -294,6 +294,12 @@ def test_error(): X, y = make_imbalance(iris.data, iris.target, ratio={0: 20, 1: 25, 2: 50}) base = DecisionTreeClassifier() + # Test n_estimators + assert_raises(ValueError, + BalancedBaggingClassifier(base, n_estimators=1.5).fit, X, y) + assert_raises(ValueError, + BalancedBaggingClassifier(base, n_estimators=-1).fit, X, y) + # Test max_samples assert_raises(ValueError, BalancedBaggingClassifier(base, max_samples=-1).fit, X, y) From f484631e059f6057bd3c4434d6d1556bd1b4b832 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 11 Aug 2017 18:35:43 +0200 Subject: [PATCH 03/11] DOC add examples --- doc/whats_new.rst | 4 + .../plot_comparison_bagging_classifier.py | 104 ++++++++++++++++++ imblearn/ensemble/easy_ensemble.py | 23 +++- 3 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 examples/ensemble/plot_comparison_bagging_classifier.py diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 656fb04e9..0f50cc5ad 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -33,6 +33,10 @@ New features Enhancement ~~~~~~~~~~~ +- Add :class:`ensemble.BalancedBaggingClassifier` which is a meta estimator to + directly use the :class:`ensemble.EasyEnsemble` chained with a classifier. By + `Guillaume Lemaitre`_. + - :func:`datasets.make_imbalance` take a ratio similarly to other samplers. It supports multiclass. By `Guillaume Lemaitre`_. diff --git a/examples/ensemble/plot_comparison_bagging_classifier.py b/examples/ensemble/plot_comparison_bagging_classifier.py new file mode 100644 index 000000000..62176c0a0 --- /dev/null +++ b/examples/ensemble/plot_comparison_bagging_classifier.py @@ -0,0 +1,104 @@ +""" +========================================================= +Comparison of balanced and imbalanced bagging classifiers +========================================================= + +This example shows the benefit of balancing the training set when using a +bagging classifier. ``BalancedBaggingClassifier`` chains a +``RandomUnderSampler`` and a given classifier while ``BaggingClassifier`` is +using directly the imbalanced data. + +Balancing the data set before training the classifier improve the +classification performance. In addition, it avoids the ensemble to focus on the +majority class which would be a known drawback of the decision tree +classifiers. + +""" + +# Authors: Guillaume Lemaitre +# License: MIT + +from collections import Counter +import itertools + +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.ensemble import BaggingClassifier +from sklearn.metrics import confusion_matrix + +from imblearn.datasets import make_imbalance +from imblearn.ensemble import BalancedBaggingClassifier + +from imblearn.metrics import classification_report_imbalanced + + +def plot_confusion_matrix(cm, classes, + normalize=False, + title='Confusion matrix', + cmap=plt.cm.Blues): + """ + This function prints and plots the confusion matrix. + Normalization can be applied by setting `normalize=True`. + """ + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + print("Normalized confusion matrix") + else: + print('Confusion matrix, without normalization') + + print(cm) + + plt.imshow(cm, interpolation='nearest', cmap=cmap) + plt.title(title) + plt.colorbar() + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes, rotation=45) + plt.yticks(tick_marks, classes) + + fmt = '.2f' if normalize else 'd' + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, format(cm[i, j], fmt), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black") + + plt.tight_layout() + plt.ylabel('True label') + plt.xlabel('Predicted label') + + +iris = load_iris() +X, y = make_imbalance(iris.data, iris.target, ratio={0: 25, 1: 40, 2: 50}, + random_state=0) +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + +bagging = BaggingClassifier(random_state=0) +balanced_bagging = BalancedBaggingClassifier(random_state=0) + +print('Class distribution of the training set: {}'.format(Counter(y_train))) + +bagging.fit(X_train, y_train) +balanced_bagging.fit(X_train, y_train) + +print('Class distribution of the test set: {}'.format(Counter(y_test))) + +print('Classification results using a bagging classifier on imbalanced data') +y_pred_bagging = bagging.predict(X_test) +print(classification_report_imbalanced(y_test, y_pred_bagging)) +cm_bagging = confusion_matrix(y_test, y_pred_bagging) +plt.figure() +plot_confusion_matrix(cm_bagging, classes=iris.target_names, + title='Confusion matrix using BaggingClassifier') + +print('Classification results using a bagging classifier on balanced data') +y_pred_balanced_bagging = balanced_bagging.predict(X_test) +print(classification_report_imbalanced(y_test, y_pred_balanced_bagging)) +cm_balanced_bagging = confusion_matrix(y_test, y_pred_balanced_bagging) +plt.figure() +plot_confusion_matrix(cm_balanced_bagging, classes=iris.target_names, + title='Confusion matrix using BalancedBaggingClassifier') + +plt.show() diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index 2f80220fa..7d15eb4f5 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -285,9 +285,30 @@ class BalancedBaggingClassifier(BaggingClassifier): was never left out during the bootstrap. In this case, `oob_decision_function_` might contain NaN. + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.metrics import confusion_matrix + >>> from imblearn.ensemble import \ +BalancedBaggingClassifier # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape {}'.format(Counter(y))) + Original dataset shape Counter({1: 900, 0: 100}) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... random_state=0) + >>> bbc = BalancedBaggingClassifier(random_state=42) + >>> bbc.fit(X_train, y_train) # doctest: +ELLIPSIS + BalancedBaggingClassifier(...) + >>> y_pred = bbc.predict(X_test) + >>> print(confusion_matrix(y_test, y_pred)) + [[ 23 0] + [ 2 225]] + References ---------- - .. [1] L. Breiman, "Pasting small votes for classification in large + .. [1] L". Breiman, Pasting small votes for classification in large databases and on-line", Machine Learning, 36(1), 85-103, 1999. .. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140, 1996. From 9c79de271aec1fd19309f72d2b7fdef7837ee2f7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 11 Aug 2017 19:59:08 +0200 Subject: [PATCH 04/11] FIX not passing sample_weight at fit --- imblearn/ensemble/easy_ensemble.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index 7d15eb4f5..192848358 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -369,3 +369,24 @@ def _validate_estimator(self, default=DecisionTreeClassifier()): [('sampler', RandomUnderSampler(ratio=self.ratio, replacement=self.replacement)), ('classifier', base_estimator)]) + + def fit(self, X, y): + """Build a Bagging ensemble of estimators from the training + set (X, y). + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The training input samples. + + y : array-like, shape = [n_samples] + The target values. + + Returns + ------- + self : object + Returns self. + """ + # RandomUnderSampler is not supporting sample_weight. We need to pass + # None. + return self._fit(X, y, self.max_samples, sample_weight=None) From df86a1bca6be3e5ab4d398a239394b3699298b4e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 11 Aug 2017 20:08:28 +0200 Subject: [PATCH 05/11] DOC add api documentation --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index 4d3bc55f4..477bc060f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -104,6 +104,7 @@ Ensemble methods :toctree: generated/ ensemble.BalanceCascade + ensemble.BalancedBaggingClassifier ensemble.EasyEnsemble From 943a06c0138a5034e23aaf9bd78caff74d47695d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 11 Aug 2017 22:00:55 +0200 Subject: [PATCH 06/11] DOC fix docstring --- imblearn/ensemble/easy_ensemble.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index 192848358..e3cd8a6dd 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -197,13 +197,13 @@ class BalancedBaggingClassifier(BaggingClassifier): max_samples : int or float, optional (default=1.0) The number of samples to draw from X to train each base estimator. - - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. + If int, then draw `max_samples` samples. + If float, then draw `max_samples * X.shape[0]` samples. max_features : int or float, optional (default=1.0) The number of features to draw from X to train each base estimator. - - If int, then draw `max_features` features. - - If float, then draw `max_features * X.shape[1]` features. + If int, then draw `max_features` features. + If float, then draw `max_features * X.shape[1]` features. bootstrap : boolean, optional (default=True) Whether samples are drawn with replacement. @@ -219,6 +219,7 @@ class BalancedBaggingClassifier(BaggingClassifier): When set to True, reuse the solution of the previous call to fit and add more estimators to the ensemble, otherwise, just fit a whole new ensemble. + .. versionadded:: 0.17 *warm_start* constructor parameter. @@ -250,7 +251,7 @@ class BalancedBaggingClassifier(BaggingClassifier): If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used - by `np.random`. + by ```np.random```. verbose : int, optional (default=0) Controls the verbosity of the building process. @@ -285,6 +286,9 @@ class BalancedBaggingClassifier(BaggingClassifier): was never left out during the bootstrap. In this case, `oob_decision_function_` might contain NaN. + Examples + -------- + >>> from collections import Counter >>> from sklearn.datasets import make_classification >>> from sklearn.model_selection import train_test_split @@ -308,7 +312,7 @@ class BalancedBaggingClassifier(BaggingClassifier): References ---------- - .. [1] L". Breiman, Pasting small votes for classification in large + .. [1] L. Breiman, "Pasting small votes for classification in large databases and on-line", Machine Learning, 36(1), 85-103, 1999. .. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140, 1996. From 6d2ab3c96d81aab85dea92d1ca60af65392859ea Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 11 Aug 2017 22:59:25 +0200 Subject: [PATCH 07/11] iter --- imblearn/ensemble/easy_ensemble.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index e3cd8a6dd..aa5cae36d 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -197,13 +197,13 @@ class BalancedBaggingClassifier(BaggingClassifier): max_samples : int or float, optional (default=1.0) The number of samples to draw from X to train each base estimator. - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. + - If int, then draw ``max_samples`` samples. + - If float, then draw ``max_samples * X.shape[0]`` samples. max_features : int or float, optional (default=1.0) The number of features to draw from X to train each base estimator. - If int, then draw `max_features` features. - If float, then draw `max_features * X.shape[1]` features. + - If int, then draw ``max_features`` features. + - If float, then draw ``max_features * X.shape[1]`` features. bootstrap : boolean, optional (default=True) Whether samples are drawn with replacement. @@ -248,10 +248,12 @@ class BalancedBaggingClassifier(BaggingClassifier): If -1, then the number of jobs is set to the number of cores. random_state : int, RandomState instance or None, optional (default=None) - If int, random_state is the seed used by the random number generator; - If RandomState instance, random_state is the random number generator; - If None, the random number generator is the RandomState instance used - by ```np.random```. + - If int, ``random_state`` is the seed used by the random number + generator; + - If ``RandomState`` instance, random_state is the random + number generator; + - If ``None``, the random number generator is the + ``RandomState`` instance used by ``np.random``. verbose : int, optional (default=0) Controls the verbosity of the building process. @@ -284,7 +286,7 @@ class BalancedBaggingClassifier(BaggingClassifier): Decision function computed with out-of-bag estimate on the training set. If n_estimators is small it might be possible that a data point was never left out during the bootstrap. In this case, - `oob_decision_function_` might contain NaN. + ``oob_decision_function_`` might contain NaN. Examples -------- From 7912de83ad7dc616e04005f617fc7fbb40ef5623 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 12 Aug 2017 00:09:11 +0200 Subject: [PATCH 08/11] DOC fix docstring --- imblearn/ensemble/easy_ensemble.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index aa5cae36d..a694d55f6 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -197,11 +197,13 @@ class BalancedBaggingClassifier(BaggingClassifier): max_samples : int or float, optional (default=1.0) The number of samples to draw from X to train each base estimator. + - If int, then draw ``max_samples`` samples. - If float, then draw ``max_samples * X.shape[0]`` samples. max_features : int or float, optional (default=1.0) The number of features to draw from X to train each base estimator. + - If int, then draw ``max_features`` features. - If float, then draw ``max_features * X.shape[1]`` features. From 1390f96bd19e799894632d76694bb7940df5cb49 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 12 Aug 2017 15:47:39 +0200 Subject: [PATCH 09/11] DOC add user guide entry and cross referencing --- doc/ensemble.rst | 56 ++++++++++++++++++++++++++++ imblearn/ensemble/balance_cascade.py | 4 +- imblearn/ensemble/easy_ensemble.py | 40 ++++++++++++-------- 3 files changed, 83 insertions(+), 17 deletions(-) diff --git a/doc/ensemble.rst b/doc/ensemble.rst index cf934d73f..772257ae5 100644 --- a/doc/ensemble.rst +++ b/doc/ensemble.rst @@ -6,6 +6,11 @@ Ensemble of samplers .. currentmodule:: imblearn.ensemble +.. _ensemble_samplers: + +Samplers +-------- + An imbalanced data set can be balanced by creating several balanced subsets. The module :mod:`imblearn.ensemble` allows to create such sets. @@ -54,3 +59,54 @@ parameter ``n_max_subset`` and an additional bootstraping can be activated with See :ref:`sphx_glr_auto_examples_ensemble_plot_easy_ensemble.py` and :ref:`sphx_glr_auto_examples_ensemble_plot_balance_cascade.py`. + +.. _ensemble_meta_estimators: + +Chaining ensemble of samplers and estimators +-------------------------------------------- + +In ensemble classifiers, bagging methods build several estimators on different +randomly selected subset of data. In scikit-learn, this classifier is named +``BaggingClassifier``. However, this classifier does not allow to balance each +subset of data. Therefore, when training on imbalanced data set, this +classifier will favor the majority classes:: + + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.metrics import confusion_matrix + >>> from sklearn.ensemble import BaggingClassifier + >>> from sklearn.tree import DecisionTreeClassifier + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + >>> bc = BaggingClassifier(base_estimator=DecisionTreeClassifier(), + ... random_state=0) + >>> bc.fit(X_train, y_train) #doctest: +ELLIPSIS + BaggingClassifier(...) + >>> y_pred = bc.predict(X_test) + >>> confusion_matrix(y_test, y_pred) + array([[ 0, 0, 12], + [ 0, 0, 59], + [ 0, 0, 1179]]) + +:class:`BalancedBaggingClassifier` allows to resample each subset of data +before to train each estimator of the ensemble. In short, it combines the +output of an :class:`EasyEnsemble` sampler with an ensemble of classifiers +(i.e. ``BaggingClassifier``). Therefore, :class:`BalancedBaggingClassifier` +takes the same parameters than the scikit-learn +``BaggingClassifier``. Additionally, there is two additional parameters, +``ratio`` and ``replacement``, as in the :class:`EasyEnsemble` sampler:: + + + >>> from imblearn.ensemble import BalancedBaggingClassifier + >>> bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(), + ... ratio='auto', + ... replacement=False, + ... random_state=0) + >>> bbc.fit(X, y) # doctest: +ELLIPSIS + BalancedBaggingClassifier(...) + >>> y_pred = bbc.predict(X_test) + >>> confusion_matrix(y_test, y_pred) + array([[ 12, 0, 0], + [ 0, 55, 4], + [ 68, 53, 1058]]) + +See +:ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`. diff --git a/imblearn/ensemble/balance_cascade.py b/imblearn/ensemble/balance_cascade.py index f88c873ed..b9afba54c 100644 --- a/imblearn/ensemble/balance_cascade.py +++ b/imblearn/ensemble/balance_cascade.py @@ -27,7 +27,7 @@ class BalanceCascade(BaseEnsembleSampler): This method iteratively select subset and make an ensemble of the different sets. The selection is performed using a specific classifier. - Read more in the :ref:`User Guide `. + Read more in the :ref:`User Guide `. Parameters ---------- @@ -99,7 +99,7 @@ class BalanceCascade(BaseEnsembleSampler): See also -------- - EasyEnsemble + BalancedBaggingClassifier, EasyEnsemble References ---------- diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index 38a94f798..506f6d802 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -48,7 +48,7 @@ class EasyEnsemble(BaseEnsembleSampler): This method iteratively select a random subset and make an ensemble of the different sets. - Read more in the :ref:`User Guide `. + Read more in the :ref:`User Guide `. Parameters ---------- @@ -95,7 +95,7 @@ class EasyEnsemble(BaseEnsembleSampler): See also -------- - BalanceCascade + BalanceCascade, BalancedBaggingClassifier References ---------- @@ -192,7 +192,7 @@ class BalancedBaggingClassifier(BaggingClassifier): implementation. It includes an additional step to balance the training set at fit time using a ``RandomUnderSampler``. - Read more in the :ref:`User Guide `. + Read more in the :ref:`User Guide `. Parameters ---------- @@ -298,6 +298,28 @@ class BalancedBaggingClassifier(BaggingClassifier): was never left out during the bootstrap. In this case, ``oob_decision_function_`` might contain NaN. + Notes + ----- + + See + :ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`. + + See also + -------- + BalanceCascade, EasyEnsemble + + References + ---------- + .. [1] L. Breiman, "Pasting small votes for classification in large + databases and on-line", Machine Learning, 36(1), 85-103, 1999. + .. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140, + 1996. + .. [3] T. Ho, "The random subspace method for constructing decision + forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844, + 1998. + .. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine + Learning and Knowledge Discovery in Databases, 346-361, 2012. + Examples -------- @@ -322,18 +344,6 @@ class BalancedBaggingClassifier(BaggingClassifier): [[ 23 0] [ 2 225]] - References - ---------- - .. [1] L. Breiman, "Pasting small votes for classification in large - databases and on-line", Machine Learning, 36(1), 85-103, 1999. - .. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140, - 1996. - .. [3] T. Ho, "The random subspace method for constructing decision - forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844, - 1998. - .. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine - Learning and Knowledge Discovery in Databases, 346-361, 2012. - """ def __init__(self, base_estimator=None, From 58fa5a1f3f098177ea10e67c755bd4b64f2a214f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 14 Aug 2017 21:05:34 +0200 Subject: [PATCH 10/11] FIX mv into a new module --- imblearn/ensemble/__init__.py | 4 +- imblearn/ensemble/classifier.py | 269 +++++++++++++++++++++++++++++ imblearn/ensemble/easy_ensemble.py | 263 +--------------------------- 3 files changed, 273 insertions(+), 263 deletions(-) create mode 100644 imblearn/ensemble/classifier.py diff --git a/imblearn/ensemble/__init__.py b/imblearn/ensemble/__init__.py index 13e148991..35cbd24eb 100644 --- a/imblearn/ensemble/__init__.py +++ b/imblearn/ensemble/__init__.py @@ -3,7 +3,9 @@ under-sampled subsets combined inside an ensemble. """ -from .easy_ensemble import EasyEnsemble, BalancedBaggingClassifier +from .easy_ensemble import EasyEnsemble from .balance_cascade import BalanceCascade +from .classifier import BalancedBaggingClassifier + __all__ = ['EasyEnsemble', 'BalancedBaggingClassifier', 'BalanceCascade'] diff --git a/imblearn/ensemble/classifier.py b/imblearn/ensemble/classifier.py new file mode 100644 index 000000000..dca5be0a3 --- /dev/null +++ b/imblearn/ensemble/classifier.py @@ -0,0 +1,269 @@ +"""Ensemble predictors combining a sampler and a classifier.""" + +# Authors: Guillaume Lemaitre +# Christos Aridas +# License: MIT + +import numbers + +import sklearn +from sklearn.base import clone +from sklearn.ensemble import BaggingClassifier +from sklearn.tree import DecisionTreeClassifier +from sklearn.ensemble.bagging import _generate_bagging_indices +from sklearn.utils import indices_to_mask + +from ..pipeline import Pipeline +from ..under_sampling import RandomUnderSampler + +old_generate = _generate_bagging_indices + + +def _masked_bagging_indices(random_state, bootstrap_features, + bootstrap_samples, n_features, n_samples, + max_features, max_samples): + """Monkey-patch to always get a mask instead of indices""" + feature_indices, sample_indices = old_generate(random_state, + bootstrap_features, + bootstrap_samples, + n_features, n_samples, + max_features, max_samples) + sample_indices = indices_to_mask(sample_indices, n_samples) + + return feature_indices, sample_indices + + +sklearn.ensemble.bagging._generate_bagging_indices = _masked_bagging_indices + + +class BalancedBaggingClassifier(BaggingClassifier): + """A Bagging classifier with additional balancing. + + This implementation of Bagging is similar to the scikit-learn + implementation. It includes an additional step to balance the training set + at fit time using a ``RandomUnderSampler``. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + base_estimator : object or None, optional (default=None) + The base estimator to fit on random subsets of the dataset. + If None, then the base estimator is a decision tree. + + n_estimators : int, optional (default=10) + The number of base estimators in the ensemble. + + max_samples : int or float, optional (default=1.0) + The number of samples to draw from X to train each base estimator. + + - If int, then draw ``max_samples`` samples. + - If float, then draw ``max_samples * X.shape[0]`` samples. + + max_features : int or float, optional (default=1.0) + The number of features to draw from X to train each base estimator. + + - If int, then draw ``max_features`` features. + - If float, then draw ``max_features * X.shape[1]`` features. + + bootstrap : boolean, optional (default=True) + Whether samples are drawn with replacement. + + bootstrap_features : boolean, optional (default=False) + Whether features are drawn with replacement. + + oob_score : bool + Whether to use out-of-bag samples to estimate + the generalization error. + + warm_start : bool, optional (default=False) + When set to True, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit + a whole new ensemble. + + .. versionadded:: 0.17 + *warm_start* constructor parameter. + + ratio : str, dict, or callable, optional (default='auto') + Ratio to use for resampling the data set. + + - If ``str``, has to be one of: (i) ``'minority'``: resample the + minority class; (ii) ``'majority'``: resample the majority class, + (iii) ``'not minority'``: resample all classes apart of the minority + class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``: + correspond to ``'all'`` with for over-sampling methods and ``'not + minority'`` for under-sampling methods. The classes targeted will be + over-sampled or under-sampled to achieve an equal number of sample + with the majority or minority class. + - If ``dict``, the keys correspond to the targeted classes. The values + correspond to the desired number of samples. + - If callable, function taking ``y`` and returns a ``dict``. The keys + correspond to the targeted classes. The values correspond to the + desired number of samples. + + replacement : bool, optional (default=False) + Whether or not to sample randomly with replacement or not. + + n_jobs : int, optional (default=1) + The number of jobs to run in parallel for both `fit` and `predict`. + If -1, then the number of jobs is set to the number of cores. + + random_state : int, RandomState instance or None, optional (default=None) + - If int, ``random_state`` is the seed used by the random number + generator; + - If ``RandomState`` instance, random_state is the random + number generator; + - If ``None``, the random number generator is the + ``RandomState`` instance used by ``np.random``. + + verbose : int, optional (default=0) + Controls the verbosity of the building process. + + Attributes + ---------- + base_estimator_ : estimator + The base estimator from which the ensemble is grown. + + estimators_ : list of estimators + The collection of fitted base estimators. + + estimators_samples_ : list of arrays + The subset of drawn samples (i.e., the in-bag samples) for each base + estimator. Each subset is defined by a boolean mask. + + estimators_features_ : list of arrays + The subset of drawn features for each base estimator. + + classes_ : array of shape = [n_classes] + The classes labels. + + n_classes_ : int or list + The number of classes. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + + oob_decision_function_ : array of shape = [n_samples, n_classes] + Decision function computed with out-of-bag estimate on the training + set. If n_estimators is small it might be possible that a data point + was never left out during the bootstrap. In this case, + ``oob_decision_function_`` might contain NaN. + + Notes + ----- + + See + :ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`. + + See also + -------- + BalanceCascade, EasyEnsemble + + References + ---------- + .. [1] L. Breiman, "Pasting small votes for classification in large + databases and on-line", Machine Learning, 36(1), 85-103, 1999. + .. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140, + 1996. + .. [3] T. Ho, "The random subspace method for constructing decision + forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844, + 1998. + .. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine + Learning and Knowledge Discovery in Databases, 346-361, 2012. + + Examples + -------- + + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.metrics import confusion_matrix + >>> from imblearn.ensemble import \ +BalancedBaggingClassifier # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape {}'.format(Counter(y))) + Original dataset shape Counter({1: 900, 0: 100}) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... random_state=0) + >>> bbc = BalancedBaggingClassifier(random_state=42) + >>> bbc.fit(X_train, y_train) # doctest: +ELLIPSIS + BalancedBaggingClassifier(...) + >>> y_pred = bbc.predict(X_test) + >>> print(confusion_matrix(y_test, y_pred)) + [[ 23 0] + [ 2 225]] + + """ + def __init__(self, + base_estimator=None, + n_estimators=10, + max_samples=1.0, + max_features=1.0, + bootstrap=True, + bootstrap_features=False, + oob_score=False, + warm_start=False, + ratio='auto', + replacement=False, + n_jobs=1, + random_state=None, + verbose=0): + + super(BaggingClassifier, self).__init__( + base_estimator, + n_estimators=n_estimators, + max_samples=max_samples, + max_features=max_features, + bootstrap=bootstrap, + bootstrap_features=bootstrap_features, + oob_score=oob_score, + warm_start=warm_start, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose) + self.ratio = ratio + self.replacement = replacement + + def _validate_estimator(self, default=DecisionTreeClassifier()): + """Check the estimator and the n_estimator attribute, set the + `base_estimator_` attribute.""" + if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): + raise ValueError("n_estimators must be an integer, " + "got {0}.".format(type(self.n_estimators))) + + if self.n_estimators <= 0: + raise ValueError("n_estimators must be greater than zero, " + "got {0}.".format(self.n_estimators)) + + if self.base_estimator is not None: + base_estimator = clone(self.base_estimator) + else: + base_estimator = clone(default) + + self.base_estimator_ = Pipeline( + [('sampler', RandomUnderSampler(ratio=self.ratio, + replacement=self.replacement)), + ('classifier', base_estimator)]) + + def fit(self, X, y): + """Build a Bagging ensemble of estimators from the training + set (X, y). + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The training input samples. + + y : array-like, shape = [n_samples] + The target values. + + Returns + ------- + self : object + Returns self. + """ + # RandomUnderSampler is not supporting sample_weight. We need to pass + # None. + return self._fit(X, y, self.max_samples, sample_weight=None) diff --git a/imblearn/ensemble/easy_ensemble.py b/imblearn/ensemble/easy_ensemble.py index 506f6d802..07d771d85 100644 --- a/imblearn/ensemble/easy_ensemble.py +++ b/imblearn/ensemble/easy_ensemble.py @@ -4,44 +4,16 @@ # Christos Aridas # License: MIT -import numbers - import numpy as np -import sklearn -from sklearn.base import clone -from sklearn.ensemble import BaggingClassifier -from sklearn.ensemble.bagging import _generate_bagging_indices -from sklearn.tree import DecisionTreeClassifier -from sklearn.utils import check_random_state, indices_to_mask +from sklearn.utils import check_random_state from .base import BaseEnsembleSampler -from ..pipeline import Pipeline from ..under_sampling import RandomUnderSampler MAX_INT = np.iinfo(np.int32).max -old_generate = _generate_bagging_indices - - -def _masked_bagging_indices(random_state, bootstrap_features, - bootstrap_samples, n_features, n_samples, - max_features, max_samples): - """Monkey-patch to always get a mask instead of indices""" - feature_indices, sample_indices = old_generate(random_state, - bootstrap_features, - bootstrap_samples, - n_features, n_samples, - max_features, max_samples) - sample_indices = indices_to_mask(sample_indices, n_samples) - - return feature_indices, sample_indices - - -sklearn.ensemble.bagging._generate_bagging_indices = _masked_bagging_indices - - class EasyEnsemble(BaseEnsembleSampler): """Create an ensemble sets by iteratively applying random under-sampling. @@ -183,236 +155,3 @@ def _sample(self, X, y): np.array(idx_under)) else: return np.array(X_resampled), np.array(y_resampled) - - -class BalancedBaggingClassifier(BaggingClassifier): - """A Bagging classifier with additional balancing. - - This implementation of Bagging is similar to the scikit-learn - implementation. It includes an additional step to balance the training set - at fit time using a ``RandomUnderSampler``. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - base_estimator : object or None, optional (default=None) - The base estimator to fit on random subsets of the dataset. - If None, then the base estimator is a decision tree. - - n_estimators : int, optional (default=10) - The number of base estimators in the ensemble. - - max_samples : int or float, optional (default=1.0) - The number of samples to draw from X to train each base estimator. - - - If int, then draw ``max_samples`` samples. - - If float, then draw ``max_samples * X.shape[0]`` samples. - - max_features : int or float, optional (default=1.0) - The number of features to draw from X to train each base estimator. - - - If int, then draw ``max_features`` features. - - If float, then draw ``max_features * X.shape[1]`` features. - - bootstrap : boolean, optional (default=True) - Whether samples are drawn with replacement. - - bootstrap_features : boolean, optional (default=False) - Whether features are drawn with replacement. - - oob_score : bool - Whether to use out-of-bag samples to estimate - the generalization error. - - warm_start : bool, optional (default=False) - When set to True, reuse the solution of the previous call to fit - and add more estimators to the ensemble, otherwise, just fit - a whole new ensemble. - - .. versionadded:: 0.17 - *warm_start* constructor parameter. - - ratio : str, dict, or callable, optional (default='auto') - Ratio to use for resampling the data set. - - - If ``str``, has to be one of: (i) ``'minority'``: resample the - minority class; (ii) ``'majority'``: resample the majority class, - (iii) ``'not minority'``: resample all classes apart of the minority - class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``: - correspond to ``'all'`` with for over-sampling methods and ``'not - minority'`` for under-sampling methods. The classes targeted will be - over-sampled or under-sampled to achieve an equal number of sample - with the majority or minority class. - - If ``dict``, the keys correspond to the targeted classes. The values - correspond to the desired number of samples. - - If callable, function taking ``y`` and returns a ``dict``. The keys - correspond to the targeted classes. The values correspond to the - desired number of samples. - - replacement : bool, optional (default=False) - Whether or not to sample randomly with replacement or not. - - n_jobs : int, optional (default=1) - The number of jobs to run in parallel for both `fit` and `predict`. - If -1, then the number of jobs is set to the number of cores. - - random_state : int, RandomState instance or None, optional (default=None) - - If int, ``random_state`` is the seed used by the random number - generator; - - If ``RandomState`` instance, random_state is the random - number generator; - - If ``None``, the random number generator is the - ``RandomState`` instance used by ``np.random``. - - verbose : int, optional (default=0) - Controls the verbosity of the building process. - - Attributes - ---------- - base_estimator_ : estimator - The base estimator from which the ensemble is grown. - - estimators_ : list of estimators - The collection of fitted base estimators. - - estimators_samples_ : list of arrays - The subset of drawn samples (i.e., the in-bag samples) for each base - estimator. Each subset is defined by a boolean mask. - - estimators_features_ : list of arrays - The subset of drawn features for each base estimator. - - classes_ : array of shape = [n_classes] - The classes labels. - - n_classes_ : int or list - The number of classes. - - oob_score_ : float - Score of the training dataset obtained using an out-of-bag estimate. - - oob_decision_function_ : array of shape = [n_samples, n_classes] - Decision function computed with out-of-bag estimate on the training - set. If n_estimators is small it might be possible that a data point - was never left out during the bootstrap. In this case, - ``oob_decision_function_`` might contain NaN. - - Notes - ----- - - See - :ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`. - - See also - -------- - BalanceCascade, EasyEnsemble - - References - ---------- - .. [1] L. Breiman, "Pasting small votes for classification in large - databases and on-line", Machine Learning, 36(1), 85-103, 1999. - .. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140, - 1996. - .. [3] T. Ho, "The random subspace method for constructing decision - forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844, - 1998. - .. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine - Learning and Knowledge Discovery in Databases, 346-361, 2012. - - Examples - -------- - - >>> from collections import Counter - >>> from sklearn.datasets import make_classification - >>> from sklearn.model_selection import train_test_split - >>> from sklearn.metrics import confusion_matrix - >>> from imblearn.ensemble import \ -BalancedBaggingClassifier # doctest: +NORMALIZE_WHITESPACE - >>> X, y = make_classification(n_classes=2, class_sep=2, - ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, - ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) - >>> print('Original dataset shape {}'.format(Counter(y))) - Original dataset shape Counter({1: 900, 0: 100}) - >>> X_train, X_test, y_train, y_test = train_test_split(X, y, - ... random_state=0) - >>> bbc = BalancedBaggingClassifier(random_state=42) - >>> bbc.fit(X_train, y_train) # doctest: +ELLIPSIS - BalancedBaggingClassifier(...) - >>> y_pred = bbc.predict(X_test) - >>> print(confusion_matrix(y_test, y_pred)) - [[ 23 0] - [ 2 225]] - - """ - def __init__(self, - base_estimator=None, - n_estimators=10, - max_samples=1.0, - max_features=1.0, - bootstrap=True, - bootstrap_features=False, - oob_score=False, - warm_start=False, - ratio='auto', - replacement=False, - n_jobs=1, - random_state=None, - verbose=0): - - super(BaggingClassifier, self).__init__( - base_estimator, - n_estimators=n_estimators, - max_samples=max_samples, - max_features=max_features, - bootstrap=bootstrap, - bootstrap_features=bootstrap_features, - oob_score=oob_score, - warm_start=warm_start, - n_jobs=n_jobs, - random_state=random_state, - verbose=verbose) - self.ratio = ratio - self.replacement = replacement - - def _validate_estimator(self, default=DecisionTreeClassifier()): - """Check the estimator and the n_estimator attribute, set the - `base_estimator_` attribute.""" - if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): - raise ValueError("n_estimators must be an integer, " - "got {0}.".format(type(self.n_estimators))) - - if self.n_estimators <= 0: - raise ValueError("n_estimators must be greater than zero, " - "got {0}.".format(self.n_estimators)) - - if self.base_estimator is not None: - base_estimator = clone(self.base_estimator) - else: - base_estimator = clone(default) - - self.base_estimator_ = Pipeline( - [('sampler', RandomUnderSampler(ratio=self.ratio, - replacement=self.replacement)), - ('classifier', base_estimator)]) - - def fit(self, X, y): - """Build a Bagging ensemble of estimators from the training - set (X, y). - - Parameters - ---------- - X : array-like of shape = [n_samples, n_features] - The training input samples. - - y : array-like, shape = [n_samples] - The target values. - - Returns - ------- - self : object - Returns self. - """ - # RandomUnderSampler is not supporting sample_weight. We need to pass - # None. - return self._fit(X, y, self.max_samples, sample_weight=None) From e0e3988285b66b027805fd68e63b47ffcb2b25f6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 14 Aug 2017 21:18:09 +0200 Subject: [PATCH 11/11] FIX add missing dependency --- imblearn/ensemble/classifier.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/imblearn/ensemble/classifier.py b/imblearn/ensemble/classifier.py index dca5be0a3..81d528ac9 100644 --- a/imblearn/ensemble/classifier.py +++ b/imblearn/ensemble/classifier.py @@ -6,6 +6,8 @@ import numbers +import numpy as np + import sklearn from sklearn.base import clone from sklearn.ensemble import BaggingClassifier