From 58be184f5f91029628b38cb07882b241a18f8b11 Mon Sep 17 00:00:00 2001 From: Michal Romaniuk Date: Fri, 24 Jan 2014 18:40:17 +0000 Subject: [PATCH] Enable grid search with classifiers that may throw an error on individual fits. --- doc/whats_new.rst | 8 ++++ sklearn/cross_validation.py | 48 ++++++++++++++++++----- sklearn/grid_search.py | 50 +++++++++++++++++------- sklearn/tests/test_grid_search.py | 63 ++++++++++++++++++++++++++++++- 4 files changed, 146 insertions(+), 23 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 9c33b25be2b77..834fa76f5fe00 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -69,6 +69,12 @@ Enhancements - ``DictVectorizer`` can now perform ``fit_transform`` on an iterable in a single pass, when giving the option ``sort=False``. By Dan Blanchard. + - :class:`GridSearchCV` and :class:`RandomizedSearchCV` can now be + configured to work with estimators that may fail and raise errors on + individual folds. This option is controlled by the `error_score` + parameter. This does not affect errors raised on re-fit. By + `Michal Romaniuk`_. + Documentation improvements .......................... @@ -2985,3 +2991,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Jatin Shah: http://jatinshah.org/ .. _Dougal Sutherland: https://github.com/dougalsutherland + +.. _Michal Romaniuk: https://github.com/romaniukm diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index ebcf4f934f043..a1246013c8bed 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1150,9 +1150,13 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, return np.array(scores)[:, 0] -def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, - fit_params, return_train_score=False, - return_parameters=False): +class FitFailedWarning(RuntimeWarning): + pass + + +def _fit_and_score(estimator, X, y, scorer, train, test, verbose, + parameters, fit_params, return_train_score=False, + return_parameters=False, error_score='raise'): """Fit estimator and compute scores for a given dataset split. Parameters @@ -1180,6 +1184,12 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, verbose : integer The verbosity level. + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + parameters : dict or None Parameters to be set on the estimator. @@ -1231,13 +1241,33 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) - if y_train is None: - estimator.fit(X_train, **fit_params) + + try: + if y_train is None: + estimator.fit(X_train, **fit_params) + else: + estimator.fit(X_train, y_train, **fit_params) + + except Exception as e: + if error_score == 'raise': + raise + elif isinstance(error_score, numbers.Number): + test_score = error_score + if return_train_score: + train_score = error_score + warnings.warn("Classifier fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%r" % (error_score, e), FitFailedWarning) + else: + raise ValueError("error_score must be the string 'raise' or a" + " numeric value. (Hint: if using 'raise', please" + " make sure that it has been spelled correctly.)" + ) + else: - estimator.fit(X_train, y_train, **fit_params) - test_score = _score(estimator, X_test, y_test, scorer) - if return_train_score: - train_score = _score(estimator, X_train, y_train, scorer) + test_score = _score(estimator, X_test, y_test, scorer) + if return_train_score: + train_score = _score(estimator, X_train, y_train, scorer) scoring_time = time.time() - start_time diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 5bc3fec1e318e..b41467c7b9a92 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -182,7 +182,7 @@ def __len__(self): def fit_grid_point(X, y, estimator, parameters, train, test, scorer, - verbose, **fit_params): + verbose, error_score='raise', **fit_params): """Run fit on one set of parameters. Parameters @@ -215,6 +215,11 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer, **fit_params : kwargs Additional parameter passed to the fit function of the estimator. + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. Returns ------- @@ -229,7 +234,7 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer, """ score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, - fit_params) + fit_params, error_score) return score, parameters, n_samples_test @@ -284,7 +289,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, @abstractmethod def __init__(self, estimator, scoring=None, fit_params=None, n_jobs=1, iid=True, - refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'): + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', + error_score='raise'): self.scoring = scoring self.estimator = estimator @@ -295,6 +301,7 @@ def __init__(self, estimator, scoring=None, self.cv = cv self.verbose = verbose self.pre_dispatch = pre_dispatch + self.error_score = error_score def score(self, X, y=None): """Returns the score on the given data, if the estimator has been refit @@ -385,9 +392,10 @@ def _fit(self, X, y, parameter_iterable): )( delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_, train, test, self.verbose, parameters, - self.fit_params, return_parameters=True) - for parameters in parameter_iterable - for train, test in cv) + self.fit_params, return_parameters=True, + error_score=self.error_score) + for parameters in parameter_iterable + for train, test in cv) # Out is a list of triplet: score, estimator, n_test_samples n_fits = len(out) @@ -506,6 +514,13 @@ class GridSearchCV(BaseSearchCV): verbose : integer Controls the verbosity: the higher, the more messages. + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + + Examples -------- >>> from sklearn import svm, grid_search, datasets @@ -515,7 +530,7 @@ class GridSearchCV(BaseSearchCV): >>> clf = grid_search.GridSearchCV(svr, parameters) >>> clf.fit(iris.data, iris.target) ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS - GridSearchCV(cv=None, + GridSearchCV(cv=None, error_score=..., estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=..., degree=..., gamma=..., kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, @@ -580,12 +595,14 @@ class GridSearchCV(BaseSearchCV): """ - def __init__(self, estimator, param_grid, scoring=None, - fit_params=None, n_jobs=1, iid=True, - refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'): + def __init__(self, estimator, param_grid, scoring=None, loss_func=None, + score_func=None, fit_params=None, n_jobs=1, iid=True, + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', + error_score='raise'): + super(GridSearchCV, self).__init__( estimator, scoring, fit_params, n_jobs, iid, - refit, cv, verbose, pre_dispatch) + refit, cv, verbose, pre_dispatch, error_score) self.param_grid = param_grid _check_param_grid(param_grid) @@ -680,6 +697,12 @@ class RandomizedSearchCV(BaseSearchCV): verbose : integer Controls the verbosity: the higher, the more messages. + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + Attributes ---------- @@ -730,7 +753,8 @@ class RandomizedSearchCV(BaseSearchCV): def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, - verbose=0, pre_dispatch='2*n_jobs', random_state=None): + verbose=0, pre_dispatch='2*n_jobs', random_state=None, + error_score='raise'): self.param_distributions = param_distributions self.n_iter = n_iter @@ -738,7 +762,7 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, super(RandomizedSearchCV, self).__init__( estimator=estimator, scoring=scoring, fit_params=fit_params, n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, - pre_dispatch=pre_dispatch) + pre_dispatch=pre_dispatch, error_score=error_score) def fit(self, X, y=None): """Run fit on the estimator with randomly drawn parameters. diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 7791993a44901..ed547c7b9c3a5 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -17,6 +17,7 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_warns from sklearn.utils.testing import assert_raise_message from sklearn.utils.testing import assert_false, assert_true from sklearn.utils.testing import assert_array_equal @@ -44,7 +45,7 @@ from sklearn.metrics import f1_score from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score -from sklearn.cross_validation import KFold, StratifiedKFold +from sklearn.cross_validation import KFold, StratifiedKFold, FitFailedWarning from sklearn.preprocessing import Imputer from sklearn.pipeline import Pipeline @@ -674,3 +675,63 @@ def test_grid_search_allows_nans(): ('classifier', MockClassifier()), ]) GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y) + + + +class FailingClassifier(BaseEstimator): + """Classifier that raises a ValueError on fit()""" + + FAILING_PARAMETER = 2 + + def __init__(self, parameter=None): + self.parameter = parameter + + def fit(self, X, y=None): + if self.parameter == FailingClassifier.FAILING_PARAMETER: + raise ValueError("Failing classifier failed as required") + + def predict(self, X): + return np.zeros(X.shape[0]) + + +def test_grid_search_failing_classifier(): + """GridSearchCV with on_error != 'raise' + + Ensures that a warning is raised and score reset where appropriate. + """ + + X, y = make_classification(n_samples=20, n_features=10, random_state=0) + + clf = FailingClassifier() + + # refit=False because we only want to check that errors caused by fits + # to individual folds will be caught and warnings raised instead. If + # refit was done, then an exception would be raised on refit and not + # caught by grid_search (expected behavior), and this would cause an + # error in this test. + gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', + refit=False, error_score=0.0) + + assert_warns(FitFailedWarning, gs.fit, X, y) + + # Ensure that grid scores were set to zero as required for those fits + # that are expected to fail. + assert all(np.all(this_point.cv_validation_scores == 0.0) + for this_point in gs.grid_scores_ + if this_point.parameters['parameter'] == + FailingClassifier.FAILING_PARAMETER) + + +def test_grid_search_failing_classifier_raise(): + """GridSearchCV with on_error == 'raise' raises the error""" + + X, y = make_classification(n_samples=20, n_features=10, random_state=0) + + clf = FailingClassifier() + + # refit=False because we want to test the behaviour of the grid search part + gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', + refit=False, error_score='raise') + + # FailingClassifier issues a ValueError so this is what we look for. + assert_raises(ValueError, gs.fit, X, y)