Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add metadata routing to cross_val* #26896

Merged
merged 29 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c9fcd6e
ENH add metadata routing to cross_val*
adrinjalali Jul 18, 2023
9a0cfba
existing tests pass
adrinjalali Jul 18, 2023
5277212
handle groups and start tests
adrinjalali Jul 18, 2023
9d40ca8
Merge remote-tracking branch 'upstream/main' into slep6/cross_val
adrinjalali Jul 25, 2023
0e64c4a
...
adrinjalali Jul 25, 2023
8bf6ce5
pass
adrinjalali Jul 25, 2023
185cc2d
Merge remote-tracking branch 'upstream/main' into slep6/cross_val
adrinjalali Jul 25, 2023
f8528d1
fix imports
adrinjalali Jul 25, 2023
5d633d8
...
adrinjalali Jul 27, 2023
787b196
fixes and tests
adrinjalali Jul 31, 2023
feacd04
Merge remote-tracking branch 'upstream/main' into slep6/cross_val
adrinjalali Jul 31, 2023
7bc37c7
fix test to fit first
adrinjalali Jul 31, 2023
a62174f
add cross_val_predict and tests
adrinjalali Aug 1, 2023
f948747
Merge remote-tracking branch 'upstream/main' into slep6/cross_val
adrinjalali Aug 1, 2023
813a63e
add changelog
adrinjalali Aug 1, 2023
d4cc661
minimal touch
adrinjalali Aug 1, 2023
10c94b0
[doc build] fix internal deprecated usage
adrinjalali Aug 1, 2023
ee03848
Update sklearn/model_selection/tests/test_validation.py
adrinjalali Aug 3, 2023
bf70ad0
Update sklearn/model_selection/tests/test_validation.py
adrinjalali Aug 3, 2023
66acee2
Update sklearn/model_selection/tests/test_validation.py
adrinjalali Aug 3, 2023
7ea0c58
Update sklearn/model_selection/tests/test_validation.py
adrinjalali Aug 3, 2023
b0bb9bb
Update sklearn/model_selection/tests/test_validation.py
adrinjalali Aug 3, 2023
047510b
comment on IDs
adrinjalali Aug 3, 2023
77b0b02
Merge branch 'slep6/cross_val' of github.com:adrinjalali/scikit-learn…
adrinjalali Aug 3, 2023
85e0a98
Merge remote-tracking branch 'upstream/main' into slep6/cross_val
adrinjalali Aug 3, 2023
df62460
apply Omar's suggestions
adrinjalali Aug 7, 2023
58b0ab1
Merge remote-tracking branch 'upstream/main' into slep6/cross_val
adrinjalali Aug 7, 2023
72a0537
more fixes by Omar
adrinjalali Aug 8, 2023
19c1193
Merge remote-tracking branch 'upstream/main' into slep6/cross_val
adrinjalali Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ Changelog
object in the parameter grid if it's an estimator. :pr:`26786` by `Adrin
Jalali`_.

- |Feature| :func:`~model_selection.cross_validate`,
:func:`~model_selection.cross_val_score`, and
:func:`~model_selection.cross_val_predict` now support metadata routing. The
metadata are routed to the estimator's `fit`, the scorer, and the CV
splitter's `split`. The metadata is accepted via the new `params` parameter.
`fit_params` is deprecated and will be removed in version 1.6. `groups`
parameter is also not accepted as a separate argument when metadata routing
is enabled and should be passed via the `params` parameter. :pr:`26896` by
`Adrin Jalali`_.

:mod:`sklearn.neighbors`
........................

Expand Down
2 changes: 1 addition & 1 deletion sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
cv=cv,
method=method_name,
n_jobs=self.n_jobs,
fit_params=routed_params.estimator.fit,
params=routed_params.estimator.fit,
)
predictions = _compute_predictions(
pred_method, method_name, X, n_classes
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def fit(self, X, y, sample_weight=None):
cv=deepcopy(cv),
method=meth,
n_jobs=self.n_jobs,
fit_params=fit_params,
params=fit_params,
verbose=self.verbose,
)
for est, meth in zip(all_estimators, self.stack_method_)
Expand Down
7 changes: 6 additions & 1 deletion sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer):
X_train,
y_train,
lambda estimator, features: _score(
estimator, X_test[:, features], y_test, scorer
# TODO(SLEP6): pass score_params here
estimator,
X_test[:, features],
y_test,
scorer,
score_params=None,
),
).scores_

Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC
from sklearn.tests.test_metadata_routing import assert_request_is_empty
from sklearn.tests.metadata_routing_common import assert_request_is_empty
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils._testing import (
assert_almost_equal,
Expand Down
2 changes: 2 additions & 0 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,8 @@ def fit(self, X, y=None, *, groups=None, **fit_params):
fit_and_score_kwargs = dict(
scorer=scorers,
fit_params=fit_params,
# TODO(SLEP6): pass score params along
score_params=None,
return_train_score=self.return_train_score,
return_n_test_samples=True,
return_times=True,
Expand Down