From 3564f20be9b46c6fb358289d887fda0fef9be6b7 Mon Sep 17 00:00:00 2001 From: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Date: Tue, 13 Feb 2024 13:08:17 +0100 Subject: [PATCH] MNT Checking function `_estimator_has` also raises `AttributeError` (#28167) Co-authored-by: Guillaume Lemaitre Co-authored-by: Adrin Jalali --- sklearn/ensemble/_stacking.py | 20 ++++++++----- sklearn/ensemble/tests/test_stacking.py | 29 +++++++++++++++++++ sklearn/feature_selection/_from_model.py | 20 ++++++++----- sklearn/feature_selection/_rfe.py | 20 ++++++++----- .../tests/test_from_model.py | 21 ++++++++++++++ sklearn/feature_selection/tests/test_rfe.py | 24 ++++++++++++++- sklearn/multiclass.py | 23 ++++++++------- sklearn/semi_supervised/_self_training.py | 22 ++++++++++---- .../tests/test_self_training.py | 25 ++++++++++++++-- sklearn/tests/test_multiclass.py | 22 ++++++++++++++ 10 files changed, 185 insertions(+), 41 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 8e27facda11df..c028e85895b14 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -45,14 +45,20 @@ def _estimator_has(attr): """Check if we can delegate a method to the underlying estimator. - First, we check the first fitted final estimator if available, otherwise we - check the unfitted final estimator. + First, we check the fitted `final_estimator_` if available, otherwise we check the + unfitted `final_estimator`. We raise the original `AttributeError` if `attr` does + not exist. This function is used together with `available_if`. """ - return lambda self: ( - hasattr(self.final_estimator_, attr) - if hasattr(self, "final_estimator_") - else hasattr(self.final_estimator, attr) - ) + + def check(self): + if hasattr(self, "final_estimator_"): + getattr(self.final_estimator_, attr) + else: + getattr(self.final_estimator, attr) + + return True + + return check class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta): diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index d15aa32077689..0d1493529e318 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -859,3 +859,32 @@ def test_stacking_classifier_base_regressor(): clf.predict(X_test) clf.predict_proba(X_test) assert clf.score(X_test, y_test) > 0.8 + + +def test_stacking_final_estimator_attribute_error(): + """Check that we raise the proper AttributeError when the final estimator + does not implement the `decision_function` method, which is decorated with + `available_if`. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/28108 + """ + X, y = make_classification(random_state=42) + + estimators = [ + ("lr", LogisticRegression()), + ("rf", RandomForestClassifier(n_estimators=2, random_state=42)), + ] + # RandomForestClassifier does not implement 'decision_function' and should raise + # an AttributeError + final_estimator = RandomForestClassifier(n_estimators=2, random_state=42) + clf = StackingClassifier( + estimators=estimators, final_estimator=final_estimator, cv=3 + ) + + outer_msg = "This 'StackingClassifier' has no attribute 'decision_function'" + inner_msg = "'RandomForestClassifier' object has no attribute 'decision_function'" + with pytest.raises(AttributeError, match=outer_msg) as exec_info: + clf.fit(X, y).decision_function(X) + assert isinstance(exec_info.value.__cause__, AttributeError) + assert inner_msg in str(exec_info.value.__cause__) diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index 45785cf29fce7..61addedd2de78 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -74,14 +74,20 @@ def _calculate_threshold(estimator, importances, threshold): def _estimator_has(attr): """Check if we can delegate a method to the underlying estimator. - First, we check the fitted estimator if available, otherwise we - check the unfitted estimator. + First, we check the fitted `estimator_` if available, otherwise we check the + unfitted `estimator`. We raise the original `AttributeError` if `attr` does + not exist. This function is used together with `available_if`. """ - return lambda self: ( - hasattr(self.estimator_, attr) - if hasattr(self, "estimator_") - else hasattr(self.estimator, attr) - ) + + def check(self): + if hasattr(self, "estimator_"): + getattr(self.estimator_, attr) + else: + getattr(self.estimator, attr) + + return True + + return check class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator): diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index 76c43669b610a..d6d1b71e08609 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -49,14 +49,20 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer): def _estimator_has(attr): """Check if we can delegate a method to the underlying estimator. - First, we check the first fitted estimator if available, otherwise we - check the unfitted estimator. + First, we check the fitted `estimator_` if available, otherwise we check the + unfitted `estimator`. We raise the original `AttributeError` if `attr` does + not exist. This function is used together with `available_if`. """ - return lambda self: ( - hasattr(self.estimator_, attr) - if hasattr(self, "estimator_") - else hasattr(self.estimator, attr) - ) + + def check(self): + if hasattr(self, "estimator_"): + getattr(self.estimator_, attr) + else: + getattr(self.estimator, attr) + + return True + + return check class RFE(_RoutingNotSupportedMixin, SelectorMixin, MetaEstimatorMixin, BaseEstimator): diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index aa802136c2f39..3573b7a078294 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -18,6 +18,7 @@ ElasticNetCV, Lasso, LassoCV, + LinearRegression, LogisticRegression, PassiveAggressiveClassifier, SGDClassifier, @@ -661,3 +662,23 @@ def test_partial_fit_validate_feature_names(as_frame): assert_array_equal(selector.feature_names_in_, X.columns) else: assert not hasattr(selector, "feature_names_in_") + + +def test_from_model_estimator_attribute_error(): + """Check that we raise the proper AttributeError when the estimator + does not implement the `partial_fit` method, which is decorated with + `available_if`. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/28108 + """ + # `LinearRegression` does not implement 'partial_fit' and should raise an + # AttributeError + from_model = SelectFromModel(estimator=LinearRegression()) + + outer_msg = "This 'SelectFromModel' has no attribute 'partial_fit'" + inner_msg = "'LinearRegression' object has no attribute 'partial_fit'" + with pytest.raises(AttributeError, match=outer_msg) as exec_info: + from_model.fit(data, y).partial_fit(data) + assert isinstance(exec_info.value.__cause__, AttributeError) + assert inner_msg in str(exec_info.value.__cause__) diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 234245ed276ec..e3edb0e7b5d21 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -15,7 +15,7 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.feature_selection import RFE, RFECV from sklearn.impute import SimpleImputer -from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.metrics import get_scorer, make_scorer, zero_one_loss from sklearn.model_selection import GroupKFold, cross_val_score from sklearn.pipeline import make_pipeline @@ -591,3 +591,25 @@ def test_rfe_pls(ClsRFE, PLSEstimator): estimator = PLSEstimator(n_components=1) selector = ClsRFE(estimator, step=1).fit(X, y) assert selector.score(X, y) > 0.5 + + +def test_rfe_estimator_attribute_error(): + """Check that we raise the proper AttributeError when the estimator + does not implement the `decision_function` method, which is decorated with + `available_if`. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/28108 + """ + iris = load_iris() + + # `LinearRegression` does not implement 'decision_function' and should raise an + # AttributeError + rfe = RFE(estimator=LinearRegression()) + + outer_msg = "This 'RFE' has no attribute 'decision_function'" + inner_msg = "'LinearRegression' object has no attribute 'decision_function'" + with pytest.raises(AttributeError, match=outer_msg) as exec_info: + rfe.fit(iris.data, iris.target).decision_function(iris.data) + assert isinstance(exec_info.value.__cause__, AttributeError) + assert inner_msg in str(exec_info.value.__cause__) diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 07b19115f0912..914aac99d82b5 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -181,12 +181,19 @@ def _estimators_has(attr): """Check if self.estimator or self.estimators_[0] has attr. If `self.estimators_[0]` has the attr, then its safe to assume that other - values has it too. This function is used together with `avaliable_if`. + estimators have it too. We raise the original `AttributeError` if `attr` + does not exist. This function is used together with `available_if`. """ - return lambda self: ( - hasattr(self.estimator, attr) - or (hasattr(self, "estimators_") and hasattr(self.estimators_[0], attr)) - ) + + def check(self): + if hasattr(self, "estimators_"): + getattr(self.estimators_[0], attr) + else: + getattr(self.estimator, attr) + + return True + + return check class OneVsRestClassifier( @@ -434,12 +441,6 @@ def partial_fit(self, X, y, classes=None, **partial_fit_params): ) if _check_partial_fit_first_call(self, classes): - if not hasattr(self.estimator, "partial_fit"): - raise ValueError( - ("Base estimator {0}, doesn't have partial_fit method").format( - self.estimator - ) - ) self.estimators_ = [clone(self.estimator) for _ in range(self.n_classes_)] # A sparse LabelBinarizer, with sparse_output=True, has been diff --git a/sklearn/semi_supervised/_self_training.py b/sklearn/semi_supervised/_self_training.py index 41a913d485263..810447c1e6f46 100644 --- a/sklearn/semi_supervised/_self_training.py +++ b/sklearn/semi_supervised/_self_training.py @@ -18,12 +18,22 @@ def _estimator_has(attr): - """Check if `self.base_estimator_ `or `self.base_estimator_` has `attr`.""" - return lambda self: ( - hasattr(self.base_estimator_, attr) - if hasattr(self, "base_estimator_") - else hasattr(self.base_estimator, attr) - ) + """Check if we can delegate a method to the underlying estimator. + + First, we check the fitted `base_estimator_` if available, otherwise we check + the unfitted `base_estimator`. We raise the original `AttributeError` if + `attr` does not exist. This function is used together with `available_if`. + """ + + def check(self): + if hasattr(self, "base_estimator_"): + getattr(self.base_estimator_, attr) + else: + getattr(self.base_estimator, attr) + + return True + + return check class SelfTrainingClassifier( diff --git a/sklearn/semi_supervised/tests/test_self_training.py b/sklearn/semi_supervised/tests/test_self_training.py index 4b2865abf7a8c..2efeb32446f89 100644 --- a/sklearn/semi_supervised/tests/test_self_training.py +++ b/sklearn/semi_supervised/tests/test_self_training.py @@ -12,6 +12,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.semi_supervised import SelfTrainingClassifier from sklearn.svm import SVC +from sklearn.tree import DecisionTreeClassifier # Author: Oliver Rausch # License: BSD 3 clause @@ -315,10 +316,30 @@ def test_base_estimator_meta_estimator(): clf.fit(X_train, y_train_missing_labels) -def test_missing_predict_proba(): - # Check that an error is thrown if predict_proba is not implemented +def test_self_training_estimator_attribute_error(): + """Check that we raise the proper AttributeErrors when the `base_estimator` + does not implement the `predict_proba` method, which is called from within + `fit`, or `decision_function`, which is decorated with `available_if`. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/28108 + """ + # `SVC` with `probability=False` does not implement 'predict_proba' that + # is required internally in `fit` of `SelfTrainingClassifier`. We expect + # an AttributeError to be raised. base_estimator = SVC(probability=False, gamma="scale") self_training = SelfTrainingClassifier(base_estimator) with pytest.raises(AttributeError, match="has no attribute 'predict_proba'"): self_training.fit(X_train, y_train_missing_labels) + + # `DecisionTreeClassifier` does not implement 'decision_function' and + # should raise an AttributeError + self_training = SelfTrainingClassifier(base_estimator=DecisionTreeClassifier()) + + outer_msg = "This 'SelfTrainingClassifier' has no attribute 'decision_function'" + inner_msg = "'DecisionTreeClassifier' object has no attribute 'decision_function'" + with pytest.raises(AttributeError, match=outer_msg) as exec_info: + self_training.fit(X_train, y_train_missing_labels).decision_function(X_train) + assert isinstance(exec_info.value.__cause__, AttributeError) + assert inner_msg in str(exec_info.value.__cause__) diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index caf7f5ae2fb49..b57d681d7ebfa 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -926,3 +926,25 @@ def test_ovo_consistent_binary_classification(): ovo.fit(X, y) assert_array_equal(clf.predict(X), ovo.predict(X)) + + +def test_multiclass_estimator_attribute_error(): + """Check that we raise the proper AttributeError when the final estimator + does not implement the `partial_fit` method, which is decorated with + `available_if`. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/28108 + """ + iris = datasets.load_iris() + + # LogisticRegression does not implement 'partial_fit' and should raise an + # AttributeError + clf = OneVsRestClassifier(estimator=LogisticRegression(random_state=42)) + + outer_msg = "This 'OneVsRestClassifier' has no attribute 'partial_fit'" + inner_msg = "'LogisticRegression' object has no attribute 'partial_fit'" + with pytest.raises(AttributeError, match=outer_msg) as exec_info: + clf.partial_fit(iris.data, iris.target) + assert isinstance(exec_info.value.__cause__, AttributeError) + assert inner_msg in str(exec_info.value.__cause__)