Skip to content

Commit

Permalink
MNT Checking function _estimator_has also raises AttributeError (#…
Browse files Browse the repository at this point in the history
…28167)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
  • Loading branch information
3 people committed Feb 13, 2024
1 parent b2e231e commit 3564f20
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 41 deletions.
20 changes: 13 additions & 7 deletions sklearn/ensemble/_stacking.py
Expand Up @@ -45,14 +45,20 @@
def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the first fitted final estimator if available, otherwise we
check the unfitted final estimator.
First, we check the fitted `final_estimator_` if available, otherwise we check the
unfitted `final_estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""
return lambda self: (
hasattr(self.final_estimator_, attr)
if hasattr(self, "final_estimator_")
else hasattr(self.final_estimator, attr)
)

def check(self):
if hasattr(self, "final_estimator_"):
getattr(self.final_estimator_, attr)
else:
getattr(self.final_estimator, attr)

return True

return check


class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta):
Expand Down
29 changes: 29 additions & 0 deletions sklearn/ensemble/tests/test_stacking.py
Expand Up @@ -859,3 +859,32 @@ def test_stacking_classifier_base_regressor():
clf.predict(X_test)
clf.predict_proba(X_test)
assert clf.score(X_test, y_test) > 0.8


def test_stacking_final_estimator_attribute_error():
"""Check that we raise the proper AttributeError when the final estimator
does not implement the `decision_function` method, which is decorated with
`available_if`.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/28108
"""
X, y = make_classification(random_state=42)

estimators = [
("lr", LogisticRegression()),
("rf", RandomForestClassifier(n_estimators=2, random_state=42)),
]
# RandomForestClassifier does not implement 'decision_function' and should raise
# an AttributeError
final_estimator = RandomForestClassifier(n_estimators=2, random_state=42)
clf = StackingClassifier(
estimators=estimators, final_estimator=final_estimator, cv=3
)

outer_msg = "This 'StackingClassifier' has no attribute 'decision_function'"
inner_msg = "'RandomForestClassifier' object has no attribute 'decision_function'"
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
clf.fit(X, y).decision_function(X)
assert isinstance(exec_info.value.__cause__, AttributeError)
assert inner_msg in str(exec_info.value.__cause__)
20 changes: 13 additions & 7 deletions sklearn/feature_selection/_from_model.py
Expand Up @@ -74,14 +74,20 @@ def _calculate_threshold(estimator, importances, threshold):
def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the fitted estimator if available, otherwise we
check the unfitted estimator.
First, we check the fitted `estimator_` if available, otherwise we check the
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""
return lambda self: (
hasattr(self.estimator_, attr)
if hasattr(self, "estimator_")
else hasattr(self.estimator, attr)
)

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
Expand Down
20 changes: 13 additions & 7 deletions sklearn/feature_selection/_rfe.py
Expand Up @@ -49,14 +49,20 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer):
def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the first fitted estimator if available, otherwise we
check the unfitted estimator.
First, we check the fitted `estimator_` if available, otherwise we check the
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""
return lambda self: (
hasattr(self.estimator_, attr)
if hasattr(self, "estimator_")
else hasattr(self.estimator, attr)
)

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class RFE(_RoutingNotSupportedMixin, SelectorMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down
21 changes: 21 additions & 0 deletions sklearn/feature_selection/tests/test_from_model.py
Expand Up @@ -18,6 +18,7 @@
ElasticNetCV,
Lasso,
LassoCV,
LinearRegression,
LogisticRegression,
PassiveAggressiveClassifier,
SGDClassifier,
Expand Down Expand Up @@ -661,3 +662,23 @@ def test_partial_fit_validate_feature_names(as_frame):
assert_array_equal(selector.feature_names_in_, X.columns)
else:
assert not hasattr(selector, "feature_names_in_")


def test_from_model_estimator_attribute_error():
"""Check that we raise the proper AttributeError when the estimator
does not implement the `partial_fit` method, which is decorated with
`available_if`.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/28108
"""
# `LinearRegression` does not implement 'partial_fit' and should raise an
# AttributeError
from_model = SelectFromModel(estimator=LinearRegression())

outer_msg = "This 'SelectFromModel' has no attribute 'partial_fit'"
inner_msg = "'LinearRegression' object has no attribute 'partial_fit'"
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
from_model.fit(data, y).partial_fit(data)
assert isinstance(exec_info.value.__cause__, AttributeError)
assert inner_msg in str(exec_info.value.__cause__)
24 changes: 23 additions & 1 deletion sklearn/feature_selection/tests/test_rfe.py
Expand Up @@ -15,7 +15,7 @@
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import RFE, RFECV
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import get_scorer, make_scorer, zero_one_loss
from sklearn.model_selection import GroupKFold, cross_val_score
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -591,3 +591,25 @@ def test_rfe_pls(ClsRFE, PLSEstimator):
estimator = PLSEstimator(n_components=1)
selector = ClsRFE(estimator, step=1).fit(X, y)
assert selector.score(X, y) > 0.5


def test_rfe_estimator_attribute_error():
"""Check that we raise the proper AttributeError when the estimator
does not implement the `decision_function` method, which is decorated with
`available_if`.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/28108
"""
iris = load_iris()

# `LinearRegression` does not implement 'decision_function' and should raise an
# AttributeError
rfe = RFE(estimator=LinearRegression())

outer_msg = "This 'RFE' has no attribute 'decision_function'"
inner_msg = "'LinearRegression' object has no attribute 'decision_function'"
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
rfe.fit(iris.data, iris.target).decision_function(iris.data)
assert isinstance(exec_info.value.__cause__, AttributeError)
assert inner_msg in str(exec_info.value.__cause__)
23 changes: 12 additions & 11 deletions sklearn/multiclass.py
Expand Up @@ -181,12 +181,19 @@ def _estimators_has(attr):
"""Check if self.estimator or self.estimators_[0] has attr.
If `self.estimators_[0]` has the attr, then its safe to assume that other
values has it too. This function is used together with `avaliable_if`.
estimators have it too. We raise the original `AttributeError` if `attr`
does not exist. This function is used together with `available_if`.
"""
return lambda self: (
hasattr(self.estimator, attr)
or (hasattr(self, "estimators_") and hasattr(self.estimators_[0], attr))
)

def check(self):
if hasattr(self, "estimators_"):
getattr(self.estimators_[0], attr)
else:
getattr(self.estimator, attr)

return True

return check


class OneVsRestClassifier(
Expand Down Expand Up @@ -434,12 +441,6 @@ def partial_fit(self, X, y, classes=None, **partial_fit_params):
)

if _check_partial_fit_first_call(self, classes):
if not hasattr(self.estimator, "partial_fit"):
raise ValueError(
("Base estimator {0}, doesn't have partial_fit method").format(
self.estimator
)
)
self.estimators_ = [clone(self.estimator) for _ in range(self.n_classes_)]

# A sparse LabelBinarizer, with sparse_output=True, has been
Expand Down
22 changes: 16 additions & 6 deletions sklearn/semi_supervised/_self_training.py
Expand Up @@ -18,12 +18,22 @@


def _estimator_has(attr):
"""Check if `self.base_estimator_ `or `self.base_estimator_` has `attr`."""
return lambda self: (
hasattr(self.base_estimator_, attr)
if hasattr(self, "base_estimator_")
else hasattr(self.base_estimator, attr)
)
"""Check if we can delegate a method to the underlying estimator.
First, we check the fitted `base_estimator_` if available, otherwise we check
the unfitted `base_estimator`. We raise the original `AttributeError` if
`attr` does not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "base_estimator_"):
getattr(self.base_estimator_, attr)
else:
getattr(self.base_estimator, attr)

return True

return check


class SelfTrainingClassifier(
Expand Down
25 changes: 23 additions & 2 deletions sklearn/semi_supervised/tests/test_self_training.py
Expand Up @@ -12,6 +12,7 @@
from sklearn.neighbors import KNeighborsClassifier
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

# Author: Oliver Rausch <rauscho@ethz.ch>
# License: BSD 3 clause
Expand Down Expand Up @@ -315,10 +316,30 @@ def test_base_estimator_meta_estimator():
clf.fit(X_train, y_train_missing_labels)


def test_missing_predict_proba():
# Check that an error is thrown if predict_proba is not implemented
def test_self_training_estimator_attribute_error():
"""Check that we raise the proper AttributeErrors when the `base_estimator`
does not implement the `predict_proba` method, which is called from within
`fit`, or `decision_function`, which is decorated with `available_if`.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/28108
"""
# `SVC` with `probability=False` does not implement 'predict_proba' that
# is required internally in `fit` of `SelfTrainingClassifier`. We expect
# an AttributeError to be raised.
base_estimator = SVC(probability=False, gamma="scale")
self_training = SelfTrainingClassifier(base_estimator)

with pytest.raises(AttributeError, match="has no attribute 'predict_proba'"):
self_training.fit(X_train, y_train_missing_labels)

# `DecisionTreeClassifier` does not implement 'decision_function' and
# should raise an AttributeError
self_training = SelfTrainingClassifier(base_estimator=DecisionTreeClassifier())

outer_msg = "This 'SelfTrainingClassifier' has no attribute 'decision_function'"
inner_msg = "'DecisionTreeClassifier' object has no attribute 'decision_function'"
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
self_training.fit(X_train, y_train_missing_labels).decision_function(X_train)
assert isinstance(exec_info.value.__cause__, AttributeError)
assert inner_msg in str(exec_info.value.__cause__)
22 changes: 22 additions & 0 deletions sklearn/tests/test_multiclass.py
Expand Up @@ -926,3 +926,25 @@ def test_ovo_consistent_binary_classification():
ovo.fit(X, y)

assert_array_equal(clf.predict(X), ovo.predict(X))


def test_multiclass_estimator_attribute_error():
"""Check that we raise the proper AttributeError when the final estimator
does not implement the `partial_fit` method, which is decorated with
`available_if`.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/28108
"""
iris = datasets.load_iris()

# LogisticRegression does not implement 'partial_fit' and should raise an
# AttributeError
clf = OneVsRestClassifier(estimator=LogisticRegression(random_state=42))

outer_msg = "This 'OneVsRestClassifier' has no attribute 'partial_fit'"
inner_msg = "'LogisticRegression' object has no attribute 'partial_fit'"
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
clf.partial_fit(iris.data, iris.target)
assert isinstance(exec_info.value.__cause__, AttributeError)
assert inner_msg in str(exec_info.value.__cause__)

0 comments on commit 3564f20

Please sign in to comment.