Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Clone estimator for each parameter value in validation_curve #9119

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ Bug fixes
classes, and some values proposed in the docstring could raise errors.
:issue:`5359` by `Tom Dupre la Tour`_.

- Fixed a bug where :func:`model_selection.validation_curve`
reused the same estimator for each parameter value.
:issue:`7365` by `Aleksandr Sandrovskii <Sundrique>`.

API changes summary
-------------------

Expand Down
2 changes: 1 addition & 1 deletion sklearn/learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None,
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
verbose=verbose)
out = parallel(delayed(_fit_and_score)(
estimator, X, y, scorer, train, test, verbose,
clone(estimator), X, y, scorer, train, test, verbose,
parameters={param_name: v}, fit_params=None, return_train_score=True)
for train, test in cv for v in param_range)

Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
verbose=verbose)
out = parallel(delayed(_fit_and_score)(
estimator, X, y, scorer, train, test, verbose,
clone(estimator), X, y, scorer, train, test, verbose,
parameters={param_name: v}, fit_params=None, return_train_score=True)
# NOTE do not change order of iteration to allow one time cv splitters
for train, test in cv.split(X, y, groups) for v in param_range)
Expand Down
27 changes: 27 additions & 0 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,21 @@ def _is_training_data(self, X):
return X is self.X_subset


class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
"""Dummy classifier that disallows repeated calls of fit method"""

def fit(self, X_subset, y_subset):
assert_false(
hasattr(self, 'fit_called_'),
'fit is called the second time'
)
self.fit_called_ = True
return super(type(self), self).fit(X_subset, y_subset)

def predict(self, X):
raise NotImplementedError


class MockClassifier(object):
"""Dummy classifier to test the cross-validation"""

Expand Down Expand Up @@ -852,6 +867,18 @@ def test_validation_curve():
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)


def test_validation_curve_clone_estimator():
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
n_clusters_per_class=1, random_state=0)

param_range = np.linspace(1, 0, 10)
_, _ = validation_curve(
MockEstimatorWithSingleFitCallAllowed(), X, y,
param_name="param", param_range=param_range, cv=2
)


def test_validation_curve_cv_splits_consistency():
n_samples = 100
n_splits = 5
Expand Down
25 changes: 25 additions & 0 deletions sklearn/tests/test_learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_false
from sklearn.datasets import make_classification

with warnings.catch_warnings():
Expand Down Expand Up @@ -93,6 +94,18 @@ def score(self, X=None, y=None):
return None


class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
"""Dummy classifier that disallows repeated calls of fit method"""

def fit(self, X_subset, y_subset):
assert_false(
hasattr(self, 'fit_called_'),
'fit is called the second time'
)
self.fit_called_ = True
return super(type(self), self).fit(X_subset, y_subset)


def test_learning_curve():
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
Expand Down Expand Up @@ -284,3 +297,15 @@ def test_validation_curve():

assert_array_almost_equal(train_scores.mean(axis=1), param_range)
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)


def test_validation_curve_clone_estimator():
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
n_clusters_per_class=1, random_state=0)

param_range = np.linspace(1, 0, 10)
_, _ = validation_curve(
MockEstimatorWithSingleFitCallAllowed(), X, y,
param_name="param", param_range=param_range, cv=2
)