Skip to content

Commit

Permalink
ENH Expose errors_score in cross_validate and cross_val_score (#11583)
Browse files Browse the repository at this point in the history
  • Loading branch information
samronsin authored and jnothman committed Jul 27, 2018
1 parent e3984fb commit 6b38970
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 20 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,13 @@ Model evaluation and meta-estimators
hyperparameter optimization and refitting the best model on the whole
dataset. :issue:`11310` by :user:`Matthias Feurer <mfeurer>`.

- Expose `error_score` parameter in :func:`model_selection.cross_validate`,
:func:`model_selection.cross_val_score`,
:func:`model_selection.learning_curve` and
:func:`model_selection.validation_curve` to control the behavior triggered
when an error occurs in :func:`model_selection._fit_and_score`.
:issue:`11576` by :user:`Samuel O. Ronsin <samronsin>`.

Decomposition and manifold learning

- Speed improvements for both 'exact' and 'barnes_hut' methods in
Expand Down
75 changes: 61 additions & 14 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv='warn',
n_jobs=1, verbose=0, fit_params=None,
pre_dispatch='2*n_jobs', return_train_score="warn",
return_estimator=False):
return_estimator=False, error_score='raise-deprecating'):
"""Evaluate metric(s) by cross-validation and also record fit/score times.
Read more in the :ref:`User Guide <multimetric_cross_validation>`.
Expand Down Expand Up @@ -139,6 +139,16 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv='warn',
return_estimator : boolean, default False
Whether to return the estimators fitted on each split.
error_score : 'raise' | 'raise-deprecating' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised.
If set to 'raise-deprecating', a FutureWarning is printed before the
error is raised.
If a numeric value is given, FitFailedWarning is raised. This parameter
does not affect the refit step, which will always raise the error.
Default is 'raise-deprecating' but from version 0.22 it will change
to np.nan.
Returns
-------
scores : dict of float arrays of shape=(n_splits,)
Expand Down Expand Up @@ -223,7 +233,8 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv='warn',
delayed(_fit_and_score)(
clone(estimator), X, y, scorers, train, test, verbose, None,
fit_params, return_train_score=return_train_score,
return_times=True, return_estimator=return_estimator)
return_times=True, return_estimator=return_estimator,
error_score=error_score)
for train, test in cv.split(X, y, groups))

zipped_scores = list(zip(*scores))
Expand Down Expand Up @@ -262,7 +273,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv='warn',

def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv='warn',
n_jobs=1, verbose=0, fit_params=None,
pre_dispatch='2*n_jobs'):
pre_dispatch='2*n_jobs', error_score='raise-deprecating'):
"""Evaluate a score by cross-validation
Read more in the :ref:`User Guide <cross_validation>`.
Expand Down Expand Up @@ -335,6 +346,16 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv='warn',
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'
error_score : 'raise' | 'raise-deprecating' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised.
If set to 'raise-deprecating', a FutureWarning is printed before the
error is raised.
If a numeric value is given, FitFailedWarning is raised. This parameter
does not affect the refit step, which will always raise the error.
Default is 'raise-deprecating' but from version 0.22 it will change
to np.nan.
Returns
-------
scores : array of float, shape=(len(list(cv)),)
Expand Down Expand Up @@ -373,7 +394,8 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv='warn',
return_train_score=False,
n_jobs=n_jobs, verbose=verbose,
fit_params=fit_params,
pre_dispatch=pre_dispatch)
pre_dispatch=pre_dispatch,
error_score=error_score)
return cv_results['test_score']


Expand Down Expand Up @@ -415,12 +437,15 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
verbose : integer
The verbosity level.
error_score : 'raise' or numeric
error_score : 'raise' | 'raise-deprecating' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error. Default is 'raise' but from
version 0.22 it will change to np.nan.
If set to 'raise', the error is raised.
If set to 'raise-deprecating', a FutureWarning is printed before the
error is raised.
If a numeric value is given, FitFailedWarning is raised. This parameter
does not affect the refit step, which will always raise the error.
Default is 'raise-deprecating' but from version 0.22 it will change
to np.nan.
parameters : dict or None
Parameters to be set on the estimator.
Expand Down Expand Up @@ -1045,7 +1070,7 @@ def learning_curve(estimator, X, y, groups=None,
train_sizes=np.linspace(0.1, 1.0, 5), cv='warn',
scoring=None, exploit_incremental_learning=False, n_jobs=1,
pre_dispatch="all", verbose=0, shuffle=False,
random_state=None):
random_state=None, error_score='raise-deprecating'):
"""Learning curve.
Determines cross-validated training and test scores for different training
Expand Down Expand Up @@ -1136,6 +1161,16 @@ def learning_curve(estimator, X, y, groups=None,
If None, the random number generator is the RandomState instance used
by `np.random`. Used when ``shuffle`` is True.
error_score : 'raise' | 'raise-deprecating' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised.
If set to 'raise-deprecating', a FutureWarning is printed before the
error is raised.
If a numeric value is given, FitFailedWarning is raised. This parameter
does not affect the refit step, which will always raise the error.
Default is 'raise-deprecating' but from version 0.22 it will change
to np.nan.
Returns
-------
train_sizes_abs : array, shape (n_unique_ticks,), dtype int
Expand Down Expand Up @@ -1194,8 +1229,9 @@ def learning_curve(estimator, X, y, groups=None,
train_test_proportions.append((train[:n_train_samples], test))

out = parallel(delayed(_fit_and_score)(
clone(estimator), X, y, scorer, train, test,
verbose, parameters=None, fit_params=None, return_train_score=True)
clone(estimator), X, y, scorer, train, test, verbose,
parameters=None, fit_params=None, return_train_score=True,
error_score=error_score)
for train, test in train_test_proportions)
out = np.array(out)
n_cv_folds = out.shape[0] // n_unique_ticks
Expand Down Expand Up @@ -1288,7 +1324,7 @@ def _incremental_fit_estimator(estimator, X, y, classes, train, test,

def validation_curve(estimator, X, y, param_name, param_range, groups=None,
cv='warn', scoring=None, n_jobs=1, pre_dispatch="all",
verbose=0):
verbose=0, error_score='raise-deprecating'):
"""Validation curve.
Determine training and test scores for varying parameter values.
Expand Down Expand Up @@ -1359,6 +1395,16 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
verbose : integer, optional
Controls the verbosity: the higher, the more messages.
error_score : 'raise' | 'raise-deprecating' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised.
If set to 'raise-deprecating', a FutureWarning is printed before the
error is raised.
If a numeric value is given, FitFailedWarning is raised. This parameter
does not affect the refit step, which will always raise the error.
Default is 'raise-deprecating' but from version 0.22 it will change
to np.nan.
Returns
-------
train_scores : array, shape (n_ticks, n_cv_folds)
Expand All @@ -1381,7 +1427,8 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
verbose=verbose)
out = parallel(delayed(_fit_and_score)(
clone(estimator), X, y, scorer, train, test, verbose,
parameters={param_name: v}, fit_params=None, return_train_score=True)
parameters={param_name: v}, fit_params=None, return_train_score=True,
error_score=error_score)
# NOTE do not change order of iteration to allow one time cv splitters
for train, test in cv.split(X, y, groups) for v in param_range)
out = np.asarray(out)
Expand Down
3 changes: 3 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,9 @@ def fit(self, X, y=None):
def predict(self, X):
return np.zeros(X.shape[0])

def score(self, X=None, Y=None):
return 0.


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
def test_grid_search_failing_classifier():
Expand Down
32 changes: 26 additions & 6 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,7 @@ def get_params(self, deep=False):
P_sparse = coo_matrix(np.eye(5))


@pytest.mark.filterwarnings('ignore: From version 0.22, errors during fit')
@pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22
# FIXME issue in error_score parameter
def test_cross_val_score():
clf = MockClassifier()

Expand Down Expand Up @@ -279,7 +277,8 @@ def test_cross_val_score():
scores = cross_val_score(clf, X_3d, y2)

clf = MockClassifier(allow_nd=False)
assert_raises(ValueError, cross_val_score, clf, X_3d, y2)
assert_raises(ValueError, cross_val_score, clf, X_3d, y2,
error_score='raise')


@pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22
Expand Down Expand Up @@ -1171,8 +1170,6 @@ def test_learning_curve_with_boolean_indices():
np.linspace(0.1, 1.0, 10))


@pytest.mark.filterwarnings('ignore: From version 0.22, errors during fit')
# FIXME this is an error in the error_score change!
def test_learning_curve_with_shuffle():
# Following test case was designed this way to verify the code
# changes made in pull request: #7506.
Expand All @@ -1195,7 +1192,8 @@ def test_learning_curve_with_shuffle():
assert_array_almost_equal(test_scores_batch.mean(axis=1),
np.array([0.36111111, 0.25, 0.25]))
assert_raises(ValueError, learning_curve, estimator, X, y, cv=cv, n_jobs=1,
train_sizes=np.linspace(0.3, 1.0, 3), groups=groups)
train_sizes=np.linspace(0.3, 1.0, 3), groups=groups,
error_score='raise')

train_sizes_inc, train_scores_inc, test_scores_inc = learning_curve(
estimator, X, y, cv=cv, n_jobs=1, train_sizes=np.linspace(0.3, 1.0, 3),
Expand Down Expand Up @@ -1477,6 +1475,7 @@ def test_fit_and_score():
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
# dummy X data
X = np.arange(1, 10)
y = np.ones(9)
fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0,
None, None]
# passing error score to trigger the warning message
Expand Down Expand Up @@ -1510,3 +1509,24 @@ def test_fit_and_score():
assert_raise_message(ValueError, "Failing classifier failed as required",
_fit_and_score, *fit_and_score_args,
**fit_and_score_kwargs)

# check that functions upstream pass error_score param to _fit_and_score
error_message = ("error_score must be the string 'raise' or a"
" numeric value. (Hint: if using 'raise', please"
" make sure that it has been spelled correctly.)")

assert_raise_message(ValueError, error_message, cross_validate,
failing_clf, X, cv=3, error_score='unvalid-string')

assert_raise_message(ValueError, error_message, cross_val_score,
failing_clf, X, cv=3, error_score='unvalid-string')

assert_raise_message(ValueError, error_message, learning_curve,
failing_clf, X, y, cv=3, error_score='unvalid-string')

assert_raise_message(ValueError, error_message, validation_curve,
failing_clf, X, y, 'parameter',
[FailingClassifier.FAILING_PARAMETER], cv=3,
error_score='unvalid-string')

assert_equal(failing_clf.score(), 0.) # FailingClassifier coverage

0 comments on commit 6b38970

Please sign in to comment.