Skip to content

Commit

Permalink
Refactored lc example, removed unrelated changes, included training s…
Browse files Browse the repository at this point in the history
…coring time to score time
  • Loading branch information
H4dr1en committed May 29, 2019
1 parent 679c72b commit ea0cc71
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 178 deletions.
212 changes: 49 additions & 163 deletions examples/model_selection/plot_learning_curve.py
@@ -1,20 +1,3 @@
"""
========================
Plotting Learning Curves
========================
On the left side the learning curve of a naive Bayes classifier is shown for
the digits dataset. Note that the training score and the cross-validation score
are both not very good at the end. However, the shape of the curve can be found
in more complex datasets very often: the training score is very high at the
beginning and decreases and the cross-validation score is very low at the
beginning and increases. On the right side we see the learning curve of an SVM
with RBF kernel. We can see clearly that the training score is still around
the maximum and the validation score could be increased with more training
samples.
"""
print(__doc__)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
Expand Down Expand Up @@ -79,31 +62,59 @@ def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
be big enough to contain at least one sample from each class.
(default: np.linspace(0.1, 1.0, 5))
"""
plt.figure()
plt.title(title)
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
axes[0].set_title(title)
if ylim is not None:
plt.ylim(*ylim)
plt.xlabel("Training examples")
plt.ylabel("Score")
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
axes[0].set_ylim(*ylim)
axes[0].set_xlabel("Training examples")
axes[0].set_ylabel("Score")

train_sizes, train_scores, test_scores, fit_times, _ = \
learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs,
train_sizes=train_sizes,
return_times=True)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.grid()

plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1, color="g")
plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Training score")
plt.plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Cross-validation score")

plt.legend(loc="best")
fit_times_mean = np.mean(fit_times, axis=1)
fit_times_std = np.std(fit_times, axis=1)

# Plot learning curve
axes[0].grid()
axes[0].fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
axes[0].fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1,
color="g")
axes[0].plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Training score")
axes[0].plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Cross-validation score")
axes[0].legend(loc="best")

# Plot n_samples vs fit_times
axes[1].grid()
p = axes[1].plot(train_sizes, fit_times_mean, 'o-')
axes[1].fill_between(train_sizes, fit_times_mean - fit_times_std,
fit_times_mean + fit_times_std, alpha=0.1,
color=p[0].get_color())
axes[1].set_xlabel("Training examples")
axes[1].set_ylabel("fit_times")
axes[1].set_title("Scalability of the model")

# Plot fit_time vs score
axes[2].grid()
p = axes[2].plot(fit_times_mean, test_scores_mean, 'o-')
axes[2].fill_between(fit_times_mean,
test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std,
alpha=0.1, color=p[0].get_color())
axes[2].set_xlabel("fit_times")
axes[2].set_ylabel("Score")
axes[2].set_title("Performance of the model")

return plt


Expand All @@ -125,129 +136,4 @@ def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
estimator = SVC(gamma=0.001)
plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=4)

plt.show()

###############################################################################
# Plotting the calculation times
# ------------------------------
# On the left side, the training time of several learning curves are shown for
# various estimators. On the right side, the training time of those estimators
# are shown with their corresponding cross-validated score. Note that even if
# KNeighborsRegressor is the fastest estimator (as shown on the left side),
# SGDRegressor is the estimator giving the best score for this dataset as shown
# on the right side.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.linear_model import SGDRegressor
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import ShuffleSplit


def plot_learning_curve_times(estimators, X, y, train_sizes,
cv=None, n_jobs=None):
"""
Generate a simple plot of the test and training learning curve.
Parameters
----------
estimators : A list of object type that implement the "fit" and "predict"
methods. An object of that type which is cloned for each validation.
X : array-like, shape (n_samples, n_features)
Training vector, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape (n_samples) or (n_samples, n_features), optional
Target relative to X for classification or regression;
None for unsupervised learning.
train_sizes : array-like, shape (n_ticks,), dtype float or int
Relative or absolute numbers of training examples that will be used to
generate the learning curve. If the dtype is float, it is regarded as a
fraction of the maximum size of the training set (that is determined
by the selected validation method), i.e. it has to be within (0, 1].
Otherwise it is interpreted as absolute sizes of the training sets.
Note that for classification the number of samples usually have to
be big enough to contain at least one sample from each class.
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if ``y`` is binary or multiclass,
:class:`StratifiedKFold` used. If the estimator is not a classifier
or if ``y`` is neither binary nor multiclass, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validators that can be used here.
n_jobs : int or None, optional (default=-1)
Number of jobs to run in parallel.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
"""
fig, axes = plt.subplots(1, 2, figsize=(18, 6))

axes[0].set_title("Learning Curves computation times")
axes[0].set_xlabel("Training examples")
axes[0].set_ylabel("Fit times (s)")
axes[0].grid()

axes[1].set_title("Estimators fit times with scores")
axes[1].set_xlabel("Fit times (s)")
axes[1].set_ylabel("Cross-validation score")
axes[1].grid()

for name, estimator in estimators:
train_sizes, _, test_scores, fit_times, _ = \
learning_curve(estimator(), X, y, cv=cv, n_jobs=n_jobs,
train_sizes=train_sizes, return_times=True)

test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
fit_times_mean = np.mean(fit_times, axis=1)
fit_times_std = np.std(fit_times, axis=1)

p = axes[0].plot(train_sizes, fit_times_mean, 'o-', label=name)

axes[0].fill_between(train_sizes, fit_times_mean - fit_times_std,
fit_times_mean + fit_times_std, alpha=0.1,
color=p[0].get_color())

p = axes[1].plot(fit_times_mean, test_scores_mean, 'o-',
label=name)

axes[1].fill_between(fit_times_mean,
test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std,
alpha=0.1, color=p[0].get_color())

axes[0].legend(loc="best")
axes[1].legend(loc="best")
return fig


X, y = make_regression(n_samples=int(1e4), n_features=50, n_informative=25,
bias=-92, noise=100)

cv = ShuffleSplit(n_splits=5, test_size=0.2, random_state=0)

estimators = []
estimators.append(("SGDRegressor", SGDRegressor))
estimators.append(("KNeighborsRegressor", KNeighborsRegressor))
estimators.append(("SVR", SVR))
estimators.append(("RandomForestRegressor", RandomForestRegressor))

train_sizes = np.logspace(np.log10(1e-3), np.log10(1), 8)
fig = plot_learning_curve_times(estimators, X, y, train_sizes, cv=cv,
n_jobs=4)
fig.show()
plt.show()
12 changes: 6 additions & 6 deletions examples/plot_kernel_ridge_regression.py
Expand Up @@ -151,12 +151,12 @@

svr = SVR(kernel='rbf', C=1e1, gamma=0.1)
kr = KernelRidge(kernel='rbf', alpha=0.1, gamma=0.1)
train_sizes, train_scores_svr, test_scores_svr = learning_curve(
svr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10),
scoring="neg_mean_squared_error", cv=10)
train_sizes_abs, train_scores_kr, test_scores_kr = learning_curve(
kr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10),
scoring="neg_mean_squared_error", cv=10)
train_sizes, train_scores_svr, test_scores_svr = \
learning_curve(svr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10),
scoring="neg_mean_squared_error", cv=10)
train_sizes_abs, train_scores_kr, test_scores_kr = \
learning_curve(kr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10),
scoring="neg_mean_squared_error", cv=10)

plt.plot(train_sizes, -test_scores_svr.mean(1), 'o-', color="r",
label="SVR")
Expand Down
19 changes: 10 additions & 9 deletions sklearn/model_selection/_validation.py
Expand Up @@ -476,7 +476,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
msg = ''
else:
msg = '%s' % (', '.join('%s=%s' % (k, v)
for k, v in parameters.items()))
for k, v in parameters.items()))
print("[CV] %s %s" % (msg, (64 - len(msg)) * '.'))

# Adjust length of sample weights
Expand Down Expand Up @@ -510,10 +510,10 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
elif isinstance(error_score, numbers.Number):
if is_multimetric:
test_scores = dict(zip(scorer.keys(),
[error_score, ] * n_scorers))
[error_score, ] * n_scorers))
if return_train_score:
train_scores = dict(zip(scorer.keys(),
[error_score, ] * n_scorers))
[error_score, ] * n_scorers))
else:
test_scores = error_score
if return_train_score:
Expand Down Expand Up @@ -1198,7 +1198,7 @@ def learning_curve(estimator, X, y, groups=None,
If a numeric value is given, FitFailedWarning is raised. This parameter
does not affect the refit step, which will always raise the error.
return_times : boolean, optional, default: False
return_times : boolean, optional (default: False)
Whether to return the fit/score times.
Returns
Expand Down Expand Up @@ -1354,22 +1354,23 @@ def _incremental_fit_estimator(estimator, X, y, classes, train, test,
X_partial_train, y_partial_train = _safe_split(estimator, X, y,
partial_train)
X_test, y_test = _safe_split(estimator, X, y, test, train_subset)
start = time.time()
start_fit = time.time()
if y_partial_train is None:
estimator.partial_fit(X_partial_train, classes=classes)
else:
estimator.partial_fit(X_partial_train, y_partial_train,
classes=classes)
fit_time = time.time() - start
fit_time = time.time() - start_fit
fit_times.append(fit_time)

start_score = time.time()

test_scores.append(_score(estimator, X_test, y_test, scorer))
train_scores.append(_score(estimator, X_train, y_train, scorer))

score_time = time.time() - start - fit_time
score_time = time.time() - start_score
score_times.append(score_time)

train_scores.append(_score(estimator, X_train, y_train, scorer))

ret = (train_scores, test_scores, fit_times, score_times) \
if return_times else (train_scores, test_scores)

Expand Down

0 comments on commit ea0cc71

Please sign in to comment.