Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX support scalar values in fit_params in SearchCV #15863

Merged
merged 35 commits into from Dec 31, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0492836
support a scalar fit param
adrinjalali Dec 11, 2019
1ae9d03
pep8
adrinjalali Dec 11, 2019
341f8e0
TST add test for desired behavior
glemaitre Dec 23, 2019
53d7a91
FIX introduce _check_fit_params to validate parameters
glemaitre Dec 23, 2019
ef64f0b
DOC update whats new
glemaitre Dec 23, 2019
d5f0ba0
Merge remote-tracking branch 'origin/master' into pr/adrinjalali/15863
glemaitre Dec 23, 2019
9567b44
TST tests both grid-search and randomize-search
glemaitre Dec 23, 2019
f340ab6
PEP8
glemaitre Dec 23, 2019
ffb7ce5
DOC revert unecessary change
glemaitre Dec 23, 2019
2b5b1db
TST add test for _check_fit_params
glemaitre Dec 23, 2019
ffbac6f
olivier comments
glemaitre Dec 23, 2019
c0216dc
TST fixes
glemaitre Dec 23, 2019
4693729
DOC whats new
glemaitre Dec 23, 2019
d7a2c19
DOC whats new
glemaitre Dec 23, 2019
52ecee4
TST revert type of error
glemaitre Dec 23, 2019
be69ce0
add olivier suggestions
glemaitre Dec 23, 2019
9b71a9c
address olivier comments
glemaitre Dec 23, 2019
46c4b9f
address thomas comments
glemaitre Dec 23, 2019
71fab3f
PEP8
glemaitre Dec 23, 2019
9a85162
comments olivier
glemaitre Dec 23, 2019
f41c808
TST fix test by passing X
glemaitre Dec 23, 2019
c989c70
avoid to call twice tocsr
glemaitre Dec 23, 2019
570dfa8
add case column/row sparse in check_fit_param
glemaitre Dec 23, 2019
444c947
provide optional indices
glemaitre Dec 23, 2019
9f47b58
TST check content when indexing params
glemaitre Dec 23, 2019
75bd0a9
PEP8
glemaitre Dec 23, 2019
c24f39d
TST update tests to check identity
glemaitre Dec 23, 2019
63679fd
stupid fix
glemaitre Dec 23, 2019
849615b
use a distribution in RandomizedSearchCV
glemaitre Dec 24, 2019
7837cdf
MNT add lightgbm to one of the CI build
glemaitre Dec 24, 2019
b98e194
move to another build
glemaitre Dec 24, 2019
3127d2b
do not install dependencies lightgbm
glemaitre Dec 24, 2019
a096a7d
MNT comments on the CI setup
glemaitre Dec 24, 2019
18b1207
address some comments
glemaitre Dec 27, 2019
74d70e7
Test fit_params compat without dependency on lightgbm
ogrisel Dec 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 13 additions & 7 deletions doc/whats_new/v0.22.rst
Expand Up @@ -35,6 +35,10 @@ Changelog
Follow-up of :pr:`15898` by :user:`Shivam Gargsya <shivamgargsya>`.
:pr:`15933` by :user:`Guillaume Lemaitre <glemaitre>` and `Olivier Grisel`_.

- |Fix| :func:`inspection.plot_partial_dependence` and
:meth:`inspection.PartialDependenceDisplay.plot` now consistently checks
the number of axes passed in. :pr:`15760` by `Thomas Fan`_.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

:mod:`sklearn.metrics`
......................

Expand All @@ -50,19 +54,21 @@ Changelog
value of the ``zero_division`` keyword argument. :pr:`15879`
by :user:`Bibhash Chandra Mitra <Bibyutatsu>`.

:mod:`sklearn.model_selection`
..............................

- |Fix| :class:`model_selection.GridSearchCV` and
:class:`model_selection.RandomizedSearchCV` will accept scalar provided in
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
`fit_params`. Change in 0.22 was breaking backward compatibility.
:pr:`15863` by :user:`Adrin Jalali <adrinjalali>` and
:user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.utils`
....................

- |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with
boolean columns to floats. :pr:`15797` by `Thomas Fan`_.

:mod:`sklearn.inspection`
.........................

- |Fix| :func:`inspection.plot_partial_dependence` and
:meth:`inspection.PartialDependenceDisplay.plot` now consistently checks
the number of axes passed in. :pr:`15760` by `Thomas Fan`_.

.. _changes_0_22:

Version 0.22.0
Expand Down
6 changes: 2 additions & 4 deletions sklearn/model_selection/_search.py
Expand Up @@ -33,7 +33,7 @@
from ..utils import check_random_state
from ..utils.fixes import MaskedArray
from ..utils.random import sample_without_replacement
from ..utils.validation import indexable, check_is_fitted
from ..utils.validation import indexable, check_is_fitted, _check_fit_params
from ..utils.metaestimators import if_delegate_has_method
from ..metrics._scorer import _check_multimetric_scoring
from ..metrics import check_scoring
Expand Down Expand Up @@ -648,9 +648,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
refit_metric = 'score'

X, y, groups = indexable(X, y, groups)
# make sure fit_params are sliceable
fit_params_values = indexable(*fit_params.values())
fit_params = dict(zip(fit_params.keys(), fit_params_values))
fit_params = _check_fit_params(fit_params)

n_splits = cv.get_n_splits(X, y, groups)

Expand Down
61 changes: 60 additions & 1 deletion sklearn/model_selection/tests/test_search.py
Expand Up @@ -27,7 +27,7 @@

from scipy.stats import bernoulli, expon, uniform

from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.base import clone
from sklearn.exceptions import NotFittedError
from sklearn.datasets import make_classification
Expand All @@ -36,6 +36,7 @@

from sklearn.model_selection import fit_grid_point
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedShuffleSplit
Expand Down Expand Up @@ -1846,3 +1847,61 @@ def test_search_cv__pairwise_property_equivalence_of_precomputed():

attr_message = "GridSearchCV not identical with precomputed metric"
assert (preds_original == preds_precomputed).all(), attr_message


@pytest.mark.parametrize(
"SearchCV, param_search",
[(GridSearchCV, {'a': [0.1, 0.01]}),
(RandomizedSearchCV, {'a': np.random.randint(1, 3, size=2)})]
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
)
def test_scalar_fit_param(SearchCV, param_search):
# check general support for scalar in fit_params
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
# non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15805
class TestEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, a=None):
self.a = a

def fit(self, X, y, r):
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
assert r == 42
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

def predict(self, X):
return np.zeros(shape=(len(X)))

model = SearchCV(TestEstimator(), param_search)
X, y = make_classification(random_state=42)
model.fit(X, y, r=42)


def _custom_lgbm_metric(y_test, y_pred):
# y_pred are probablities which need to be thresholded
y_pred = (y_pred > 0.5).astype(int)
acc = accuracy_score(y_test, y_pred)
# required output of format: (eval_name, eval_result, is_higher_better)
return ('accuracy', acc, True)


@pytest.mark.parametrize("metric", ['auc', _custom_lgbm_metric])
@pytest.mark.parametrize(
"SearchCV, param_search",
[(GridSearchCV, {'learning_rate': [0.1, 0.01]}),
(RandomizedSearchCV, {'learning_rate': uniform(0.01, 0.1)})]
)
def test_scalar_fit_param_lgbm(metric, SearchCV, param_search):
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
# check support for scalar in fit_params in LightGBM
# non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15805
lgbm = pytest.importorskip("lightgbm")
X_train, X_valid, y_train, y_valid = train_test_split(
*make_classification(random_state=42), random_state=42
)
model = SearchCV(
lgbm.LGBMClassifier(n_estimators=5), param_search
)
fit_params = {
'eval_set': [(X_valid, y_valid)],
'eval_metric': metric,
'early_stopping_rounds': 5,
'verbose': False
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
}
model.fit(X_train, y_train, **fit_params)
64 changes: 53 additions & 11 deletions sklearn/utils/validation.py
Expand Up @@ -16,7 +16,8 @@
import numpy as np
import scipy.sparse as sp
from distutils.version import LooseVersion
from inspect import signature, isclass, Parameter
from inspect import signature, Parameter
from inspect import isclass, isfunction, ismethod, ismodule

from numpy.core.numeric import ComplexWarning
import joblib
Expand Down Expand Up @@ -212,6 +213,26 @@ def check_consistent_length(*arrays):
" samples: %r" % [int(l) for l in lengths])


def _convert_iterable(iterable):
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
"""Helper convert iterable to arrays of sparse matrices.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Convert sparse matrices to csr and non-interable objects to arrays.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Let passes `None`.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
iterable : {list, dataframe, array, sparse} or None
Object to be converted to a sliceable iterable.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
"""
if sp.issparse(iterable):
return iterable.tocsr()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to convert to csr? I think if the estimator needs to convert the param, they'll do it themselves.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward-compatibility only (we were doing it before). I assume that csr would be a good default since we are converting to csr when the number of samples in the arrays are the same than in X meaning that we should be efficient taking rows.

elif hasattr(iterable, "__getitem__") or hasattr(iterable, "iloc"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I've missed something, but if the point is not to pass anything which implements __array_function__, shouldn't we test for that instead? An object may implement that protocol and implement __getitem__, can it not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you mean something like:

if sp.sparse(iterable):
    # efficient indexing per rows
    return iterable.csr()
elif hasattr(iterable, "iloc"):
    # pandas series or dataframe
    return iterable
elif hasattr(iterable, "__array_function__"):
    # do not rely on array protocol
    return np.asarray(iterable)
elif hasattr(iterable, "__getitem__"):
    return iterable
return np.asarray(iterable)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this looks better. I'd put the __array_function__ condition on top or right after sp.issparse though.

return iterable
elif iterable is None:
return iterable
return np.array(iterable)


def indexable(*iterables):
"""Make arrays indexable for cross-validation.

Expand All @@ -224,16 +245,7 @@ def indexable(*iterables):
*iterables : lists, dataframes, arrays, sparse matrices
List of objects to ensure sliceability.
"""
result = []
for X in iterables:
if sp.issparse(X):
result.append(X.tocsr())
elif hasattr(X, "__getitem__") or hasattr(X, "iloc"):
result.append(X)
elif X is None:
result.append(X)
else:
result.append(np.array(X))
result = [_convert_iterable(X) for X in iterables]
check_consistent_length(*result)
return result

Expand Down Expand Up @@ -1259,3 +1271,33 @@ def inner_f(*args, **kwargs):
kwargs.update({k: arg for k, arg in zip(all_args, args)})
return f(**kwargs)
return inner_f


def _check_fit_params(fit_params):
"""Check and validate the parameters passed during `fit`.

Parameters
----------
fit_params : dict
Dictionary containing the parameters passed at fit.

Returns
-------
fit_params_validated : dict
Validated parameters. We ensure that the values are iterable.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
"""
fit_params_validated = {}
for param_key, param_value in fit_params.items():
is_scalar = [
check(param_value)
for check in [np.isscalar, ismodule, isclass, ismethod, isfunction]
]
if any(is_scalar):
# keep scalar as is for backward-compatibility
# https://github.com/scikit-learn/scikit-learn/issues/15805
fit_params_validated[param_key] = param_value
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
else:
# ensure iterable will be sliceable
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
fit_params_validated[param_key] = _convert_iterable(param_value)

return fit_params_validated