From b2339500d156153ebc7263ee10ad7ea479701095 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 10 Sep 2017 00:26:35 +0200 Subject: [PATCH 01/14] FIX enforce deterministic behaviour in BaseBagging --- doc/whats_new/v0.20.rst | 9 +++++++- sklearn/ensemble/bagging.py | 7 ++++++- sklearn/ensemble/tests/test_bagging.py | 29 ++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 51d2fab65be81..f4ae2c3158e71 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -40,7 +40,7 @@ Classifiers and regressors - Added :class:`naive_bayes.ComplementNB`, which implements the Complement Naive Bayes classifier described in Rennie et al. (2003). By :user:`Michael A. Alcorn `. - + Model evaluation - Added the :func:`metrics.balanced_accuracy` metric and a corresponding @@ -122,6 +122,13 @@ Decomposition, manifold learning and clustering with large datasets when ``n_components='mle'`` on Python 3 versions. :issue:`9886` by :user:`Hanmin Qin `. +Ensemble + +- Fix allowing to obtain deterministic with :class:`BaseBagging` estimator, + when comparing results generated at fit time with the one using the object + attributes when ``random_state`` is set. :issue:`9723` by :user:`Guillaume + Lemaitre `. + Metrics - Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 7ea3030bdf120..bc039cf124269 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -95,6 +95,11 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, n_samples, max_features, max_samples) + # Convert sample indices to a masked array; it is the same behavior + # than estimators_samples_ and necessary to obtain deterministic + # results when the random state is set. + sample_mask = indices_to_mask(indices, n_samples) + # Draw samples, using sample weights, and then fit if support_sample_weight: if sample_weight is None: @@ -113,7 +118,7 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, # Draw samples, using a mask, and then fit else: - estimator.fit((X[indices])[:, features], y[indices]) + estimator.fit((X[sample_mask])[:, features], y[sample_mask]) estimators.append(estimator) estimators_features.append(features) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 50820d4512b5b..e42879e2c74f4 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -28,6 +28,7 @@ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.svm import SVC, SVR +from sklearn.random_projection import SparseRandomProjection from sklearn.pipeline import make_pipeline from sklearn.feature_selection import SelectKBest from sklearn.model_selection import train_test_split @@ -721,6 +722,34 @@ def test_estimators_samples(): assert_array_almost_equal(orig_coefs, new_coefs) +def test_estimators_samples_deterministic(): + # This test is a regression test to check that with a random step + # (e.g. SparseRandomProjection) and a given random state, the results + # generated at fit time can be identically reproduced at a later time using + # data saved in object attributes. + + iris = load_iris() + X, y = iris.data, iris.target + + base_pipeline = make_pipeline(SparseRandomProjection(n_components=2), + LogisticRegression()) + clf = BaggingClassifier(base_estimator=base_pipeline, + max_samples=0.5, + random_state=0) + clf.fit(X, y) + pipeline_estimator_coef = clf.estimators_[0].steps[-1][1].coef_.copy() + + estimator = clf.estimators_[0] + estimator_sample = clf.estimators_samples_[0] + estimator_feature = clf.estimators_features_[0] + + X_train = (X[estimator_sample])[:, estimator_feature] + y_train = y[estimator_sample] + + estimator.fit(X_train, y_train) + assert_array_equal(estimator.steps[-1][1].coef_, pipeline_estimator_coef) + + def test_max_samples_consistency(): # Make sure validated max_samples and original max_samples are identical # when valid integer max_samples supplied by user From 1f62a49c8347a2ad79b8add1966c0b15b83fb100 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 10 Sep 2017 12:36:36 +0200 Subject: [PATCH 02/14] DOC add entry in whats new for model changes --- doc/whats_new/v0.20.rst | 4 ++++ sklearn/ensemble/bagging.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index f4ae2c3158e71..3ab5693817843 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -19,6 +19,10 @@ random sampling procedures. - :class:`isotonic.IsotonicRegression` (bug fix) - :class:`metrics.roc_auc_score` (bug fix) +- Classes dervied from :class:`ensemble.BaseBagging`. The base estimator may be + presented with the same samples in a different order in previous versions; + now the samples will always be in input order. + Details are listed in the changelog below. (While we are trying to better inform users by providing this information, we diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index bc039cf124269..587f2456c3a0a 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -96,7 +96,7 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, max_samples) # Convert sample indices to a masked array; it is the same behavior - # than estimators_samples_ and necessary to obtain deterministic + # as estimators_samples_ and necessary to obtain deterministic # results when the random state is set. sample_mask = indices_to_mask(indices, n_samples) From 9ea85d88b01b35706934535dd0780a171a8780a5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 20 Oct 2017 13:03:26 +0200 Subject: [PATCH 03/14] add comments --- sklearn/ensemble/tests/test_bagging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index e42879e2c74f4..39fcbb632f2e1 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -726,7 +726,7 @@ def test_estimators_samples_deterministic(): # This test is a regression test to check that with a random step # (e.g. SparseRandomProjection) and a given random state, the results # generated at fit time can be identically reproduced at a later time using - # data saved in object attributes. + # data saved in object attributes. Check issue #9524 for full discussion. iris = load_iris() X, y = iris.data, iris.target From c3f2471f75668d40178d2bb8a4d74e6cca91009a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 20 Oct 2017 13:05:55 +0200 Subject: [PATCH 04/14] revert space --- doc/whats_new/v0.20.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 3ab5693817843..5380b1a882f4f 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -40,7 +40,7 @@ Classifiers and regressors :class:`ensemble.GradientBoostingRegressor` now support early stopping via ``n_iter_no_change``, ``validation_fraction`` and ``tol``. :issue:`7071` by `Raghav RV`_ - + - Added :class:`naive_bayes.ComplementNB`, which implements the Complement Naive Bayes classifier described in Rennie et al. (2003). By :user:`Michael A. Alcorn `. From 77be464333d99eb6acf61e82866c97056290fb81 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 20 Oct 2017 13:06:28 +0200 Subject: [PATCH 05/14] revert space right place --- doc/whats_new/v0.20.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 5380b1a882f4f..799b777def2ee 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -40,11 +40,11 @@ Classifiers and regressors :class:`ensemble.GradientBoostingRegressor` now support early stopping via ``n_iter_no_change``, ``validation_fraction`` and ``tol``. :issue:`7071` by `Raghav RV`_ - + - Added :class:`naive_bayes.ComplementNB`, which implements the Complement Naive Bayes classifier described in Rennie et al. (2003). By :user:`Michael A. Alcorn `. - + Model evaluation - Added the :func:`metrics.balanced_accuracy` metric and a corresponding From d9bf029178262cdc60c2a26f1671a11f571e9040 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 20 Oct 2017 16:58:03 +0200 Subject: [PATCH 06/14] Typo --- doc/whats_new/v0.20.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 799b777def2ee..dba79820ae90d 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -19,7 +19,7 @@ random sampling procedures. - :class:`isotonic.IsotonicRegression` (bug fix) - :class:`metrics.roc_auc_score` (bug fix) -- Classes dervied from :class:`ensemble.BaseBagging`. The base estimator may be +- Classes derived from :class:`ensemble.BaseBagging`. The base estimator may be presented with the same samples in a different order in previous versions; now the samples will always be in input order. From f3a612ea1a171140bed0b7d80fa98fe002b73ac5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 23 Oct 2017 10:44:52 +0200 Subject: [PATCH 07/14] FIX change estimators_samples_ --- doc/whats_new/v0.20.rst | 6 +----- sklearn/ensemble/bagging.py | 22 ++++++---------------- sklearn/ensemble/tests/test_bagging.py | 4 ++-- 3 files changed, 9 insertions(+), 23 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index dba79820ae90d..f4ae2c3158e71 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -19,10 +19,6 @@ random sampling procedures. - :class:`isotonic.IsotonicRegression` (bug fix) - :class:`metrics.roc_auc_score` (bug fix) -- Classes derived from :class:`ensemble.BaseBagging`. The base estimator may be - presented with the same samples in a different order in previous versions; - now the samples will always be in input order. - Details are listed in the changelog below. (While we are trying to better inform users by providing this information, we @@ -44,7 +40,7 @@ Classifiers and regressors - Added :class:`naive_bayes.ComplementNB`, which implements the Complement Naive Bayes classifier described in Rennie et al. (2003). By :user:`Michael A. Alcorn `. - + Model evaluation - Added the :func:`metrics.balanced_accuracy` metric and a corresponding diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 587f2456c3a0a..305dbde904bac 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -95,11 +95,6 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, n_samples, max_features, max_samples) - # Convert sample indices to a masked array; it is the same behavior - # as estimators_samples_ and necessary to obtain deterministic - # results when the random state is set. - sample_mask = indices_to_mask(indices, n_samples) - # Draw samples, using sample weights, and then fit if support_sample_weight: if sample_weight is None: @@ -116,9 +111,8 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, estimator.fit(X[:, features], y, sample_weight=curr_sample_weight) - # Draw samples, using a mask, and then fit else: - estimator.fit((X[sample_mask])[:, features], y[sample_mask]) + estimator.fit((X[indices])[:, features], y[indices]) estimators.append(estimator) estimators_features.append(features) @@ -415,7 +409,7 @@ def _get_estimators_indices(self): def estimators_samples_(self): """The subset of drawn samples for each base estimator. - Returns a dynamically generated list of boolean masks identifying + Returns a dynamically generated list of indices identifying the samples used for fitting each member of the ensemble, i.e., the in-bag samples. @@ -423,12 +417,8 @@ def estimators_samples_(self): to reduce the object memory footprint by not storing the sampling data. Thus fetching the property may be slower than expected. """ - sample_masks = [] - for _, sample_indices in self._get_estimators_indices(): - mask = indices_to_mask(sample_indices, self._n_samples) - sample_masks.append(mask) - - return sample_masks + return [sample_indices + for _, sample_indices in self._get_estimators_indices()] class BaggingClassifier(BaseBagging, ClassifierMixin): @@ -591,7 +581,7 @@ def _set_oob_score(self, X, y): self.estimators_samples_, self.estimators_features_): # Create mask for OOB samples - mask = ~samples + mask = ~indices_to_mask(samples, n_samples) if hasattr(estimator, "predict_proba"): predictions[mask, :] += estimator.predict_proba( @@ -983,7 +973,7 @@ def _set_oob_score(self, X, y): self.estimators_samples_, self.estimators_features_): # Create mask for OOB samples - mask = ~samples + mask = ~indices_to_mask(samples, n_samples) predictions[mask] += estimator.predict((X[mask, :])[:, features]) n_predictions[mask] += 1 diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 39fcbb632f2e1..3290e5792951c 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -703,8 +703,8 @@ def test_estimators_samples(): # Test for correct formatting assert_equal(len(estimators_samples), len(estimators)) - assert_equal(len(estimators_samples[0]), len(X)) - assert_equal(estimators_samples[0].dtype.kind, 'b') + assert_equal(len(estimators_samples[0]), len(X) // 2) + assert_equal(estimators_samples[0].dtype.kind, 'i') # Re-fit single estimator to test for consistent sampling estimator_index = 0 From 146a9b1ed88dfafb8e250dc8951848f1d1505cfe Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 25 Oct 2017 14:17:39 +0200 Subject: [PATCH 08/14] TST check size training set and hash during bootstrapping --- sklearn/ensemble/tests/test_bagging.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 3290e5792951c..b9d5a66eb9465 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -34,6 +34,7 @@ from sklearn.model_selection import train_test_split from sklearn.datasets import load_boston, load_iris, make_hastie_10_2 from sklearn.utils import check_random_state +from sklearn.externals.joblib import hash from scipy.sparse import csc_matrix, csr_matrix @@ -219,6 +220,13 @@ def fit(self, X, y): assert_array_almost_equal(sparse_results, dense_results) +class DummySizeEstimator(BaseEstimator): + + def fit(self, X, y): + self.training_size = X.shape[0] + self.training_hash = hash(X) + + def test_bootstrap_samples(): # Test that bootstrapping samples generate non-perfect base estimators. rng = check_random_state(0) @@ -246,6 +254,16 @@ def test_bootstrap_samples(): assert_greater(base_estimator.score(X_train, y_train), ensemble.score(X_train, y_train)) + # check as well the size of the training set to be sure that only + # under-sampling is performed + ensemble = BaggingRegressor(base_estimator=DummySizeEstimator(), + bootstrap=True).fit(X_train, y_train) + training_hash = [] + for estimator in ensemble.estimators_: + assert estimator.training_size == X_train.shape[0] + training_hash.append(estimator.training_hash) + assert len(set(training_hash)) == len(training_hash) + def test_bootstrap_features(): # Test that bootstrapping features may generate duplicate features. From 02edd0e5d9dcd2ad81ecb67c546d82377d0ee265 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Oct 2017 11:52:22 +0100 Subject: [PATCH 09/14] DOC/TST API entry in doc and style correction --- sklearn/ensemble/bagging.py | 4 ++-- sklearn/ensemble/tests/test_bagging.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 305dbde904bac..2bc09b1990145 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -503,7 +503,7 @@ class BaggingClassifier(BaseBagging, ClassifierMixin): estimators_samples_ : list of arrays The subset of drawn samples (i.e., the in-bag samples) for each base - estimator. Each subset is defined by a boolean mask. + estimator. Each subset is defined by an array of the indices selected. estimators_features_ : list of arrays The subset of drawn features for each base estimator. @@ -865,7 +865,7 @@ class BaggingRegressor(BaseBagging, RegressorMixin): estimators_samples_ : list of arrays The subset of drawn samples (i.e., the in-bag samples) for each base - estimator. Each subset is defined by a boolean mask. + estimator. Each subset is defined by an array of the indices selected. estimators_features_ : list of arrays The subset of drawn features for each base estimator. diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index b9d5a66eb9465..833cb465b89f2 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -223,8 +223,8 @@ def fit(self, X, y): class DummySizeEstimator(BaseEstimator): def fit(self, X, y): - self.training_size = X.shape[0] - self.training_hash = hash(X) + self.training_size_ = X.shape[0] + self.training_hash_ = hash(X) def test_bootstrap_samples(): @@ -254,14 +254,15 @@ def test_bootstrap_samples(): assert_greater(base_estimator.score(X_train, y_train), ensemble.score(X_train, y_train)) - # check as well the size of the training set to be sure that only - # under-sampling is performed + # check that each sampling correspond to a complete bootstrap resample. + # the size of each bootstrap should be the same as the input data but + # the data should be different (checked using the hash of the data). ensemble = BaggingRegressor(base_estimator=DummySizeEstimator(), bootstrap=True).fit(X_train, y_train) training_hash = [] for estimator in ensemble.estimators_: - assert estimator.training_size == X_train.shape[0] - training_hash.append(estimator.training_hash) + assert estimator.training_size_ == X_train.shape[0] + training_hash.append(estimator.training_hash_) assert len(set(training_hash)) == len(training_hash) From 453a78a6769850e6f16b09d97106bab3c325b0d5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 31 Oct 2017 13:58:54 +0100 Subject: [PATCH 10/14] DOC move entry in the right section --- doc/whats_new/v0.20.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 3ff4ce9701e43..3d1e0634b7127 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -22,11 +22,6 @@ random sampling procedures. - :class:`neural_network.MLPRegressor` (bug fix) - :class:`neural_network.MLPClassifier` (bug fix) -- Classes derived from :class:`ensemble.BaseBagging`. The attribute - ``estimators_samples_`` will return a list of arrays containing the indices - selected for each bootstrap instead of a list of arrays containing the mask - of the samples selected for each bootstrap. - Details are listed in the changelog below. (While we are trying to better inform users by providing this information, we @@ -193,3 +188,10 @@ Cluster - Deprecate ``pooling_func`` unused parameter in :class:`cluster.AgglomerativeClustering`. :issue:`9875` by :user:`Kumar Ashutosh `. + +Ensemble + +- Classes derived from :class:`ensemble.BaseBagging`. The attribute + ``estimators_samples_`` will return a list of arrays containing the indices + selected for each bootstrap instead of a list of arrays containing the mask + of the samples selected for each bootstrap. From 1da1ee4a36031233d2681563fe219582a0e91257 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 31 Oct 2017 21:23:44 +0100 Subject: [PATCH 11/14] DOC add missing author and issue --- doc/whats_new/v0.20.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 3d1e0634b7127..1767f82ea0d17 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -194,4 +194,5 @@ Ensemble - Classes derived from :class:`ensemble.BaseBagging`. The attribute ``estimators_samples_`` will return a list of arrays containing the indices selected for each bootstrap instead of a list of arrays containing the mask - of the samples selected for each bootstrap. + of the samples selected for each bootstrap. :issue:`9524` by + :user:`Guillaume Lemaitre `. From 043b44e58d20ff34f5084c39e968f89afbf32a11 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 31 Oct 2017 22:57:01 +0100 Subject: [PATCH 12/14] DOC add details in whats new entry --- doc/whats_new/v0.20.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 1767f82ea0d17..3617067b74f83 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -194,5 +194,6 @@ Ensemble - Classes derived from :class:`ensemble.BaseBagging`. The attribute ``estimators_samples_`` will return a list of arrays containing the indices selected for each bootstrap instead of a list of arrays containing the mask - of the samples selected for each bootstrap. :issue:`9524` by + of the samples selected for each bootstrap. Indices allows to repeat samples + while mask does not allow this functionality. :issue:`9524` by :user:`Guillaume Lemaitre `. From 1e79e229e3deca9c7015a1affd033b6335ccb9b3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 17 Jul 2018 23:00:35 +0200 Subject: [PATCH 13/14] FIX add hash from joblib in utils --- sklearn/ensemble/tests/test_bagging.py | 3 +-- sklearn/utils/__init__.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 3781e0b5c2d0e..56d56c69163ff 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -34,8 +34,7 @@ from sklearn.feature_selection import SelectKBest from sklearn.model_selection import train_test_split from sklearn.datasets import load_boston, load_iris, make_hastie_10_2 -from sklearn.utils import check_random_state -from sklearn.externals.joblib import hash +from sklearn.utils import check_random_state, hash from sklearn.preprocessing import FunctionTransformer from scipy.sparse import csc_matrix, csr_matrix diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 50963341f8edf..a5c76dd97780d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -15,7 +15,7 @@ check_consistent_length, check_X_y, indexable, check_symmetric) from .class_weight import compute_class_weight, compute_sample_weight -from ._joblib import cpu_count, Parallel, Memory, delayed +from ._joblib import cpu_count, Parallel, Memory, delayed, hash from ._joblib import parallel_backend from ..exceptions import DataConversionWarning from ..utils.fixes import _Sequence as Sequence From ca56b861fb4604c25de32c09c0508ad054bf4947 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 17 Jul 2018 23:05:23 +0200 Subject: [PATCH 14/14] iter --- sklearn/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index a5c76dd97780d..56177fca7a693 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -29,7 +29,8 @@ "column_or_1d", "safe_indexing", "check_consistent_length", "check_X_y", 'indexable', "check_symmetric", "indices_to_mask", "deprecated", - "cpu_count", "Parallel", "Memory", "delayed", "parallel_backend"] + "cpu_count", "Parallel", "Memory", "delayed", "parallel_backend", + "hash"] class Bunch(dict):