Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA metadata routing for StackingClassifier and StackingRegressor #28701

Merged
merged 21 commits into from May 13, 2024

Conversation

StefanieSenger
Copy link
Contributor

@StefanieSenger StefanieSenger commented Mar 26, 2024

Reference Issues/PRs

towards #22893
closes #18028

What does this implement/fix? Explain your changes.

Adds metadata routing to StackingClassifier and StackingRegressor.

Any other comments?

  1. I wasn't sure if we want to route metadata within the predict method. It's already implemented in _BaseStacking.predict(self, X, **predict_params), but without needing to set a request for it. What is the preferable way here?

  2. Also there is an issue whenever RidgeCV (the default) is the final_estimator: We get a RecursionError, because it gets hung in _metadata_requests.py. I will continue to invest about this. Update: This is a bug in the routing mechanism of RidgeCV and it's unrelated to this PR. I've opened a PR to fix it: FIX RecursionError bug with metadata routing in metaestimators with scoring #28712

@adrinjalali @OmarManzoor @glemaitre, do you want to have a look?

Copy link

github-actions bot commented Mar 26, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 01b8efe. Link to the linter CI: here

@glemaitre glemaitre self-requested a review April 8, 2024 13:07
Comment on lines 199 to 200
.. deprecated:: 1.5
`sample_weight` is deprecated in 1.5 and will be removed in 1.7.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this class it's not deprecated though, it's simply removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, true.

sklearn/ensemble/_stacking.py Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
Comment on lines 989 to 1066
def fit_transform(self, X, y, sample_weight=None):
def fit_transform(self, X, y, sample_weight=None, **fit_params):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should deprecate positional sample_weight here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should. Otherwise a sample_weight passed as a fit_param into the routing would mistakenly be passed through the old way outside of the routing.

Comment on lines +262 to +265
record_metadata_not_default(
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
)
return np.asarray([[0.0, 1.0]] * len(X))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In StackingClassifier, there's a stack_method param, whiches value can be set to predict_proba as well. In StackingRregessor it's just hardcoded as predict.

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Copy link
Contributor Author

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali
Thanks for reviewing. I went through it and improved regarding your comments.

Comment on lines 199 to 200
.. deprecated:: 1.5
`sample_weight` is deprecated in 1.5 and will be removed in 1.7.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, true.

Comment on lines 989 to 1066
def fit_transform(self, X, y, sample_weight=None):
def fit_transform(self, X, y, sample_weight=None, **fit_params):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should. Otherwise a sample_weight passed as a fit_param into the routing would mistakenly be passed through the old way outside of the routing.

Comment on lines +262 to +265
record_metadata_not_default(
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
)
return np.asarray([[0.0, 1.0]] * len(X))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In StackingClassifier, there's a stack_method param, whiches value can be set to predict_proba as well. In StackingRregessor it's just hardcoded as predict.

@StefanieSenger
Copy link
Contributor Author

So, the RecursionError bug that is about to be fixed in #28712 is also appearing here.

@adrinjalali
Copy link
Member

@StefanieSenger since #28712 is merged, wanna merge with main?

@StefanieSenger
Copy link
Contributor Author

I have merged main into it and all tests now pass, @adrinjalali

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @StefanieSenger. Just a few minor comments otherwise this looks good.

sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Show resolved Hide resolved
sklearn/ensemble/tests/test_stacking.py Outdated Show resolved Hide resolved
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Copy link
Contributor Author

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing @OmarManzoor. I have committed those changes. :)

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A further suggestion

sklearn/ensemble/tests/test_stacking.py Outdated Show resolved Hide resolved
StefanieSenger and others added 2 commits May 8, 2024 14:43
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
@StefanieSenger
Copy link
Contributor Author

A further suggestion

Oh, that's wonderful, thank you, @OmarManzoor :)

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @StefanieSenger. The versions will just need to be changed to 1.6 now that the branch for 1.5 has already been separated.

@@ -139,6 +139,11 @@ more details.
transformers' ``fit`` and ``fit_transform``. :pr:`28205` by :user:`Stefanie
Senger <StefanieSenger>`.

- |Feature| :class:`ensemble.StackingClassifier` and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might need to move to 1.6 now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved it :)

sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
sklearn/ensemble/_stacking.py Outdated Show resolved Hide resolved
@adrinjalali adrinjalali merged commit 61281cf into scikit-learn:main May 13, 2024
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support fit_params in stacking
3 participants