diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 90e4e66452af3..d841667d97e8c 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -654,6 +654,13 @@ Metrics :issue:`9515` by :user:`Alan Liddell ` and :user:`Manh Dao `. +Ensemble + +- Fix allowing to obtain deterministic with :class:`BaseBagging` estimator, + when comparing results generated at fit time with the one using the object + attributes when ``random_state`` is set. :issue:`9723` by :user:`Guillaume + Lemaitre `. + Neighbors - Fixed a bug so ``predict`` in :class:`neighbors.RadiusNeighborsRegressor` can @@ -828,6 +835,15 @@ Cluster :class:`cluster.AgglomerativeClustering`. :issue:`9875` by :user:`Kumar Ashutosh `. +Ensemble + +- Classes derived from :class:`ensemble.BaseBagging`. The attribute + ``estimators_samples_`` will return a list of arrays containing the indices + selected for each bootstrap instead of a list of arrays containing the mask + of the samples selected for each bootstrap. Indices allows to repeat samples + while mask does not allow this functionality. :issue:`9524` by + :user:`Guillaume Lemaitre `. + Imputer - Deprecate :class:`preprocessing.Imputer` and move the corresponding module to diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 777e25edec064..f82221c959b21 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -110,7 +110,6 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, estimator.fit(X[:, features], y, sample_weight=curr_sample_weight) - # Draw samples, using a mask, and then fit else: estimator.fit((X[indices])[:, features], y[indices]) @@ -412,7 +411,7 @@ def _get_estimators_indices(self): def estimators_samples_(self): """The subset of drawn samples for each base estimator. - Returns a dynamically generated list of boolean masks identifying + Returns a dynamically generated list of indices identifying the samples used for fitting each member of the ensemble, i.e., the in-bag samples. @@ -420,12 +419,8 @@ def estimators_samples_(self): to reduce the object memory footprint by not storing the sampling data. Thus fetching the property may be slower than expected. """ - sample_masks = [] - for _, sample_indices in self._get_estimators_indices(): - mask = indices_to_mask(sample_indices, self._n_samples) - sample_masks.append(mask) - - return sample_masks + return [sample_indices + for _, sample_indices in self._get_estimators_indices()] class BaggingClassifier(BaseBagging, ClassifierMixin): @@ -512,7 +507,7 @@ class BaggingClassifier(BaseBagging, ClassifierMixin): estimators_samples_ : list of arrays The subset of drawn samples (i.e., the in-bag samples) for each base - estimator. Each subset is defined by a boolean mask. + estimator. Each subset is defined by an array of the indices selected. estimators_features_ : list of arrays The subset of drawn features for each base estimator. @@ -590,7 +585,7 @@ def _set_oob_score(self, X, y): self.estimators_samples_, self.estimators_features_): # Create mask for OOB samples - mask = ~samples + mask = ~indices_to_mask(samples, n_samples) if hasattr(estimator, "predict_proba"): predictions[mask, :] += estimator.predict_proba( @@ -885,7 +880,7 @@ class BaggingRegressor(BaseBagging, RegressorMixin): estimators_samples_ : list of arrays The subset of drawn samples (i.e., the in-bag samples) for each base - estimator. Each subset is defined by a boolean mask. + estimator. Each subset is defined by an array of the indices selected. estimators_features_ : list of arrays The subset of drawn features for each base estimator. @@ -996,7 +991,7 @@ def _set_oob_score(self, X, y): self.estimators_samples_, self.estimators_features_): # Create mask for OOB samples - mask = ~samples + mask = ~indices_to_mask(samples, n_samples) predictions[mask] += estimator.predict((X[mask, :])[:, features]) n_predictions[mask] += 1 diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 23884839e8fa5..505ec2f17b248 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -29,11 +29,12 @@ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.svm import SVC, SVR +from sklearn.random_projection import SparseRandomProjection from sklearn.pipeline import make_pipeline from sklearn.feature_selection import SelectKBest from sklearn.model_selection import train_test_split from sklearn.datasets import load_boston, load_iris, make_hastie_10_2 -from sklearn.utils import check_random_state +from sklearn.utils import check_random_state, hash from sklearn.preprocessing import FunctionTransformer from scipy.sparse import csc_matrix, csr_matrix @@ -222,6 +223,13 @@ def fit(self, X, y): assert_array_almost_equal(sparse_results, dense_results) +class DummySizeEstimator(BaseEstimator): + + def fit(self, X, y): + self.training_size_ = X.shape[0] + self.training_hash_ = hash(X) + + def test_bootstrap_samples(): # Test that bootstrapping samples generate non-perfect base estimators. rng = check_random_state(0) @@ -249,6 +257,17 @@ def test_bootstrap_samples(): assert_greater(base_estimator.score(X_train, y_train), ensemble.score(X_train, y_train)) + # check that each sampling correspond to a complete bootstrap resample. + # the size of each bootstrap should be the same as the input data but + # the data should be different (checked using the hash of the data). + ensemble = BaggingRegressor(base_estimator=DummySizeEstimator(), + bootstrap=True).fit(X_train, y_train) + training_hash = [] + for estimator in ensemble.estimators_: + assert estimator.training_size_ == X_train.shape[0] + training_hash.append(estimator.training_hash_) + assert len(set(training_hash)) == len(training_hash) + def test_bootstrap_features(): # Test that bootstrapping features may generate duplicate features. @@ -710,8 +729,8 @@ def test_estimators_samples(): # Test for correct formatting assert_equal(len(estimators_samples), len(estimators)) - assert_equal(len(estimators_samples[0]), len(X)) - assert_equal(estimators_samples[0].dtype.kind, 'b') + assert_equal(len(estimators_samples[0]), len(X) // 2) + assert_equal(estimators_samples[0].dtype.kind, 'i') # Re-fit single estimator to test for consistent sampling estimator_index = 0 @@ -729,6 +748,34 @@ def test_estimators_samples(): assert_array_almost_equal(orig_coefs, new_coefs) +def test_estimators_samples_deterministic(): + # This test is a regression test to check that with a random step + # (e.g. SparseRandomProjection) and a given random state, the results + # generated at fit time can be identically reproduced at a later time using + # data saved in object attributes. Check issue #9524 for full discussion. + + iris = load_iris() + X, y = iris.data, iris.target + + base_pipeline = make_pipeline(SparseRandomProjection(n_components=2), + LogisticRegression()) + clf = BaggingClassifier(base_estimator=base_pipeline, + max_samples=0.5, + random_state=0) + clf.fit(X, y) + pipeline_estimator_coef = clf.estimators_[0].steps[-1][1].coef_.copy() + + estimator = clf.estimators_[0] + estimator_sample = clf.estimators_samples_[0] + estimator_feature = clf.estimators_features_[0] + + X_train = (X[estimator_sample])[:, estimator_feature] + y_train = y[estimator_sample] + + estimator.fit(X_train, y_train) + assert_array_equal(estimator.steps[-1][1].coef_, pipeline_estimator_coef) + + def test_max_samples_consistency(): # Make sure validated max_samples and original max_samples are identical # when valid integer max_samples supplied by user diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 50963341f8edf..56177fca7a693 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -15,7 +15,7 @@ check_consistent_length, check_X_y, indexable, check_symmetric) from .class_weight import compute_class_weight, compute_sample_weight -from ._joblib import cpu_count, Parallel, Memory, delayed +from ._joblib import cpu_count, Parallel, Memory, delayed, hash from ._joblib import parallel_backend from ..exceptions import DataConversionWarning from ..utils.fixes import _Sequence as Sequence @@ -29,7 +29,8 @@ "column_or_1d", "safe_indexing", "check_consistent_length", "check_X_y", 'indexable', "check_symmetric", "indices_to_mask", "deprecated", - "cpu_count", "Parallel", "Memory", "delayed", "parallel_backend"] + "cpu_count", "Parallel", "Memory", "delayed", "parallel_backend", + "hash"] class Bunch(dict):