Skip to content

Commit

Permalink
FIX fix regression in gridsearchcv when parameter grids have estimato…
Browse files Browse the repository at this point in the history
…rs as values (#29179)
  • Loading branch information
MarcoGorelli committed Jun 5, 2024
1 parent 6340609 commit b375b7b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 6 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ Changelog
grids that have heterogeneous parameter values.
:pr:`29078` by :user:`Loïc Estève <lesteve>`.

- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
grids that have estimators as parameter values.
:pr:`29179` by :user:`Marco Gorelli<MarcoGorelli>`.


.. _changes_1_5:

Expand Down
19 changes: 17 additions & 2 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,9 +1089,24 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
for key, param_result in param_results.items():
param_list = list(param_result.values())
try:
arr_dtype = np.result_type(*param_list)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="in the future the `.dtype` attribute",
category=DeprecationWarning,
)
# Warning raised by NumPy 1.20+
arr_dtype = np.result_type(*param_list)
except (TypeError, ValueError):
arr_dtype = object
arr_dtype = np.dtype(object)
else:
if any(np.min_scalar_type(x) == object for x in param_list):
# `np.result_type` might get thrown off by `.dtype` properties
# (which some estimators have).
# If finding the result dtype this way would give object,
# then we use object.
# https://github.com/scikit-learn/scikit-learn/issues/29157
arr_dtype = np.dtype(object)
if len(param_list) == n_candidates and arr_dtype != object:
# Exclude `object` else the numpy constructor might infer a list of
# tuples to be a 2d array.
Expand Down
40 changes: 36 additions & 4 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn import config_context
from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
from sklearn.cluster import KMeans
from sklearn.compose import ColumnTransformer
from sklearn.datasets import (
make_blobs,
make_classification,
Expand Down Expand Up @@ -64,7 +65,7 @@
from sklearn.naive_bayes import ComplementNB
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from sklearn.svm import SVC, LinearSVC
from sklearn.tests.metadata_routing_common import (
ConsumingScorer,
Expand Down Expand Up @@ -1403,9 +1404,7 @@ def test_search_cv_results_none_param():
est_parameters,
cv=cv,
).fit(X, y)
assert_array_equal(
grid_search.cv_results_["param_random_state"], [0, float("nan")]
)
assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])


@ignore_warnings()
Expand Down Expand Up @@ -2686,3 +2685,36 @@ def score(self, X, y):
grid_search.fit(X, y)
for param in param_grid:
assert grid_search.cv_results_[f"param_{param}"].dtype == object


def test_search_with_estimators_issue_29157():
"""Check cv_results_ for estimators with a `dtype` parameter, e.g. OneHotEncoder."""
pd = pytest.importorskip("pandas")
df = pd.DataFrame(
{
"numeric_1": [1, 2, 3, 4, 5],
"object_1": ["a", "a", "a", "a", "a"],
"target": [1.0, 4.1, 2.0, 3.0, 1.0],
}
)
X = df.drop("target", axis=1)
y = df["target"]
enc = ColumnTransformer(
[("enc", OneHotEncoder(sparse_output=False), ["object_1"])],
remainder="passthrough",
)
pipe = Pipeline(
[
("enc", enc),
("regressor", LinearRegression()),
]
)
grid_params = {
"enc__enc": [
OneHotEncoder(sparse_output=False),
OrdinalEncoder(),
]
}
grid_search = GridSearchCV(pipe, grid_params, cv=2)
grid_search.fit(X, y)
assert grid_search.cv_results_["param_enc__enc"].dtype == object

0 comments on commit b375b7b

Please sign in to comment.