Skip to content

Conversation

TomDLT
Copy link
Member

@TomDLT TomDLT commented Jun 7, 2017

In both case, the stopping criterion (and the API) is identical with the one in the MLP classes:
After each epoch, we compute the validation score or the training loss. The optimization stops if there is no improvement twice in a row (i.e. the patience is hard-coded to 2).

To match MLP classes API, I added two parameters:

  • early_stopping (default False), not well named, which selects if we monitor the validation score (early_stopping=True) or the training loss (early_stopping=False)
  • validation_fraction (default 0.1), which selects the split size between training set and validation set.

I also added a new learning rate strategy, learning_rate='adaptive', as found in MLP classes:
The learning rate is kept constant, and is divided by 5 when there is no improvement twice in a row. The optimization stops when the learning rate is too small.

TODO:

@TomDLT
Copy link
Member Author

TomDLT commented Jun 7, 2017

Here is a benchmark script to check the effect of accessing the GIL at each epoch.
This GIL access is used to compute the prediction score on the validation set, when early_stopping=True.

On my desktop, with n_jobs=6:

# sequential runs with single thread
9.53 s ± 49.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs without GIL access at each epoch
1.88 s ± 45.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (verbose > 0)
1.91 s ± 176 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (_validation_score)
2.01 s ± 786 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

On a small cluster, with n_jobs = 16

# sequential runs with single thread
32.2 s ± 156 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs without GIL access at each epoch
2.47 s ± 185 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (verbose > 0)
2.68 s ± 131 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (_validation_score)
3.66 s ± 64.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from IPython import get_ipython
import contextlib

import numpy as np

from sklearn.linear_model import SGDRegressor
from sklearn.datasets import make_regression
from sklearn.externals.joblib import Parallel, delayed

ipython = get_ipython()

X, y = make_regression(n_samples=10000, n_features=500, n_informative=50,
                       n_targets=1, bias=10, noise=3., random_state=42)

validation_fraction = 1. / X.shape[0]


@contextlib.contextmanager
def capture():
    import sys
    from io import StringIO
    oldout = sys.stdout
    try:
        sys.stdout = StringIO()
        yield None
    finally:
        sys.stdout = oldout


def one_run(early_stopping, verbose):
    est = SGDRegressor(validation_fraction=validation_fraction,
                       early_stopping=early_stopping,
                       max_iter=100, tol=-np.inf, shuffle=False,
                       random_state=0, verbose=verbose)
    est.fit(X, y)


n_jobs = 16


def single_thread():
    print('single_thread')
    for _ in range(n_jobs):
        one_run(False, 0)


def multi_thread(early_stopping, verbose):
    with capture():
        delayed_one_run = delayed(one_run)
        Parallel(n_jobs=n_jobs, backend='threading')(
            delayed_one_run(early_stopping, verbose)
            for _ in range(n_jobs))


ipython.magic("timeit single_thread()")
ipython.magic("timeit multi_thread(False, 0)")
ipython.magic("timeit multi_thread(False, 1)")
ipython.magic("timeit multi_thread(True, 0)")

@TomDLT TomDLT force-pushed the sgd_validation branch 3 times, most recently from cfb7bc7 to dfc29b1 Compare June 14, 2017 15:02
@TomDLT TomDLT changed the title [WIP] Add a stopping criterion in SGD, based on the score on a validation set [MRG] Add a stopping criterion in SGD, based on the score on a validation set Jun 14, 2017
@TomDLT
Copy link
Member Author

TomDLT commented Jun 26, 2017

Now that #5036 is merged, is this planned to be in v0.19? @ogrisel

@TomDLT
Copy link
Member Author

TomDLT commented Jul 27, 2017

  • Add n_iter_no_change parameter, to match GradientBoosting API

@amueller
Copy link
Member

related #9456

@amueller
Copy link
Member

I say 👎 for 0.19

@jnothman
Copy link
Member

jnothman commented Jul 28, 2017 via email

@amueller
Copy link
Member

Indeed, though we still haven't released a conda-forge package for the RC. though I guess we can move forward with the release without that.

@TomDLT
Copy link
Member Author

TomDLT commented Oct 12, 2017

Current estimators with early stopping:

  • GradientBoosting(validation_fraction=0.1, n_iter_no_change=None, tol=1e-4)

    • n_iter_no_change=None leads to no stopping criterion.
    • n_iter_no_change!=None enables early stopping based on validation score.
  • MLPClassifier(validation_fraction=0.1, n_iter_no_change=10, early_stopping=False, tol=1e-4)

    • early_stopping=True enables early stopping based on validation score.
    • early_stopping=False uses a stopping criterion based on training loss.
    • n_iter_no_change is a parameter since [MRG+1] MLPRegressor quits fitting too soon due to self._no_improvement_count #9457. It has to be an integer or inf to disable all stopping criteria.
    • to disables all stopping criteria and force max_iter, you can also use tol=-inf or tol=inf, depending on the stopping strategy.
  • SGDClassifier(tol=1e-4)

    • tol=None leads to no stopping criterion.
    • tol!=None uses a stopping criterion based on training loss.
    • n_iter_no_change is not a parameter. The equivalent value is hard coded and equal to 1.

In this PR:

  • SGDClassifier(validation_fraction=0.1, early_stopping=False, n_iter_no_change=2, tol=1e-4)
    • tol=None leads to no stopping criterion.
    • n_iter_no_change is used for both stopping strategies.
    • early_stopping=True enables early stopping based on validation score.
    • early_stopping=False uses a stopping criterion based on training loss.

Conflicts:
	doc/whats_new/v0.20.rst
	sklearn/linear_model/stochastic_gradient.py
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice work.

Is it worth adding or modifying an example to show it in action? I know we have an example for gradient boosting's early stopping.

It might be worth adding these frequent parameters to the Glossary. The myriad definitions and implementations of early stopping may also deserve a separate entry as a term.

The classes :class:`SGDClassifier` and :class:`SGDRegressor` provide two
criteria to stop the algorithm when a given level of convergence is reached:

* With ``early_stopping=True``, the input data is splitted into a training
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splitted -> split

:class:`linear_model.PassiveAggressiveRegressor` and
:class:`linear_model.Perceptron` now expose a ``early_stopping`` and
``validation_fraction`` parameters, to stop optimization monitoring the
score on a validation set. :issue:`9043` by `Tom Dupre la Tour`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add another entry for adaptive learning rate, or put it here. I'm not sure if each estimator needs to be listed. You can reference the user guide instead...?

validation score is not improving by at least tol for
n_iter_no_change consecutive epochs.

.. versionadded:: 0.20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How diligent of you to add this :)

@@ -585,6 +627,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
cdef double max_change = 0.0
cdef double max_weight = 0.0

cdef short * validation_set_ptr = <short *> validation_set.data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just as well use a typed memoryview above..?

X_train, X_val, y_train, y_val = tmp[:4]
idx_train, idx_val, sample_weight_train, sample_weight_val = tmp[4:8]

self._X_val = X_val
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we delattr these at the end of fitting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -282,6 +325,8 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
penalty_type = est._get_penalty_type(est.penalty)
learning_rate_type = est._get_learning_rate_type(learning_rate)

validation_set = est._train_validation_split(X, y, sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps validation_mask?

clf1 = self.factory(early_stopping=True, random_state=random_state,
validation_fraction=validation_fraction,
learning_rate='constant', eta0=0.01,
tol=None, max_iter=1000, shuffle=shuffle)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's clear from the documentation what early_stopping=True should do when tol=None

def test_loss_function_epsilon(self):
clf = self.factory(epsilon=0.9)
clf.set_params(epsilon=0.1)
assert clf.loss_functions['huber'][1] == 0.1

def test_early_stopping(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be using inheritance to share these?

@@ -7,6 +7,7 @@ cimport numpy as np

cdef class SequentialDataset:
cdef int current_index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three things called index here. I hope they are documented somewhere.

Also, I'm not sure if we should consider this public interface that's problematic to change... Can't we use index_data_ptr[current_index] directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I simplified the call using simply index_data_ptr[current_index].
I also added a bit of documentation there.

@jnothman
Copy link
Member

if I understand correctly, this currently will run an extra epoch relative to the current early_stopping=False behaviour?

@TomDLT
Copy link
Member Author

TomDLT commented Feb 2, 2018

if I understand correctly, this currently will run an extra epoch relative to the current early_stopping=False behaviour?

Yes, the default is now n_iter_no_change=2, whereas the previous behavior corresponds to n_iter_no_change=1. To avoid breaking user's code, we could have the default to 1 and change it in the future....

@ogrisel
Copy link
Member

ogrisel commented Jun 27, 2018

Because this PR fixes the bug of previous_loss vs best_loss, I think there is no need to add a futurewarning for n_iter_no_change: we are already changing the estimator stopping condition by fixing this bug. We just need to document the bug fix on the stopping criterion in the change log

@sklearn-lgtm

This comment has been minimized.

@jnothman jnothman dismissed their stale review June 27, 2018 23:57

Invalidated by subsequent work

@TomDLT
Copy link
Member Author

TomDLT commented Jun 28, 2018

The bug of previous_loss vs best_loss is not present in master, since it is equivalent to n_iter_no_change=1, so the best loss is also the previous loss.
This was only a mistake in this PR.

@glemaitre
Copy link
Member

@ogrisel any other comment?

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2018

I still have the feeling that the current default behavior (or master and the choice of n_iter_no_change=1 in this PR) is a bug: it can very often lead to premature stopping, especially on small datasets.

If we consider that this is a bug, we can change the default to n_iter_no_change=5 instead of issuing a FutureWarning.

I would be interested in the opinion of others (maybe @jnothman @amueller ?).

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the decision on FutureWarning and the default value of the patience parameter, LGTM.

doc/glossary.rst Outdated
``n_iter_no_change``
Number of iterations with no improvement to wait before stopping the
iterative procedure. It is typically used with :term:`early stopping` to
avoid stopping too early.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For googlability, we should mention that this parameter is also named "patience" in other libraries.

.. versionadded:: 0.20

n_iter_no_change : int, default=1
Number of iterations with no improvement to wait before early stopping.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For googlability, we should mention that this parameter is also named "patience" in other libraries.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also recommend to set it to a large enough value such as 5 or 10 to avoid premature stopping.

@jnothman
Copy link
Member

jnothman commented Jul 4, 2018 via email

@jnothman
Copy link
Member

jnothman commented Jul 4, 2018 via email

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2018

I am afraid that would make the code too complicated. I think I would rather stick with the FutureWarning that is easier to understand.

@jnothman
Copy link
Member

jnothman commented Jul 4, 2018 via email

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2018

Here is a test script on a toy dataset:

from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_digits
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split


digits = load_digits()

for seed in range(5):
    print(f"random seed: {seed}")
    for n_iter_no_change in [1, 5]:
        model = make_pipeline(
            MinMaxScaler(),
            SGDClassifier(max_iter=1000, tol=1e-3,
                          n_iter_no_change=n_iter_no_change, random_state=seed)
        )
        X_train, X_test, y_train, y_test = train_test_split(
            digits.data, digits.target, test_size=0.2, random_state=seed)
        model.fit(X_train, y_train)
        test_acc = model.score(X_test, y_test)
        print(f"n_iter_no_change: {n_iter_no_change}, "
              f" n_iter: {model.steps[-1][1].n_iter_},"
              f" test acc: {test_acc:0.3f}")

results:

random seed: 0
n_iter_no_change: 1,  n_iter: 11, test acc: 0.922
n_iter_no_change: 5,  n_iter: 46, test acc: 0.958
random seed: 1
n_iter_no_change: 1,  n_iter: 12, test acc: 0.967
n_iter_no_change: 5,  n_iter: 64, test acc: 0.972
random seed: 2
n_iter_no_change: 1,  n_iter: 11, test acc: 0.925
n_iter_no_change: 5,  n_iter: 34, test acc: 0.922
random seed: 3
n_iter_no_change: 1,  n_iter: 12, test acc: 0.900
n_iter_no_change: 5,  n_iter: 48, test acc: 0.950
random seed: 4
n_iter_no_change: 1,  n_iter: 12, test acc: 0.953
n_iter_no_change: 5,  n_iter: 50, test acc: 0.969

As you can see n_iter_no_change=1 results in a detrimental premature stopping most of the time.

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2018

The effect is even stronger on a small dataset such as iris:

random seed: 0
n_iter_no_change: 1,  n_iter: 3, test acc: 0.600
n_iter_no_change: 5,  n_iter: 19, test acc: 0.767
random seed: 1
n_iter_no_change: 1,  n_iter: 5, test acc: 0.900
n_iter_no_change: 5,  n_iter: 18, test acc: 1.000
random seed: 2
n_iter_no_change: 1,  n_iter: 4, test acc: 0.700
n_iter_no_change: 5,  n_iter: 34, test acc: 0.833
random seed: 3
n_iter_no_change: 1,  n_iter: 5, test acc: 0.700
n_iter_no_change: 5,  n_iter: 22, test acc: 0.933
random seed: 4
n_iter_no_change: 1,  n_iter: 5, test acc: 0.833
n_iter_no_change: 5,  n_iter: 24, test acc: 0.867

Arguably, iris is probably too small for serious machine learning, especially with stochastic solvers, but still.

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2018

For completeness I have also tried on a larger dataset (covertype), and while it's expected that than a large n_iter_no_change is not necessary in that case, it does not seem to hurt test accuracy:

random seed: 0
n_iter_no_change: 1,  n_iter: 6, test acc: 0.712
n_iter_no_change: 5,  n_iter: 10, test acc: 0.709
random seed: 1
n_iter_no_change: 1,  n_iter: 6, test acc: 0.708
n_iter_no_change: 5,  n_iter: 10, test acc: 0.712
random seed: 2
n_iter_no_change: 1,  n_iter: 6, test acc: 0.708
n_iter_no_change: 5,  n_iter: 10, test acc: 0.711
random seed: 3
n_iter_no_change: 1,  n_iter: 6, test acc: 0.710
n_iter_no_change: 5,  n_iter: 10, test acc: 0.714
random seed: 4
n_iter_no_change: 1,  n_iter: 6, test acc: 0.707
n_iter_no_change: 5,  n_iter: 10, test acc: 0.709

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2018

Same run on covertype but using early stopping on a 10% validation split:

random seed: 0
n_iter_no_change: 1,  n_iter: 5, test acc: 0.710
n_iter_no_change: 5,  n_iter: 13, test acc: 0.711
random seed: 1
n_iter_no_change: 1,  n_iter: 6, test acc: 0.709
n_iter_no_change: 5,  n_iter: 10, test acc: 0.710
random seed: 2
n_iter_no_change: 1,  n_iter: 5, test acc: 0.710
n_iter_no_change: 5,  n_iter: 19, test acc: 0.711
random seed: 3
n_iter_no_change: 1,  n_iter: 4, test acc: 0.708
n_iter_no_change: 5,  n_iter: 22, test acc: 0.712
random seed: 4
n_iter_no_change: 1,  n_iter: 4, test acc: 0.700
n_iter_no_change: 5,  n_iter: 8, test acc: 0.708

In this case, n_iter_no_change=5 is consistently better than n_iter_no_change=1 despite the size of the dataset.

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2018

@jnothman @TomDLT I will go offline before appveyor has completed. Feel free to merge when green. Based on the runs I made I am confident that n_iter_no_change=5 by default is the good / safe choice.

@jnothman
Copy link
Member

jnothman commented Jul 4, 2018 via email

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glemaitre, does this still have your +1?

:class:`linear_model.PassiveAggressiveClassifier`,
:class:`linear_model.PassiveAggressiveRegressor` and
:class:`linear_model.Perceptron`, where the stopping criterion was stopping
the algorithm too early. A parameter `n_iter_no_change` was added and set by
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps say "before convergence"

@ogrisel
Copy link
Member

ogrisel commented Jul 5, 2018

I reported the appveyor failure here: #11438. I believe it's unrelated to this PR.

@ogrisel ogrisel merged commit 0fc7ce6 into scikit-learn:master Jul 5, 2018
@ogrisel
Copy link
Member

ogrisel commented Jul 5, 2018

Merged! Thank you very much @TomDLT!

@ogrisel ogrisel deleted the sgd_validation branch July 5, 2018 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants