diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e2ff71802ce72..b029a2fd18574 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -71,10 +71,9 @@ jobs: JOBLIB_VERSION: '0.12.3' COVERAGE: 'true' # Linux environment to test the latest available dependencies and MKL. - # It runs tests requiring pandas and PyAMG. + # It runs tests requiring lightgbm, pandas and PyAMG. pylatest_pip_openblas_pandas: DISTRIB: 'conda-pip-latest' - # FIXME: pinned until SciPy wheels are available for Python 3.8 PYTHON_VERSION: '3.8' PYTEST_VERSION: '4.6.2' COVERAGE: 'true' diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index aa2707bb03837..1ef981b5dd6e8 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -92,6 +92,8 @@ elif [[ "$DISTRIB" == "conda-pip-latest" ]]; then python -m pip install numpy scipy cython joblib python -m pip install pytest==$PYTEST_VERSION pytest-cov pytest-xdist python -m pip install pandas matplotlib pyamg + # do not install dependencies for lightgbm since it requires scikit-learn + python -m pip install lightgbm --no-deps fi if [[ "$COVERAGE" == "true" ]]; then diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index e5cf5a29a8d52..44c03adc8d649 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -50,6 +50,15 @@ Changelog value of the ``zero_division`` keyword argument. :pr:`15879` by :user:`Bibhash Chandra Mitra `. +:mod:`sklearn.model_selection` +.............................. + +- |Fix| :class:`model_selection.GridSearchCV` and + :class:`model_selection.RandomizedSearchCV` accept scalar values provided in + `fit_params`. Change in 0.22 was breaking backward compatibility. + :pr:`15863` by :user:`Adrin Jalali ` and + :user:`Guillaume Lemaitre `. + :mod:`sklearn.utils` .................... diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index e6a8493ef6250..a70bdd7a2f9dc 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -33,7 +33,7 @@ from ..utils import check_random_state from ..utils.fixes import MaskedArray from ..utils.random import sample_without_replacement -from ..utils.validation import indexable, check_is_fitted +from ..utils.validation import indexable, check_is_fitted, _check_fit_params from ..utils.metaestimators import if_delegate_has_method from ..metrics._scorer import _check_multimetric_scoring from ..metrics import check_scoring @@ -648,9 +648,7 @@ def fit(self, X, y=None, groups=None, **fit_params): refit_metric = 'score' X, y, groups = indexable(X, y, groups) - # make sure fit_params are sliceable - fit_params_values = indexable(*fit_params.values()) - fit_params = dict(zip(fit_params.keys(), fit_params_values)) + fit_params = _check_fit_params(X, fit_params) n_splits = cv.get_n_splits(X, y, groups) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 67a30c6416031..1e48a5a610e3c 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -23,6 +23,7 @@ from ..base import is_classifier, clone from ..utils import (indexable, check_random_state, _safe_indexing, _message_with_time) +from ..utils.validation import _check_fit_params from ..utils.validation import _is_arraylike, _num_samples from ..utils.metaestimators import _safe_split from ..metrics import check_scoring @@ -489,8 +490,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} - fit_params = {k: _index_param_value(X, v, train) - for k, v in fit_params.items()} + fit_params = _check_fit_params(X, fit_params, train) train_scores = {} if parameters is not None: @@ -830,8 +830,7 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params, """ # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} - fit_params = {k: _index_param_value(X, v, train) - for k, v in fit_params.items()} + fit_params = _check_fit_params(X, fit_params, train) X_train, y_train = _safe_split(estimator, X, y, train) X_test, _ = _safe_split(estimator, X, y, test, train) @@ -937,16 +936,6 @@ def _check_is_permutation(indices, n_samples): return True -def _index_param_value(X, v, indices): - """Private helper function for parameter value indexing.""" - if not _is_arraylike(v) or _num_samples(v) != _num_samples(X): - # pass through: skip indexing - return v - if sp.issparse(v): - v = v.tocsr() - return _safe_indexing(v, indices) - - def permutation_test_score(estimator, X, y, groups=None, cv=None, n_permutations=100, n_jobs=None, random_state=0, verbose=0, scoring=None): diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 056927bee75d0..bacfc20eb1fc1 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -27,7 +27,7 @@ from scipy.stats import bernoulli, expon, uniform -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import clone from sklearn.exceptions import NotFittedError from sklearn.datasets import make_classification @@ -36,6 +36,7 @@ from sklearn.model_selection import fit_grid_point from sklearn.model_selection import cross_val_score +from sklearn.model_selection import train_test_split from sklearn.model_selection import KFold from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedShuffleSplit @@ -218,33 +219,25 @@ def test_grid_search_pipeline_steps(): assert not hasattr(param_grid['regressor'][1], 'coef_') -def check_hyperparameter_searcher_with_fit_params(klass, **klass_kwargs): +@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV]) +def test_SearchCV_with_fit_params(SearchCV): X = np.arange(100).reshape(10, 10) y = np.array([0] * 5 + [1] * 5) clf = CheckingClassifier(expected_fit_params=['spam', 'eggs']) - searcher = klass(clf, {'foo_param': [1, 2, 3]}, cv=2, **klass_kwargs) + searcher = SearchCV( + clf, {'foo_param': [1, 2, 3]}, cv=2, error_score="raise" + ) # The CheckingClassifier generates an assertion error if # a parameter is missing or has length != len(X). - assert_raise_message(AssertionError, - "Expected fit parameter(s) ['eggs'] not seen.", - searcher.fit, X, y, spam=np.ones(10)) - assert_raise_message( - ValueError, - "Found input variables with inconsistent numbers of samples: [", - searcher.fit, X, y, spam=np.ones(1), - eggs=np.zeros(10)) - searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10)) + err_msg = r"Expected fit parameter\(s\) \['eggs'\] not seen." + with pytest.raises(AssertionError, match=err_msg): + searcher.fit(X, y, spam=np.ones(10)) - -def test_grid_search_with_fit_params(): - check_hyperparameter_searcher_with_fit_params(GridSearchCV, - error_score='raise') - - -def test_random_search_with_fit_params(): - check_hyperparameter_searcher_with_fit_params(RandomizedSearchCV, n_iter=1, - error_score='raise') + err_msg = "Fit parameter spam has length 1; expected" + with pytest.raises(AssertionError, match=err_msg): + searcher.fit(X, y, spam=np.ones(1), eggs=np.zeros(10)) + searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10)) @ignore_warnings @@ -1846,3 +1839,78 @@ def test_search_cv__pairwise_property_equivalence_of_precomputed(): attr_message = "GridSearchCV not identical with precomputed metric" assert (preds_original == preds_precomputed).all(), attr_message + + +@pytest.mark.parametrize( + "SearchCV, param_search", + [(GridSearchCV, {'a': [0.1, 0.01]}), + (RandomizedSearchCV, {'a': uniform(1, 3)})] +) +def test_scalar_fit_param(SearchCV, param_search): + # unofficially sanctioned tolerance for scalar values in fit_params + # non-regression test for: + # https://github.com/scikit-learn/scikit-learn/issues/15805 + class TestEstimator(BaseEstimator, ClassifierMixin): + def __init__(self, a=None): + self.a = a + + def fit(self, X, y, r=None): + self.r_ = r + + def predict(self, X): + return np.zeros(shape=(len(X))) + + model = SearchCV(TestEstimator(), param_search) + X, y = make_classification(random_state=42) + model.fit(X, y, r=42) + assert model.best_estimator_.r_ == 42 + + +@pytest.mark.parametrize( + "SearchCV, param_search", + [(GridSearchCV, {'alpha': [0.1, 0.01]}), + (RandomizedSearchCV, {'alpha': uniform(0.01, 0.1)})] +) +def test_scalar_fit_param_compat(SearchCV, param_search): + # check support for scalar values in fit_params, for instance in LightGBM + # that do not exactly respect the scikit-learn API contract but that we do + # not want to break without an explicit deprecation cycle and API + # recommendations for implementing early stopping with a user provided + # validation set. non-regression test for: + # https://github.com/scikit-learn/scikit-learn/issues/15805 + X_train, X_valid, y_train, y_valid = train_test_split( + *make_classification(random_state=42), random_state=42 + ) + + class _FitParamClassifier(SGDClassifier): + + def fit(self, X, y, sample_weight=None, tuple_of_arrays=None, + scalar_param=None, callable_param=None): + super().fit(X, y, sample_weight=sample_weight) + assert scalar_param > 0 + assert callable(callable_param) + + # The tuple of arrays should be preserved as tuple. + assert isinstance(tuple_of_arrays, tuple) + assert tuple_of_arrays[0].ndim == 2 + assert tuple_of_arrays[1].ndim == 1 + return self + + def _fit_param_callable(): + pass + + model = SearchCV( + _FitParamClassifier(), param_search + ) + + # NOTE: `fit_params` should be data dependent (e.g. `sample_weight`) which + # is not the case for the following parameters. But this abuse is common in + # popular third-party libraries and we should tolerate this behavior for + # now and be careful not to break support for those without following + # proper deprecation cycle. + fit_params = { + 'tuple_of_arrays': (X_valid, y_valid), + 'callable_param': _fit_param_callable, + 'scalar_param': 42, + } + model.fit(X_train, y_train, **fit_params) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index f121f11658051..a3bd7b566fbbf 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -32,6 +32,7 @@ from sklearn.ensemble import RandomForestRegressor from sklearn.svm import SVR from sklearn.datasets import make_blobs +from sklearn.utils import _safe_indexing from sklearn.utils.validation import ( has_fit_parameter, check_is_fitted, @@ -46,6 +47,7 @@ _check_sample_weight, _allclose_dense_sparse, FLOAT_DTYPES) +from sklearn.utils.validation import _check_fit_params import sklearn @@ -1053,3 +1055,31 @@ def __init__(self, a=1, b=1, *, c=1, d=1): with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"): A2(1, 2, 3, 4) + + +@pytest.mark.parametrize("indices", [None, [1, 3]]) +def test_check_fit_params(indices): + X = np.random.randn(4, 2) + fit_params = { + 'list': [1, 2, 3, 4], + 'array': np.array([1, 2, 3, 4]), + 'sparse-col': sp.csc_matrix([1, 2, 3, 4]).T, + 'sparse-row': sp.csc_matrix([1, 2, 3, 4]), + 'scalar-int': 1, + 'scalar-str': 'xxx', + 'None': None, + } + result = _check_fit_params(X, fit_params, indices) + indices_ = indices if indices is not None else list(range(X.shape[0])) + + for key in ['sparse-row', 'scalar-int', 'scalar-str', 'None']: + assert result[key] is fit_params[key] + + assert result['list'] == _safe_indexing(fit_params['list'], indices_) + assert_array_equal( + result['array'], _safe_indexing(fit_params['array'], indices_) + ) + assert_allclose_dense_sparse( + result['sparse-col'], + _safe_indexing(fit_params['sparse-col'], indices_) + ) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index e08495de30af5..810ae44a5d60e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -212,6 +212,26 @@ def check_consistent_length(*arrays): " samples: %r" % [int(l) for l in lengths]) +def _make_indexable(iterable): + """Ensure iterable supports indexing or convert to an indexable variant. + + Convert sparse matrices to csr and other non-indexable iterable to arrays. + Let `None` and indexable objects (e.g. pandas dataframes) pass unchanged. + + Parameters + ---------- + iterable : {list, dataframe, array, sparse} or None + Object to be converted to an indexable iterable. + """ + if sp.issparse(iterable): + return iterable.tocsr() + elif hasattr(iterable, "__getitem__") or hasattr(iterable, "iloc"): + return iterable + elif iterable is None: + return iterable + return np.array(iterable) + + def indexable(*iterables): """Make arrays indexable for cross-validation. @@ -224,16 +244,7 @@ def indexable(*iterables): *iterables : lists, dataframes, arrays, sparse matrices List of objects to ensure sliceability. """ - result = [] - for X in iterables: - if sp.issparse(X): - result.append(X.tocsr()) - elif hasattr(X, "__getitem__") or hasattr(X, "iloc"): - result.append(X) - elif X is None: - result.append(X) - else: - result.append(np.array(X)) + result = [_make_indexable(X) for X in iterables] check_consistent_length(*result) return result @@ -1259,3 +1270,41 @@ def inner_f(*args, **kwargs): kwargs.update({k: arg for k, arg in zip(all_args, args)}) return f(**kwargs) return inner_f + + +def _check_fit_params(X, fit_params, indices=None): + """Check and validate the parameters passed during `fit`. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Data array. + + fit_params : dict + Dictionary containing the parameters passed at fit. + + indices : array-like of shape (n_samples,), default=None + Indices to be selected if the parameter has the same size as `X`. + + Returns + ------- + fit_params_validated : dict + Validated parameters. We ensure that the values support indexing. + """ + from . import _safe_indexing + fit_params_validated = {} + for param_key, param_value in fit_params.items(): + if (not _is_arraylike(param_value) or + _num_samples(param_value) != _num_samples(X)): + # Non-indexable pass-through (for now for backward-compatibility). + # https://github.com/scikit-learn/scikit-learn/issues/15805 + fit_params_validated[param_key] = param_value + else: + # Any other fit_params should support indexing + # (e.g. for cross-validation). + fit_params_validated[param_key] = _make_indexable(param_value) + fit_params_validated[param_key] = _safe_indexing( + fit_params_validated[param_key], indices + ) + + return fit_params_validated