Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] FIX enforce deterministic behaviour in BaseBagging #9723

Merged
merged 20 commits into from Jul 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions doc/whats_new/v0.20.rst
Expand Up @@ -654,6 +654,13 @@ Metrics
:issue:`9515` by :user:`Alan Liddell <aliddell>` and
:user:`Manh Dao <manhdao>`.

Ensemble

- Fix allowing to obtain deterministic with :class:`BaseBagging` estimator,
when comparing results generated at fit time with the one using the object
attributes when ``random_state`` is set. :issue:`9723` by :user:`Guillaume
Lemaitre <glemaitre>`.

Neighbors

- Fixed a bug so ``predict`` in :class:`neighbors.RadiusNeighborsRegressor` can
Expand Down Expand Up @@ -828,6 +835,15 @@ Cluster
:class:`cluster.AgglomerativeClustering`. :issue:`9875` by :user:`Kumar Ashutosh
<thechargedneutron>`.

Ensemble

- Classes derived from :class:`ensemble.BaseBagging`. The attribute
``estimators_samples_`` will return a list of arrays containing the indices
selected for each bootstrap instead of a list of arrays containing the mask
of the samples selected for each bootstrap. Indices allows to repeat samples
while mask does not allow this functionality. :issue:`9524` by
:user:`Guillaume Lemaitre <glemaitre>`.

Imputer

- Deprecate :class:`preprocessing.Imputer` and move the corresponding module to
Expand Down
19 changes: 7 additions & 12 deletions sklearn/ensemble/bagging.py
Expand Up @@ -110,7 +110,6 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,

estimator.fit(X[:, features], y, sample_weight=curr_sample_weight)

# Draw samples, using a mask, and then fit
else:
estimator.fit((X[indices])[:, features], y[indices])

Expand Down Expand Up @@ -412,20 +411,16 @@ def _get_estimators_indices(self):
def estimators_samples_(self):
"""The subset of drawn samples for each base estimator.

Returns a dynamically generated list of boolean masks identifying
Returns a dynamically generated list of indices identifying
the samples used for fitting each member of the ensemble, i.e.,
the in-bag samples.

Note: the list is re-created at each call to the property in order
to reduce the object memory footprint by not storing the sampling
data. Thus fetching the property may be slower than expected.
"""
sample_masks = []
for _, sample_indices in self._get_estimators_indices():
mask = indices_to_mask(sample_indices, self._n_samples)
sample_masks.append(mask)

return sample_masks
return [sample_indices
for _, sample_indices in self._get_estimators_indices()]


class BaggingClassifier(BaseBagging, ClassifierMixin):
Expand Down Expand Up @@ -512,7 +507,7 @@ class BaggingClassifier(BaseBagging, ClassifierMixin):

estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by a boolean mask.
estimator. Each subset is defined by an array of the indices selected.

estimators_features_ : list of arrays
The subset of drawn features for each base estimator.
Expand Down Expand Up @@ -590,7 +585,7 @@ def _set_oob_score(self, X, y):
self.estimators_samples_,
self.estimators_features_):
# Create mask for OOB samples
mask = ~samples
mask = ~indices_to_mask(samples, n_samples)

if hasattr(estimator, "predict_proba"):
predictions[mask, :] += estimator.predict_proba(
Expand Down Expand Up @@ -885,7 +880,7 @@ class BaggingRegressor(BaseBagging, RegressorMixin):

estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by a boolean mask.
estimator. Each subset is defined by an array of the indices selected.

estimators_features_ : list of arrays
The subset of drawn features for each base estimator.
Expand Down Expand Up @@ -996,7 +991,7 @@ def _set_oob_score(self, X, y):
self.estimators_samples_,
self.estimators_features_):
# Create mask for OOB samples
mask = ~samples
mask = ~indices_to_mask(samples, n_samples)

predictions[mask] += estimator.predict((X[mask, :])[:, features])
n_predictions[mask] += 1
Expand Down
53 changes: 50 additions & 3 deletions sklearn/ensemble/tests/test_bagging.py
Expand Up @@ -29,11 +29,12 @@
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.svm import SVC, SVR
from sklearn.random_projection import SparseRandomProjection
from sklearn.pipeline import make_pipeline
from sklearn.feature_selection import SelectKBest
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
from sklearn.utils import check_random_state
from sklearn.utils import check_random_state, hash
from sklearn.preprocessing import FunctionTransformer

from scipy.sparse import csc_matrix, csr_matrix
Expand Down Expand Up @@ -222,6 +223,13 @@ def fit(self, X, y):
assert_array_almost_equal(sparse_results, dense_results)


class DummySizeEstimator(BaseEstimator):

def fit(self, X, y):
self.training_size_ = X.shape[0]
self.training_hash_ = hash(X)


def test_bootstrap_samples():
# Test that bootstrapping samples generate non-perfect base estimators.
rng = check_random_state(0)
Expand Down Expand Up @@ -249,6 +257,17 @@ def test_bootstrap_samples():
assert_greater(base_estimator.score(X_train, y_train),
ensemble.score(X_train, y_train))

# check that each sampling correspond to a complete bootstrap resample.
# the size of each bootstrap should be the same as the input data but
# the data should be different (checked using the hash of the data).
ensemble = BaggingRegressor(base_estimator=DummySizeEstimator(),
bootstrap=True).fit(X_train, y_train)
training_hash = []
for estimator in ensemble.estimators_:
assert estimator.training_size_ == X_train.shape[0]
training_hash.append(estimator.training_hash_)
assert len(set(training_hash)) == len(training_hash)


def test_bootstrap_features():
# Test that bootstrapping features may generate duplicate features.
Expand Down Expand Up @@ -710,8 +729,8 @@ def test_estimators_samples():

# Test for correct formatting
assert_equal(len(estimators_samples), len(estimators))
assert_equal(len(estimators_samples[0]), len(X))
assert_equal(estimators_samples[0].dtype.kind, 'b')
assert_equal(len(estimators_samples[0]), len(X) // 2)
assert_equal(estimators_samples[0].dtype.kind, 'i')

# Re-fit single estimator to test for consistent sampling
estimator_index = 0
Expand All @@ -729,6 +748,34 @@ def test_estimators_samples():
assert_array_almost_equal(orig_coefs, new_coefs)


def test_estimators_samples_deterministic():
# This test is a regression test to check that with a random step
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth mentioning the issue #9524

# (e.g. SparseRandomProjection) and a given random state, the results
# generated at fit time can be identically reproduced at a later time using
# data saved in object attributes. Check issue #9524 for full discussion.

iris = load_iris()
X, y = iris.data, iris.target

base_pipeline = make_pipeline(SparseRandomProjection(n_components=2),
LogisticRegression())
clf = BaggingClassifier(base_estimator=base_pipeline,
max_samples=0.5,
random_state=0)
clf.fit(X, y)
pipeline_estimator_coef = clf.estimators_[0].steps[-1][1].coef_.copy()

estimator = clf.estimators_[0]
estimator_sample = clf.estimators_samples_[0]
estimator_feature = clf.estimators_features_[0]

X_train = (X[estimator_sample])[:, estimator_feature]
y_train = y[estimator_sample]

estimator.fit(X_train, y_train)
assert_array_equal(estimator.steps[-1][1].coef_, pipeline_estimator_coef)


def test_max_samples_consistency():
# Make sure validated max_samples and original max_samples are identical
# when valid integer max_samples supplied by user
Expand Down
5 changes: 3 additions & 2 deletions sklearn/utils/__init__.py
Expand Up @@ -15,7 +15,7 @@
check_consistent_length, check_X_y, indexable,
check_symmetric)
from .class_weight import compute_class_weight, compute_sample_weight
from ._joblib import cpu_count, Parallel, Memory, delayed
from ._joblib import cpu_count, Parallel, Memory, delayed, hash
from ._joblib import parallel_backend
from ..exceptions import DataConversionWarning
from ..utils.fixes import _Sequence as Sequence
Expand All @@ -29,7 +29,8 @@
"column_or_1d", "safe_indexing",
"check_consistent_length", "check_X_y", 'indexable',
"check_symmetric", "indices_to_mask", "deprecated",
"cpu_count", "Parallel", "Memory", "delayed", "parallel_backend"]
"cpu_count", "Parallel", "Memory", "delayed", "parallel_backend",
"hash"]


class Bunch(dict):
Expand Down