FIX pipeline now checks if it's fitted#29868
Conversation
|
|
||
| def _check_is_fitted(pipeline): | ||
| try: | ||
| check_is_fitted(pipeline) |
There was a problem hiding this comment.
I'm wondering if we should not also modify check_is_fitted to be lenient with stateless estimator. Right now, one could expect to implement __sklearn_is_fitted__ but I don't think this is part of the API per se. So I'm wondering if check_is_fitted should look at the tag requires_fit?
glemaitre
left a comment
There was a problem hiding this comment.
A couple of comments. But it looks good.
| Result of calling `predict` on the final estimator. | ||
| """ | ||
| Xt = X | ||
| with _handle_warnings(self): |
There was a problem hiding this comment.
It might be worth to also add some TODO next to each context manager to have more occurences.
There was a problem hiding this comment.
Maybe having a decorator instead of a context manager would avoid the extra indentation?
There was a problem hiding this comment.
I would also like it to be a decorator, since this is concerning whole methods.
There was a problem hiding this comment.
these methods are complicated, they already have decorators, and adding another decorator might complicate things. So I rather do the dirty-ish thing here and keep it as is.
StefanieSenger
left a comment
There was a problem hiding this comment.
Thanks, @adrinjalali, I have worked through this and added some comments, which I hope will be helpful.
| Result of calling `predict` on the final estimator. | ||
| """ | ||
| Xt = X | ||
| with _handle_warnings(self): |
There was a problem hiding this comment.
I would also like it to be a decorator, since this is concerning whole methods.
| def __sklearn_is_fitted__(self): | ||
| return True |
There was a problem hiding this comment.
What is the purpose of distributing self.fitted_ = True and def __sklearn_is_fitted__(self) across some of the mocking classes in this test file and in model_selection/test/test_validation?
I believe this is not needed and I think it's blurring the boundaries between the test cases and makes them difficult to read without knowing this PR or searching for it in the future.
I find it neater if test classes are very cleanly only serving their own purpose.
There was a problem hiding this comment.
This is added since w/o it the test would fail.
Strictly speaking, adding a self.fitted_ = True is used when you want check_is_fitted to be okay after calling fit, but you don't have anything else to set in fit. __sklearn_check_is_fitted__ is added when you don't need the user to call fit and the estimator is always considered fitted.
There was a problem hiding this comment.
Now this PR is merged, but still: I had run all the concerned tests files without the additions before making this comment. They all passed.
There was a problem hiding this comment.
No they fail:
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[decision_function] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[inverse_transform] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict_log_proba] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict_proba] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[score] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[transform] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
There was a problem hiding this comment.
Running the tests with the -Werror::FutureWarning flag or the SKLEARN_WARNINGS_AS_ERRORS=1 environmental variable shows the errors.
Thanks for the hint, @adrinjalali.
| # TODO(1.8): remove this test | ||
| def test_pipeline_warns_not_fitted(): | ||
| class StatelessEstimator(BaseEstimator): | ||
| def fit(self, X, y): |
There was a problem hiding this comment.
Here mentioning explicitly what is lacking:
| def fit(self, X, y): | |
| def fit(self, X, y): | |
| """Doesn't create learned attributes.""" |
There was a problem hiding this comment.
the docstring now does enough explanation I think.
| for _, estimator in reversed(self.steps): | ||
| if estimator != "passthrough": | ||
| last_step = estimator | ||
| break |
There was a problem hiding this comment.
I thought using break is considered bad practice, isn't it? Not sure about its actual downsides though. Alternatively a while loop with "last step that is not 'passthrough'" as a stopping criterion, but that would look very complicated compared to the break.
There was a problem hiding this comment.
there's nothing wrong with using break.
There was a problem hiding this comment.
One alternative might be something like
last_step = next(
(
estimator for _, estimator in reversed(self.steps)
if estimator != "passthrough"
),
None,
)but not everyone thinks this is cleaner.
There was a problem hiding this comment.
yeah that's hard to read 😁 but nice!
| """A context manager to make sure a NotFittedError is raised, if a subestimator | ||
| raises the error. | ||
|
|
||
| Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation. |
There was a problem hiding this comment.
This context manager raises a warning instead of a NotFittedError, which differs from what is written here.
Maybe like this:
| """A context manager to make sure a NotFittedError is raised, if a subestimator | |
| raises the error. | |
| Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation. | |
| """A context manager to raise a FutureWarning during the deprecation period, | |
| if the last step of a pipeline raises a NotFittedError when it is not fitted. |
There was a problem hiding this comment.
Now I see I got this wrong.
Would still be good to explain better what is supposed to happen.
There was a problem hiding this comment.
I think for a helper method which is only here for two versions during the deprecation cycle, it doesn't really matter.
| P.shape[0], | ||
| P.shape[1], | ||
| ) | ||
| self.fitted_ = True |
There was a problem hiding this comment.
cause check_is_fitted checks for an attribute with a trailing underscore. This makes check_is_fitted(self) to pass.
|
@Charlie-XIAO or @adam2392 might wanna have a look? |
Charlie-XIAO
left a comment
There was a problem hiding this comment.
This overall LGTM! Just a small suggestion:
|
|
||
|
|
||
| @contextmanager | ||
| def _handle_warnings(estimator): |
There was a problem hiding this comment.
I was confused by this name before reading its docstring. Would something like _raise_or_warn_if_not_fitted, or _ensure_fitted_or_warn, or _handle_fit_status be better?
Fixes #27014
This PR makes
Pipelineto check if it's fitted in methods other thanfit*, with a deprecation.cc @glemaitre @betatim @StefanieSenger