Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] FIX raise error for max_samples if no bootstrap & optimize forest tests #21295

Merged
merged 11 commits into from Nov 25, 2021
7 changes: 7 additions & 0 deletions doc/whats_new/v1.1.rst
Expand Up @@ -53,6 +53,13 @@ Changelog
:class:`ensemble.HistGradientBoostingRegressor`.
:pr:`21130` :user:`Christian Lorentzen <lorentzenchr>`.

- |Fix| :class:`ensemble.RandomForestClassifier`,
:class:`ensemble.RandomForestRegressor`,
:class:`ensemble.ExtraTreesClassifier`, :class:`ensemble.ExtraTreesRegressor`,
and :class:`ensemble.RandomTreesEmbedding` now raise a ``ValueError`` when
``bootstrap=False`` and ``max_samples`` is not ``None``.
:pr:`21295` :user:`Haoyin Xu <PSSF23>`.

:mod:`sklearn.linear_model`
...........................

Expand Down
3 changes: 3 additions & 0 deletions sklearn/ensemble/_forest.py
Expand Up @@ -374,6 +374,9 @@ def fit(self, X, y, sample_weight=None):
else:
sample_weight = expanded_class_weight

if not self.bootstrap and self.max_samples is not None:
raise ValueError("Sub-sample size only available if bootstrap=True")
PSSF23 marked this conversation as resolved.
Show resolved Hide resolved

# Get bootstrap sample size
n_samples_bootstrap = _get_n_samples_bootstrap(
n_samples=X.shape[0], max_samples=self.max_samples
Expand Down
28 changes: 23 additions & 5 deletions sklearn/ensemble/tests/test_forest.py
Expand Up @@ -1613,6 +1613,16 @@ def test_forest_degenerate_feature_importances():
assert_array_equal(gbr.feature_importances_, np.zeros(10, dtype=np.float64))


@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS)
def test_max_samples_bootstrap(name):
# Check invalid `max_samples` values
est = FOREST_CLASSIFIERS_REGRESSORS[name](bootstrap=False, max_samples=0.5)
with pytest.raises(
ValueError, match=r"Sub-sample size only available if bootstrap=True"
):
est.fit(X, y)


@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS)
@pytest.mark.parametrize(
"max_samples, exc_type, exc_msg",
Expand Down Expand Up @@ -1657,7 +1667,7 @@ def test_forest_degenerate_feature_importances():
)
def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg):
# Check invalid `max_samples` values
est = FOREST_CLASSIFIERS_REGRESSORS[name](max_samples=max_samples)
est = FOREST_CLASSIFIERS_REGRESSORS[name](bootstrap=True, max_samples=max_samples)
with pytest.raises(exc_type, match=exc_msg):
est.fit(X, y)

Expand All @@ -1668,10 +1678,14 @@ def test_max_samples_boundary_regressors(name):
X_reg, y_reg, train_size=0.7, test_size=0.3, random_state=0
)

ms_1_model = FOREST_REGRESSORS[name](max_samples=1.0, random_state=0)
ms_1_model = FOREST_REGRESSORS[name](
bootstrap=True, max_samples=1.0, random_state=0
)
ms_1_predict = ms_1_model.fit(X_train, y_train).predict(X_test)

ms_None_model = FOREST_REGRESSORS[name](max_samples=None, random_state=0)
ms_None_model = FOREST_REGRESSORS[name](
bootstrap=True, max_samples=None, random_state=0
PSSF23 marked this conversation as resolved.
Show resolved Hide resolved
)
ms_None_predict = ms_None_model.fit(X_train, y_train).predict(X_test)

ms_1_ms = mean_squared_error(ms_1_predict, y_test)
Expand All @@ -1686,10 +1700,14 @@ def test_max_samples_boundary_classifiers(name):
X_large, y_large, random_state=0, stratify=y_large
)

ms_1_model = FOREST_CLASSIFIERS[name](max_samples=1.0, random_state=0)
ms_1_model = FOREST_CLASSIFIERS[name](
bootstrap=True, max_samples=1.0, random_state=0
)
ms_1_proba = ms_1_model.fit(X_train, y_train).predict_proba(X_test)

ms_None_model = FOREST_CLASSIFIERS[name](max_samples=None, random_state=0)
ms_None_model = FOREST_CLASSIFIERS[name](
bootstrap=True, max_samples=None, random_state=0
)
ms_None_proba = ms_None_model.fit(X_train, y_train).predict_proba(X_test)

np.testing.assert_allclose(ms_1_proba, ms_None_proba)
Expand Down