Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Use check_scalar in AdaBoostRegressor #21605

Merged
40 changes: 20 additions & 20 deletions sklearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,22 @@ def fit(self, X, y, sample_weight=None):
-------
self : object
"""
# Check parameters
if self.learning_rate <= 0:
raise ValueError("learning_rate must be greater than zero")
# Validate scalar parameters
check_scalar(
self.n_estimators,
"n_estimators",
target_type=numbers.Integral,
min_val=1,
include_boundaries="left",
)

check_scalar(
self.learning_rate,
"learning_rate",
target_type=numbers.Real,
min_val=0,
include_boundaries="neither",
)

X, y = self._validate_data(
X,
Expand Down Expand Up @@ -480,22 +493,6 @@ def fit(self, X, y, sample_weight=None):
self : object
Fitted estimator.
"""
check_scalar(
self.n_estimators,
"n_estimators",
target_type=numbers.Integral,
min_val=1,
include_boundaries="left",
)

check_scalar(
self.learning_rate,
"learning_rate",
target_type=numbers.Real,
min_val=0,
include_boundaries="neither",
)

# Check that algorithm is supported
if self.algorithm not in ("SAMME", "SAMME.R"):
raise ValueError(
Expand Down Expand Up @@ -1080,7 +1077,10 @@ def fit(self, X, y, sample_weight=None):
"""
# Check loss
if self.loss not in ("linear", "square", "exponential"):
raise ValueError("loss must be 'linear', 'square', or 'exponential'")
raise ValueError(
"loss must be 'linear', 'square', or 'exponential'"
f" Got {self.loss!r} instead."
)

# Fit
return super().fit(X, y, sample_weight)
Expand Down
27 changes: 19 additions & 8 deletions sklearn/ensemble/tests/test_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ def test_importances():
def test_error():
# Test that it gives proper exception on deficient input.

reg = AdaBoostRegressor(loss="foo")
with pytest.raises(ValueError):
reg.fit(X, y_class)

clf = AdaBoostClassifier(algorithm="foo")
with pytest.raises(ValueError):
clf.fit(X, y_class)
Comment on lines +276 to +282
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reshamas: I should have been clearer.

Here for instance, there's a check for ValueErrors being raised but their error messages aren't checked.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we re-open this one? Error messages are important and helpful, correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we go: #22144


with pytest.raises(ValueError):
AdaBoostClassifier().fit(X, y_class, sample_weight=np.asarray([-1]))

Expand Down Expand Up @@ -556,17 +564,20 @@ def test_adaboostregressor_sample_weight():
),
({"learning_rate": -1}, ValueError, "learning_rate == -1, must be > 0."),
({"learning_rate": 0}, ValueError, "learning_rate == 0, must be > 0."),
(
{"algorithm": "unknown"},
ValueError,
"Algorithm must be 'SAMME' or 'SAMME.R'.",
),
],
)
def test_adaboost_classifier_params_validation(params, err_type, err_msg):
"""Check the parameters validation in `AdaBoostClassifier`."""
@pytest.mark.parametrize(
"model, X, y",
[
(AdaBoostClassifier, X, y_class),
(AdaBoostRegressor, X, y_regr),
],
)
def test_adaboost_params_validation(model, X, y, params, err_type, err_msg):
"""Check input parameter validation in weight boosting."""
est = model(**params)
with pytest.raises(err_type, match=err_msg):
AdaBoostClassifier(**params).fit(X, y_class)
est.fit(X, y)


@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
Expand Down