Skip to content

Commit

Permalink
Enable grid search with classifiers that may throw an error on indivi…
Browse files Browse the repository at this point in the history
…dual fits.
  • Loading branch information
Michal Romaniuk authored and jnothman committed Oct 2, 2014
1 parent 74b9563 commit 58be184
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 23 deletions.
8 changes: 8 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -69,6 +69,12 @@ Enhancements
- ``DictVectorizer`` can now perform ``fit_transform`` on an iterable in a
single pass, when giving the option ``sort=False``. By Dan Blanchard.

- :class:`GridSearchCV` and :class:`RandomizedSearchCV` can now be
configured to work with estimators that may fail and raise errors on
individual folds. This option is controlled by the `error_score`
parameter. This does not affect errors raised on re-fit. By
`Michal Romaniuk`_.


Documentation improvements
..........................
Expand Down Expand Up @@ -2985,3 +2991,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Jatin Shah: http://jatinshah.org/

.. _Dougal Sutherland: https://github.com/dougalsutherland

.. _Michal Romaniuk: https://github.com/romaniukm
48 changes: 39 additions & 9 deletions sklearn/cross_validation.py
Expand Up @@ -1150,9 +1150,13 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
return np.array(scores)[:, 0]


def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
fit_params, return_train_score=False,
return_parameters=False):
class FitFailedWarning(RuntimeWarning):
pass


def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
parameters, fit_params, return_train_score=False,
return_parameters=False, error_score='raise'):
"""Fit estimator and compute scores for a given dataset split.
Parameters
Expand Down Expand Up @@ -1180,6 +1184,12 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
verbose : integer
The verbosity level.
error_score : 'raise' (default) or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.
parameters : dict or None
Parameters to be set on the estimator.
Expand Down Expand Up @@ -1231,13 +1241,33 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,

X_train, y_train = _safe_split(estimator, X, y, train)
X_test, y_test = _safe_split(estimator, X, y, test, train)
if y_train is None:
estimator.fit(X_train, **fit_params)

try:
if y_train is None:
estimator.fit(X_train, **fit_params)
else:
estimator.fit(X_train, y_train, **fit_params)

except Exception as e:
if error_score == 'raise':
raise
elif isinstance(error_score, numbers.Number):
test_score = error_score
if return_train_score:
train_score = error_score
warnings.warn("Classifier fit failed. The score on this train-test"
" partition for these parameters will be set to %f. "
"Details: \n%r" % (error_score, e), FitFailedWarning)
else:
raise ValueError("error_score must be the string 'raise' or a"
" numeric value. (Hint: if using 'raise', please"
" make sure that it has been spelled correctly.)"
)

else:
estimator.fit(X_train, y_train, **fit_params)
test_score = _score(estimator, X_test, y_test, scorer)
if return_train_score:
train_score = _score(estimator, X_train, y_train, scorer)
test_score = _score(estimator, X_test, y_test, scorer)
if return_train_score:
train_score = _score(estimator, X_train, y_train, scorer)

scoring_time = time.time() - start_time

Expand Down
50 changes: 37 additions & 13 deletions sklearn/grid_search.py
Expand Up @@ -182,7 +182,7 @@ def __len__(self):


def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
verbose, **fit_params):
verbose, error_score='raise', **fit_params):
"""Run fit on one set of parameters.
Parameters
Expand Down Expand Up @@ -215,6 +215,11 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
**fit_params : kwargs
Additional parameter passed to the fit function of the estimator.
error_score : 'raise' (default) or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.
Returns
-------
Expand All @@ -229,7 +234,7 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
"""
score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train,
test, verbose, parameters,
fit_params)
fit_params, error_score)
return score, parameters, n_samples_test


Expand Down Expand Up @@ -284,7 +289,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
@abstractmethod
def __init__(self, estimator, scoring=None,
fit_params=None, n_jobs=1, iid=True,
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'):
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
error_score='raise'):

self.scoring = scoring
self.estimator = estimator
Expand All @@ -295,6 +301,7 @@ def __init__(self, estimator, scoring=None,
self.cv = cv
self.verbose = verbose
self.pre_dispatch = pre_dispatch
self.error_score = error_score

def score(self, X, y=None):
"""Returns the score on the given data, if the estimator has been refit
Expand Down Expand Up @@ -385,9 +392,10 @@ def _fit(self, X, y, parameter_iterable):
)(
delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
train, test, self.verbose, parameters,
self.fit_params, return_parameters=True)
for parameters in parameter_iterable
for train, test in cv)
self.fit_params, return_parameters=True,
error_score=self.error_score)
for parameters in parameter_iterable
for train, test in cv)

# Out is a list of triplet: score, estimator, n_test_samples
n_fits = len(out)
Expand Down Expand Up @@ -506,6 +514,13 @@ class GridSearchCV(BaseSearchCV):
verbose : integer
Controls the verbosity: the higher, the more messages.
error_score : 'raise' (default) or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.
Examples
--------
>>> from sklearn import svm, grid_search, datasets
Expand All @@ -515,7 +530,7 @@ class GridSearchCV(BaseSearchCV):
>>> clf = grid_search.GridSearchCV(svr, parameters)
>>> clf.fit(iris.data, iris.target)
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
GridSearchCV(cv=None,
GridSearchCV(cv=None, error_score=...,
estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=...,
degree=..., gamma=..., kernel='rbf', max_iter=-1,
probability=False, random_state=None, shrinking=True,
Expand Down Expand Up @@ -580,12 +595,14 @@ class GridSearchCV(BaseSearchCV):
"""

def __init__(self, estimator, param_grid, scoring=None,
fit_params=None, n_jobs=1, iid=True,
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'):
def __init__(self, estimator, param_grid, scoring=None, loss_func=None,
score_func=None, fit_params=None, n_jobs=1, iid=True,
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
error_score='raise'):

super(GridSearchCV, self).__init__(
estimator, scoring, fit_params, n_jobs, iid,
refit, cv, verbose, pre_dispatch)
refit, cv, verbose, pre_dispatch, error_score)
self.param_grid = param_grid
_check_param_grid(param_grid)

Expand Down Expand Up @@ -680,6 +697,12 @@ class RandomizedSearchCV(BaseSearchCV):
verbose : integer
Controls the verbosity: the higher, the more messages.
error_score : 'raise' (default) or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.
Attributes
----------
Expand Down Expand Up @@ -730,15 +753,16 @@ class RandomizedSearchCV(BaseSearchCV):

def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,
verbose=0, pre_dispatch='2*n_jobs', random_state=None):
verbose=0, pre_dispatch='2*n_jobs', random_state=None,
error_score='raise'):

self.param_distributions = param_distributions
self.n_iter = n_iter
self.random_state = random_state
super(RandomizedSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch)
pre_dispatch=pre_dispatch, error_score=error_score)

def fit(self, X, y=None):
"""Run fit on the estimator with randomly drawn parameters.
Expand Down
63 changes: 62 additions & 1 deletion sklearn/tests/test_grid_search.py
Expand Up @@ -17,6 +17,7 @@
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_false, assert_true
from sklearn.utils.testing import assert_array_equal
Expand Down Expand Up @@ -44,7 +45,7 @@
from sklearn.metrics import f1_score
from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score
from sklearn.cross_validation import KFold, StratifiedKFold
from sklearn.cross_validation import KFold, StratifiedKFold, FitFailedWarning
from sklearn.preprocessing import Imputer
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -674,3 +675,63 @@ def test_grid_search_allows_nans():
('classifier', MockClassifier()),
])
GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y)



class FailingClassifier(BaseEstimator):
"""Classifier that raises a ValueError on fit()"""

FAILING_PARAMETER = 2

def __init__(self, parameter=None):
self.parameter = parameter

def fit(self, X, y=None):
if self.parameter == FailingClassifier.FAILING_PARAMETER:
raise ValueError("Failing classifier failed as required")

def predict(self, X):
return np.zeros(X.shape[0])


def test_grid_search_failing_classifier():
"""GridSearchCV with on_error != 'raise'
Ensures that a warning is raised and score reset where appropriate.
"""

X, y = make_classification(n_samples=20, n_features=10, random_state=0)

clf = FailingClassifier()

# refit=False because we only want to check that errors caused by fits
# to individual folds will be caught and warnings raised instead. If
# refit was done, then an exception would be raised on refit and not
# caught by grid_search (expected behavior), and this would cause an
# error in this test.
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
refit=False, error_score=0.0)

assert_warns(FitFailedWarning, gs.fit, X, y)

# Ensure that grid scores were set to zero as required for those fits
# that are expected to fail.
assert all(np.all(this_point.cv_validation_scores == 0.0)
for this_point in gs.grid_scores_
if this_point.parameters['parameter'] ==
FailingClassifier.FAILING_PARAMETER)


def test_grid_search_failing_classifier_raise():
"""GridSearchCV with on_error == 'raise' raises the error"""

X, y = make_classification(n_samples=20, n_features=10, random_state=0)

clf = FailingClassifier()

# refit=False because we want to test the behaviour of the grid search part
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
refit=False, error_score='raise')

# FailingClassifier issues a ValueError so this is what we look for.
assert_raises(ValueError, gs.fit, X, y)

0 comments on commit 58be184

Please sign in to comment.