Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] EHN Add bootstrap sample size limit to forest ensembles #14682

Merged
merged 31 commits into from Sep 20, 2019
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9b66811
Add max_samples bootstrap size kwarg
notmatthancock Aug 18, 2019
b845709
Refactor unit tests
notmatthancock Aug 18, 2019
6db3384
Add one more assert check in index helper test
notmatthancock Aug 18, 2019
2507d5b
Move validation and bootstrap size get to helper
notmatthancock Aug 18, 2019
4fa10a9
Compute bootstrap size just once
notmatthancock Aug 18, 2019
3eb299c
n_bootstrap_samples -> n_samples_bootstrap; update docstring
notmatthancock Aug 20, 2019
081b7b7
Refactor exception tests for max_samples
notmatthancock Aug 20, 2019
cf8b0cb
n_bootstrap_samples -> n_samples_bootstrap in signatures
notmatthancock Aug 20, 2019
49a949a
Add max_samples kwarg to RandomForestRegressor
notmatthancock Aug 20, 2019
0e2b386
Revert doc default
notmatthancock Aug 20, 2019
ef5efe1
Move n_samples_bootstrap out of loop
notmatthancock Aug 21, 2019
07865fd
Merge branch 'master' into feature/rf-subsample
notmatthancock Aug 21, 2019
5bc7796
Add max_samples to RandomTreesEmbedding
notmatthancock Aug 21, 2019
75ce338
Change docstring style for default
notmatthancock Aug 27, 2019
37104d3
Update grammar in docstring
notmatthancock Aug 27, 2019
93e1840
Refactor conditional structures
notmatthancock Aug 27, 2019
005d6ca
Add version added tag; change call style
notmatthancock Aug 27, 2019
eac5ad6
Remove docstring from unit test
notmatthancock Aug 27, 2019
d9de541
Add exception message checks to unit test
notmatthancock Aug 27, 2019
93e5799
Refactor toy data unit test
notmatthancock Aug 27, 2019
9cdd5c2
Add node count check unit test
notmatthancock Aug 27, 2019
c42db3b
Merge remote-tracking branch 'origin/master' into pr/notmatthancock/1…
glemaitre Sep 9, 2019
9d72d36
Limit unit test to RandomForest*
notmatthancock Sep 10, 2019
1622ad6
Merge branch 'feature/rf-subsample' of github.com:notmatthancock/scik…
notmatthancock Sep 10, 2019
c95cecf
Add whats new entry
notmatthancock Sep 11, 2019
fa2be88
Merge branch 'master' into feature/rf-subsample
notmatthancock Sep 11, 2019
4548191
Include bootstrap condition in docstring comments
notmatthancock Sep 12, 2019
58f005c
Rename forest_class -> ForestClass
notmatthancock Sep 12, 2019
7869f87
Escape all the special chars
notmatthancock Sep 13, 2019
e8535d4
Remove extraneous test
notmatthancock Sep 14, 2019
5b72dcc
comments adrin
glemaitre Sep 20, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
161 changes: 137 additions & 24 deletions sklearn/ensemble/forest.py
Expand Up @@ -40,6 +40,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
# License: BSD 3 clause


import numbers
from warnings import catch_warnings, simplefilter, warn
import threading

Expand Down Expand Up @@ -72,17 +73,56 @@ class calls the ``fit`` method of each sub-estimator on random samples
MAX_INT = np.iinfo(np.int32).max


def _generate_sample_indices(random_state, n_samples):
def _get_n_samples_bootstrap(n_samples, max_samples):
"""Get the number of samples in a bootstrap sample.

Parameters
----------
n_samples : int
Number of samples in the dataset.
max_samples : int or float
The maximum number of samples to draw from the total available:
- if float, this indicates a fraction of the total;
- if int, this indicates the exact number of samples;
- if None, this indicates the total number of samples.

Returns
-------
n_samples_bootstrap : int
The total number of samples to draw for the bootstrap sample.
"""
if max_samples is None:
notmatthancock marked this conversation as resolved.
Show resolved Hide resolved
return n_samples

if isinstance(max_samples, numbers.Integral):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be consistent w.r.t how we treat these fractions. For instance, in optics, we have:

def _validate_size(size, n_samples, param_name):
if size <= 0 or (size !=
int(size)
and size > 1):
raise ValueError('%s must be a positive integer '
'or a float between 0 and 1. Got %r' %
(param_name, size))
elif size > n_samples:
raise ValueError('%s must be no greater than the'
' number of samples (%d). Got %d' %
(param_name, n_samples, size))

And then 1 always means 100% of the data, at least in optics. Do we have a similar semantics in other places?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With PCA, n_components=1 means 1 components while n_components<1 will be a percentage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excluding 1 from float avoid issue with float comparison as well.

if not (1 <= max_samples <= n_samples):
msg = "`max_samples` must be in range 1 to {} but got value {}"
raise ValueError(msg.format(n_samples, max_samples))
return max_samples

if isinstance(max_samples, numbers.Real):
if not (0 < max_samples < 1):
msg = "`max_samples` must be in range (0, 1) but got value {}"
raise ValueError(msg.format(max_samples))
return int(round(n_samples * max_samples))

msg = "`max_samples` should be int or float, but got type '{}'"
raise TypeError(msg.format(type(max_samples)))


def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap):
"""Private function used to _parallel_build_trees function."""

random_instance = check_random_state(random_state)
sample_indices = random_instance.randint(0, n_samples, n_samples)
sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap)

return sample_indices


def _generate_unsampled_indices(random_state, n_samples):
def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap):
"""Private function used to forest._set_oob_score function."""
sample_indices = _generate_sample_indices(random_state, n_samples)
sample_indices = _generate_sample_indices(random_state, n_samples,
n_samples_bootstrap)
sample_counts = np.bincount(sample_indices, minlength=n_samples)
unsampled_mask = sample_counts == 0
indices_range = np.arange(n_samples)
Expand All @@ -92,7 +132,8 @@ def _generate_unsampled_indices(random_state, n_samples):


def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
verbose=0, class_weight=None):
verbose=0, class_weight=None,
n_samples_bootstrap=None):
"""Private function used to fit a single tree in parallel."""
if verbose > 1:
print("building tree %d of %d" % (tree_idx + 1, n_trees))
Expand All @@ -104,7 +145,8 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
else:
curr_sample_weight = sample_weight.copy()

indices = _generate_sample_indices(tree.random_state, n_samples)
indices = _generate_sample_indices(tree.random_state, n_samples,
n_samples_bootstrap)
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts

Expand Down Expand Up @@ -140,7 +182,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
max_samples=None):
super().__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
Expand All @@ -153,6 +196,7 @@ def __init__(self,
self.verbose = verbose
self.warm_start = warm_start
self.class_weight = class_weight
self.max_samples = max_samples

def apply(self, X):
"""Apply trees in the forest to X, return leaf indices.
Expand Down Expand Up @@ -277,6 +321,12 @@ def fit(self, X, y, sample_weight=None):
else:
sample_weight = expanded_class_weight

# Get bootstrap sample size
n_samples_bootstrap = _get_n_samples_bootstrap(
n_samples=X.shape[0],
max_samples=self.max_samples
)

# Check parameters
self._validate_estimator()

Expand Down Expand Up @@ -320,7 +370,8 @@ def fit(self, X, y, sample_weight=None):
**_joblib_parallel_args(prefer='threads'))(
delayed(_parallel_build_trees)(
t, self, X, y, sample_weight, i, len(trees),
verbose=self.verbose, class_weight=self.class_weight)
verbose=self.verbose, class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap)
for i, t in enumerate(trees))

# Collect newly grown trees
Expand Down Expand Up @@ -410,7 +461,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
max_samples=None):
super().__init__(
base_estimator,
n_estimators=n_estimators,
Expand All @@ -421,7 +473,8 @@ def __init__(self,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
max_samples=max_samples)

def _set_oob_score(self, X, y):
"""Compute out-of-bag score"""
Expand All @@ -435,9 +488,13 @@ def _set_oob_score(self, X, y):
predictions = [np.zeros((n_samples, n_classes_[k]))
for k in range(self.n_outputs_)]

n_samples_bootstrap = _get_n_samples_bootstrap(
n_samples, self.max_samples
)

for estimator in self.estimators_:
unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
estimator.random_state, n_samples, n_samples_bootstrap)
p_estimator = estimator.predict_proba(X[unsampled_indices, :],
check_input=False)

Expand Down Expand Up @@ -650,7 +707,8 @@ def __init__(self,
n_jobs=None,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
max_samples=None):
super().__init__(
base_estimator,
n_estimators=n_estimators,
Expand All @@ -660,7 +718,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
max_samples=max_samples)

def predict(self, X):
"""Predict regression target for X.
Expand Down Expand Up @@ -713,9 +772,13 @@ def _set_oob_score(self, X, y):
predictions = np.zeros((n_samples, self.n_outputs_))
n_predictions = np.zeros((n_samples, self.n_outputs_))

n_samples_bootstrap = _get_n_samples_bootstrap(
n_samples, self.max_samples
)

for estimator in self.estimators_:
unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
estimator.random_state, n_samples, n_samples_bootstrap)
p_estimator = estimator.predict(
X[unsampled_indices, :], check_input=False)

Expand Down Expand Up @@ -923,6 +986,14 @@ class RandomForestClassifier(ForestClassifier):

.. versionadded:: 0.22

max_samples : int or float, default=None
The number of samples to draw from X to train each base estimator.
notmatthancock marked this conversation as resolved.
Show resolved Hide resolved
- If None (default), then draw `X.shape[0]` samples.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
notmatthancock marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 0.22

Attributes
----------
base_estimator_ : DecisionTreeClassifier
Expand Down Expand Up @@ -1016,7 +1087,8 @@ def __init__(self,
verbose=0,
warm_start=False,
class_weight=None,
ccp_alpha=0.0):
ccp_alpha=0.0,
max_samples=None):
super().__init__(
base_estimator=DecisionTreeClassifier(),
n_estimators=n_estimators,
Expand All @@ -1031,7 +1103,8 @@ def __init__(self,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1198,6 +1271,14 @@ class RandomForestRegressor(ForestRegressor):

.. versionadded:: 0.22

max_samples : int or float, default=None
The number of samples to draw from X to train each base estimator.
- If None (default), then draw `X.shape[0]` samples.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
notmatthancock marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 0.22

Attributes
----------
base_estimator_ : DecisionTreeRegressor
Expand Down Expand Up @@ -1285,7 +1366,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
ccp_alpha=0.0):
ccp_alpha=0.0,
max_samples=None):
super().__init__(
base_estimator=DecisionTreeRegressor(),
n_estimators=n_estimators,
Expand All @@ -1299,7 +1381,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1484,6 +1567,14 @@ class ExtraTreesClassifier(ForestClassifier):

.. versionadded:: 0.22

max_samples : int or float, default=None
The number of samples to draw from X to train each base estimator.
- If None (default), then draw `X.shape[0]` samples.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the behavior of the float, but we need to document that here, i.e. explicitly say that if float, it must be in (0, 1)


.. versionadded:: 0.22

Attributes
----------
base_estimator_ : ExtraTreeClassifier
Expand Down Expand Up @@ -1557,7 +1648,8 @@ def __init__(self,
verbose=0,
warm_start=False,
class_weight=None,
ccp_alpha=0.0):
ccp_alpha=0.0,
max_samples=None):
super().__init__(
base_estimator=ExtraTreeClassifier(),
n_estimators=n_estimators,
Expand All @@ -1572,7 +1664,8 @@ def __init__(self,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1736,6 +1829,14 @@ class ExtraTreesRegressor(ForestRegressor):

.. versionadded:: 0.22

max_samples : int or float, default=None
The number of samples to draw from X to train each base estimator.
- If None (default), then draw `X.shape[0]` samples.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.

.. versionadded:: 0.22

Attributes
----------
base_estimator_ : ExtraTreeRegressor
Expand Down Expand Up @@ -1796,7 +1897,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
ccp_alpha=0.0):
ccp_alpha=0.0,
max_samples=None):
super().__init__(
base_estimator=ExtraTreeRegressor(),
n_estimators=n_estimators,
Expand All @@ -1810,7 +1912,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1951,6 +2054,14 @@ class RandomTreesEmbedding(BaseForest):

.. versionadded:: 0.22

max_samples : int or float, default=None
The number of samples to draw from X to train each base estimator.
- If None (default), then draw `X.shape[0]` samples.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
notmatthancock marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 0.22

Attributes
----------
estimators_ : list of DecisionTreeClassifier
Expand Down Expand Up @@ -1983,7 +2094,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
ccp_alpha=0.0):
ccp_alpha=0.0,
max_samples=None):
super().__init__(
base_estimator=ExtraTreeRegressor(),
n_estimators=n_estimators,
Expand All @@ -1997,7 +2109,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
max_samples=max_samples)

self.max_depth = max_depth
self.min_samples_split = min_samples_split
Expand Down