Skip to content

Commit

Permalink
Merge 94eda9c into e2730fa
Browse files Browse the repository at this point in the history
  • Loading branch information
richford committed Oct 28, 2020
2 parents e2730fa + 94eda9c commit d69e39d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
80 changes: 59 additions & 21 deletions afqinsight/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def cross_validate_checkpoint(
workdir=None,
checkpoint=True,
force_refresh=False,
serialize_cv=False,
):
"""Evaluate metric(s) by cross-validation and also record fit/score times.
Expand Down Expand Up @@ -293,6 +294,20 @@ def cross_validate_checkpoint(
If a numeric value is given, FitFailedWarning is raised. This parameter
does not affect the refit step, which will always raise the error.
workdir : path-like object, default=None
A string or :term:`python:path-like-object` indicating the directory
in which to store checkpoint files
checkpoint : bool, default=True
If True, checkpoint the parameters, estimators, and scores.
force_refresh : bool, default=False
If True, recompute scores even if the checkpoint file already exists.
Otherwise, load scores from checkpoint files and return.
serialize_cv : bool, default=False
If True, do not use joblib.Parallel to evaluate each CV split.
Returns
-------
scores : dict of float arrays of shape (n_splits,)
Expand Down Expand Up @@ -380,28 +395,51 @@ def cross_validate_checkpoint(

# We clone the estimator to make sure that all the folds are
# independent, and that it is pickle-able.
parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch)
scores = parallel(
delayed(_fit_and_score_ckpt)(
workdir=workdir,
checkpoint=checkpoint,
force_refresh=force_refresh,
estimator=clone(estimator),
X=X,
y=y,
scorer=scorers,
train=train,
test=test,
verbose=verbose,
parameters=None,
fit_params=fit_params,
return_train_score=return_train_score,
return_times=True,
return_estimator=return_estimator,
error_score=error_score,
if serialize_cv:
scores = [
_fit_and_score_ckpt(
workdir=workdir,
checkpoint=checkpoint,
force_refresh=force_refresh,
estimator=clone(estimator),
X=X,
y=y,
scorer=scorers,
train=train,
test=test,
verbose=verbose,
parameters=None,
fit_params=fit_params,
return_train_score=return_train_score,
return_times=True,
return_estimator=return_estimator,
error_score=error_score,
)
for train, test in cv.split(X, y, groups)
]
else:
parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch)
scores = parallel(
delayed(_fit_and_score_ckpt)(
workdir=workdir,
checkpoint=checkpoint,
force_refresh=force_refresh,
estimator=clone(estimator),
X=X,
y=y,
scorer=scorers,
train=train,
test=test,
verbose=verbose,
parameters=None,
fit_params=fit_params,
return_train_score=return_train_score,
return_times=True,
return_estimator=return_estimator,
error_score=error_score,
)
for train, test in cv.split(X, y, groups)
)
for train, test in cv.split(X, y, groups)
)

zipped_scores = list(zip(*scores))
if return_train_score:
Expand Down
4 changes: 3 additions & 1 deletion afqinsight/tests/test_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


@pytest.mark.parametrize("return_estimator", [True, False])
def test_cross_validate(return_estimator):
@pytest.mark.parametrize("serialize_cv", [True, False])
def test_cross_validate(return_estimator, serialize_cv):
diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]
Expand All @@ -30,6 +31,7 @@ def test_cross_validate(return_estimator):
checkpoint=True,
workdir=tempdir,
return_estimator=return_estimator,
serialize_cv=serialize_cv,
)
cv_files_1 = os.listdir(tempdir)

Expand Down

0 comments on commit d69e39d

Please sign in to comment.