New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Add a stopping criterion in SGD, based on the score on a validation set #9043

Merged
merged 23 commits into from Jul 5, 2018

Conversation

Projects
None yet
6 participants
@TomDLT
Member

TomDLT commented Jun 7, 2017

  • Follow up #5036, which implemented a stopping criterion based on the training loss.
  • This PR implements a stopping criterion based on the prediction score on a validation set.

In both case, the stopping criterion (and the API) is identical with the one in the MLP classes:
After each epoch, we compute the validation score or the training loss. The optimization stops if there is no improvement twice in a row (i.e. the patience is hard-coded to 2).

To match MLP classes API, I added two parameters:

  • early_stopping (default False), not well named, which selects if we monitor the validation score (early_stopping=True) or the training loss (early_stopping=False)
  • validation_fraction (default 0.1), which selects the split size between training set and validation set.

I also added a new learning rate strategy, learning_rate='adaptive', as found in MLP classes:
The learning rate is kept constant, and is divided by 5 when there is no improvement twice in a row. The optimization stops when the learning rate is too small.

TODO:

  • make things work
  • add a few tests
  • benchmark the GIL access at each epoch
  • add a few lines in the doc
  • add a whats_new entry
  • merge #5036
@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Jun 7, 2017

Member

Here is a benchmark script to check the effect of accessing the GIL at each epoch.
This GIL access is used to compute the prediction score on the validation set, when early_stopping=True.

On my desktop, with n_jobs=6:

# sequential runs with single thread
9.53 s ± 49.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs without GIL access at each epoch
1.88 s ± 45.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (verbose > 0)
1.91 s ± 176 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (_validation_score)
2.01 s ± 786 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

On a small cluster, with n_jobs = 16

# sequential runs with single thread
32.2 s ± 156 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs without GIL access at each epoch
2.47 s ± 185 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (verbose > 0)
2.68 s ± 131 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (_validation_score)
3.66 s ± 64.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from IPython import get_ipython
import contextlib

import numpy as np

from sklearn.linear_model import SGDRegressor
from sklearn.datasets import make_regression
from sklearn.externals.joblib import Parallel, delayed

ipython = get_ipython()

X, y = make_regression(n_samples=10000, n_features=500, n_informative=50,
                       n_targets=1, bias=10, noise=3., random_state=42)

validation_fraction = 1. / X.shape[0]


@contextlib.contextmanager
def capture():
    import sys
    from io import StringIO
    oldout = sys.stdout
    try:
        sys.stdout = StringIO()
        yield None
    finally:
        sys.stdout = oldout


def one_run(early_stopping, verbose):
    est = SGDRegressor(validation_fraction=validation_fraction,
                       early_stopping=early_stopping,
                       max_iter=100, tol=-np.inf, shuffle=False,
                       random_state=0, verbose=verbose)
    est.fit(X, y)


n_jobs = 16


def single_thread():
    print('single_thread')
    for _ in range(n_jobs):
        one_run(False, 0)


def multi_thread(early_stopping, verbose):
    with capture():
        delayed_one_run = delayed(one_run)
        Parallel(n_jobs=n_jobs, backend='threading')(
            delayed_one_run(early_stopping, verbose)
            for _ in range(n_jobs))


ipython.magic("timeit single_thread()")
ipython.magic("timeit multi_thread(False, 0)")
ipython.magic("timeit multi_thread(False, 1)")
ipython.magic("timeit multi_thread(True, 0)")
Member

TomDLT commented Jun 7, 2017

Here is a benchmark script to check the effect of accessing the GIL at each epoch.
This GIL access is used to compute the prediction score on the validation set, when early_stopping=True.

On my desktop, with n_jobs=6:

# sequential runs with single thread
9.53 s ± 49.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs without GIL access at each epoch
1.88 s ± 45.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (verbose > 0)
1.91 s ± 176 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (_validation_score)
2.01 s ± 786 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

On a small cluster, with n_jobs = 16

# sequential runs with single thread
32.2 s ± 156 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs without GIL access at each epoch
2.47 s ± 185 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (verbose > 0)
2.68 s ± 131 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# parallel runs with GIL access at each epoch (_validation_score)
3.66 s ± 64.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from IPython import get_ipython
import contextlib

import numpy as np

from sklearn.linear_model import SGDRegressor
from sklearn.datasets import make_regression
from sklearn.externals.joblib import Parallel, delayed

ipython = get_ipython()

X, y = make_regression(n_samples=10000, n_features=500, n_informative=50,
                       n_targets=1, bias=10, noise=3., random_state=42)

validation_fraction = 1. / X.shape[0]


@contextlib.contextmanager
def capture():
    import sys
    from io import StringIO
    oldout = sys.stdout
    try:
        sys.stdout = StringIO()
        yield None
    finally:
        sys.stdout = oldout


def one_run(early_stopping, verbose):
    est = SGDRegressor(validation_fraction=validation_fraction,
                       early_stopping=early_stopping,
                       max_iter=100, tol=-np.inf, shuffle=False,
                       random_state=0, verbose=verbose)
    est.fit(X, y)


n_jobs = 16


def single_thread():
    print('single_thread')
    for _ in range(n_jobs):
        one_run(False, 0)


def multi_thread(early_stopping, verbose):
    with capture():
        delayed_one_run = delayed(one_run)
        Parallel(n_jobs=n_jobs, backend='threading')(
            delayed_one_run(early_stopping, verbose)
            for _ in range(n_jobs))


ipython.magic("timeit single_thread()")
ipython.magic("timeit multi_thread(False, 0)")
ipython.magic("timeit multi_thread(False, 1)")
ipython.magic("timeit multi_thread(True, 0)")

@TomDLT TomDLT changed the title from [WIP] Add a stopping criterion in SGD, based on the score on a validation set to [MRG] Add a stopping criterion in SGD, based on the score on a validation set Jun 14, 2017

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Jun 26, 2017

Member

Now that #5036 is merged, is this planned to be in v0.19? @ogrisel

Member

TomDLT commented Jun 26, 2017

Now that #5036 is merged, is this planned to be in v0.19? @ogrisel

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Jul 27, 2017

Member
  • Add n_iter_no_change parameter, to match GradientBoosting API
Member

TomDLT commented Jul 27, 2017

  • Add n_iter_no_change parameter, to match GradientBoosting API
@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jul 27, 2017

Member

related #9456

Member

amueller commented Jul 27, 2017

related #9456

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jul 27, 2017

Member

I say 👎 for 0.19

Member

amueller commented Jul 27, 2017

I say 👎 for 0.19

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jul 28, 2017

Member
Member

jnothman commented Jul 28, 2017

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jul 28, 2017

Member

Indeed, though we still haven't released a conda-forge package for the RC. though I guess we can move forward with the release without that.

Member

amueller commented Jul 28, 2017

Indeed, though we still haven't released a conda-forge package for the RC. though I guess we can move forward with the release without that.

TomDLT added some commits Jun 7, 2017

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Oct 12, 2017

Member

Current estimators with early stopping:

  • GradientBoosting(validation_fraction=0.1, n_iter_no_change=None, tol=1e-4)

    • n_iter_no_change=None leads to no stopping criterion.
    • n_iter_no_change!=None enables early stopping based on validation score.
  • MLPClassifier(validation_fraction=0.1, n_iter_no_change=10, early_stopping=False, tol=1e-4)

    • early_stopping=True enables early stopping based on validation score.
    • early_stopping=False uses a stopping criterion based on training loss.
    • n_iter_no_change is a parameter since #9457. It has to be an integer or inf to disable all stopping criteria.
    • to disables all stopping criteria and force max_iter, you can also use tol=-inf or tol=inf, depending on the stopping strategy.
  • SGDClassifier(tol=1e-4)

    • tol=None leads to no stopping criterion.
    • tol!=None uses a stopping criterion based on training loss.
    • n_iter_no_change is not a parameter. The equivalent value is hard coded and equal to 1.

In this PR:

  • SGDClassifier(validation_fraction=0.1, early_stopping=False, n_iter_no_change=2, tol=1e-4)
    • tol=None leads to no stopping criterion.
    • n_iter_no_change is used for both stopping strategies.
    • early_stopping=True enables early stopping based on validation score.
    • early_stopping=False uses a stopping criterion based on training loss.
Member

TomDLT commented Oct 12, 2017

Current estimators with early stopping:

  • GradientBoosting(validation_fraction=0.1, n_iter_no_change=None, tol=1e-4)

    • n_iter_no_change=None leads to no stopping criterion.
    • n_iter_no_change!=None enables early stopping based on validation score.
  • MLPClassifier(validation_fraction=0.1, n_iter_no_change=10, early_stopping=False, tol=1e-4)

    • early_stopping=True enables early stopping based on validation score.
    • early_stopping=False uses a stopping criterion based on training loss.
    • n_iter_no_change is a parameter since #9457. It has to be an integer or inf to disable all stopping criteria.
    • to disables all stopping criteria and force max_iter, you can also use tol=-inf or tol=inf, depending on the stopping strategy.
  • SGDClassifier(tol=1e-4)

    • tol=None leads to no stopping criterion.
    • tol!=None uses a stopping criterion based on training loss.
    • n_iter_no_change is not a parameter. The equivalent value is hard coded and equal to 1.

In this PR:

  • SGDClassifier(validation_fraction=0.1, early_stopping=False, n_iter_no_change=2, tol=1e-4)
    • tol=None leads to no stopping criterion.
    • n_iter_no_change is used for both stopping strategies.
    • early_stopping=True enables early stopping based on validation score.
    • early_stopping=False uses a stopping criterion based on training loss.
Merge branch 'master' into sgd_validation
Conflicts:
	doc/whats_new/v0.20.rst
	sklearn/linear_model/stochastic_gradient.py
@jnothman

This is nice work.

Is it worth adding or modifying an example to show it in action? I know we have an example for gradient boosting's early stopping.

It might be worth adding these frequent parameters to the Glossary. The myriad definitions and implementations of early stopping may also deserve a separate entry as a term.

Show outdated Hide outdated doc/modules/sgd.rst Outdated
Show outdated Hide outdated doc/whats_new/v0.20.rst Outdated
validation score is not improving by at least tol for
n_iter_no_change consecutive epochs.
.. versionadded:: 0.20

This comment has been minimized.

@jnothman

jnothman Jan 31, 2018

Member

How diligent of you to add this :)

@jnothman

jnothman Jan 31, 2018

Member

How diligent of you to add this :)

Show outdated Hide outdated sklearn/linear_model/sgd_fast.pyx Outdated
Show outdated Hide outdated sklearn/linear_model/stochastic_gradient.py Outdated
Show outdated Hide outdated sklearn/linear_model/stochastic_gradient.py Outdated
clf1 = self.factory(early_stopping=True, random_state=random_state,
validation_fraction=validation_fraction,
learning_rate='constant', eta0=0.01,
tol=None, max_iter=1000, shuffle=shuffle)

This comment has been minimized.

@jnothman

jnothman Jan 31, 2018

Member

I don't think it's clear from the documentation what early_stopping=True should do when tol=None

@jnothman

jnothman Jan 31, 2018

Member

I don't think it's clear from the documentation what early_stopping=True should do when tol=None

Show outdated Hide outdated sklearn/linear_model/tests/test_sgd.py Outdated
Show outdated Hide outdated sklearn/utils/seq_dataset.pxd Outdated
@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jan 31, 2018

Member

if I understand correctly, this currently will run an extra epoch relative to the current early_stopping=False behaviour?

Member

jnothman commented Jan 31, 2018

if I understand correctly, this currently will run an extra epoch relative to the current early_stopping=False behaviour?

TomDLT added some commits Feb 2, 2018

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Feb 2, 2018

Member

if I understand correctly, this currently will run an extra epoch relative to the current early_stopping=False behaviour?

Yes, the default is now n_iter_no_change=2, whereas the previous behavior corresponds to n_iter_no_change=1. To avoid breaking user's code, we could have the default to 1 and change it in the future....

Member

TomDLT commented Feb 2, 2018

if I understand correctly, this currently will run an extra epoch relative to the current early_stopping=False behaviour?

Yes, the default is now n_iter_no_change=2, whereas the previous behavior corresponds to n_iter_no_change=1. To avoid breaking user's code, we could have the default to 1 and change it in the future....

@scikit-learn scikit-learn deleted a comment from sklearn-lgtm Feb 2, 2018

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Feb 3, 2018

Member

Yes, the default is now n_iter_no_change=2, whereas the previous behavior corresponds to n_iter_no_change=1. To avoid breaking user's code, we could have the default to 1 and change it in the future...

Will it change anything but n_iter_ and runtime?

Member

jnothman commented Feb 3, 2018

Yes, the default is now n_iter_no_change=2, whereas the previous behavior corresponds to n_iter_no_change=1. To avoid breaking user's code, we could have the default to 1 and change it in the future...

Will it change anything but n_iter_ and runtime?

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Feb 3, 2018

Member

I suppose in the boundary case, it could change the issuance of a ConvergenceWarning...

Member

jnothman commented Feb 3, 2018

I suppose in the boundary case, it could change the issuance of a ConvergenceWarning...

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Feb 5, 2018

Member

In theory it could make convergence longer, possibly avoiding an early stop and leading to a poorer solution.


Also I realize we don't have pandas in CircleCI. My example does not desperately need it, but I wonder if we prefer to avoid pandas in examples or if I should simply add it in CircleCI?

Member

TomDLT commented Feb 5, 2018

In theory it could make convergence longer, possibly avoiding an early stop and leading to a poorer solution.


Also I realize we don't have pandas in CircleCI. My example does not desperately need it, but I wonder if we prefer to avoid pandas in examples or if I should simply add it in CircleCI?

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT
Member

TomDLT commented Feb 5, 2018

figure_1

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Feb 5, 2018

Member
Member

jnothman commented Feb 5, 2018

@jnothman

I admit it looks nicer with pandas than it would without.

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Feb 5, 2018

Member

Just as long without pandas, but slightly less readable!

Member

jnothman commented Feb 5, 2018

Just as long without pandas, but slightly less readable!

@sklearn-lgtm

This comment was marked as spam.

Show comment
Hide comment
@sklearn-lgtm

sklearn-lgtm Jun 27, 2018

This pull request introduces 1 alert when merging 28c46a3 into 3b5abf7 - view on LGTM.com

new alerts:

  • 1 for Result of integer division may be truncated

Comment posted by LGTM.com

sklearn-lgtm commented Jun 27, 2018

This pull request introduces 1 alert when merging 28c46a3 into 3b5abf7 - view on LGTM.com

new alerts:

  • 1 for Result of integer division may be truncated

Comment posted by LGTM.com

Invalidated by subsequent work

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Jun 28, 2018

Member

The bug of previous_loss vs best_loss is not present in master, since it is equivalent to n_iter_no_change=1, so the best loss is also the previous loss.
This was only a mistake in this PR.

Member

TomDLT commented Jun 28, 2018

The bug of previous_loss vs best_loss is not present in master, since it is equivalent to n_iter_no_change=1, so the best loss is also the previous loss.
This was only a mistake in this PR.

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 29, 2018

Contributor

@ogrisel any other comment?

Contributor

glemaitre commented Jun 29, 2018

@ogrisel any other comment?

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 4, 2018

Member

I still have the feeling that the current default behavior (or master and the choice of n_iter_no_change=1 in this PR) is a bug: it can very often lead to premature stopping, especially on small datasets.

If we consider that this is a bug, we can change the default to n_iter_no_change=5 instead of issuing a FutureWarning.

I would be interested in the opinion of others (maybe @jnothman @amueller ?).

Member

ogrisel commented Jul 4, 2018

I still have the feeling that the current default behavior (or master and the choice of n_iter_no_change=1 in this PR) is a bug: it can very often lead to premature stopping, especially on small datasets.

If we consider that this is a bug, we can change the default to n_iter_no_change=5 instead of issuing a FutureWarning.

I would be interested in the opinion of others (maybe @jnothman @amueller ?).

@ogrisel

ogrisel approved these changes Jul 4, 2018

Other than the decision on FutureWarning and the default value of the patience parameter, LGTM.

Show outdated Hide outdated doc/glossary.rst Outdated
@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jul 4, 2018

Member
Member

jnothman commented Jul 4, 2018

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jul 4, 2018

Member
Member

jnothman commented Jul 4, 2018

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 4, 2018

Member

I am afraid that would make the code too complicated. I think I would rather stick with the FutureWarning that is easier to understand.

Member

ogrisel commented Jul 4, 2018

I am afraid that would make the code too complicated. I think I would rather stick with the FutureWarning that is easier to understand.

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jul 4, 2018

Member
Member

jnothman commented Jul 4, 2018

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 4, 2018

Member

Here is a test script on a toy dataset:

from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_digits
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split


digits = load_digits()

for seed in range(5):
    print(f"random seed: {seed}")
    for n_iter_no_change in [1, 5]:
        model = make_pipeline(
            MinMaxScaler(),
            SGDClassifier(max_iter=1000, tol=1e-3,
                          n_iter_no_change=n_iter_no_change, random_state=seed)
        )
        X_train, X_test, y_train, y_test = train_test_split(
            digits.data, digits.target, test_size=0.2, random_state=seed)
        model.fit(X_train, y_train)
        test_acc = model.score(X_test, y_test)
        print(f"n_iter_no_change: {n_iter_no_change}, "
              f" n_iter: {model.steps[-1][1].n_iter_},"
              f" test acc: {test_acc:0.3f}")

results:

random seed: 0
n_iter_no_change: 1,  n_iter: 11, test acc: 0.922
n_iter_no_change: 5,  n_iter: 46, test acc: 0.958
random seed: 1
n_iter_no_change: 1,  n_iter: 12, test acc: 0.967
n_iter_no_change: 5,  n_iter: 64, test acc: 0.972
random seed: 2
n_iter_no_change: 1,  n_iter: 11, test acc: 0.925
n_iter_no_change: 5,  n_iter: 34, test acc: 0.922
random seed: 3
n_iter_no_change: 1,  n_iter: 12, test acc: 0.900
n_iter_no_change: 5,  n_iter: 48, test acc: 0.950
random seed: 4
n_iter_no_change: 1,  n_iter: 12, test acc: 0.953
n_iter_no_change: 5,  n_iter: 50, test acc: 0.969

As you can see n_iter_no_change=1 results in a detrimental premature stopping most of the time.

Member

ogrisel commented Jul 4, 2018

Here is a test script on a toy dataset:

from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_digits
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split


digits = load_digits()

for seed in range(5):
    print(f"random seed: {seed}")
    for n_iter_no_change in [1, 5]:
        model = make_pipeline(
            MinMaxScaler(),
            SGDClassifier(max_iter=1000, tol=1e-3,
                          n_iter_no_change=n_iter_no_change, random_state=seed)
        )
        X_train, X_test, y_train, y_test = train_test_split(
            digits.data, digits.target, test_size=0.2, random_state=seed)
        model.fit(X_train, y_train)
        test_acc = model.score(X_test, y_test)
        print(f"n_iter_no_change: {n_iter_no_change}, "
              f" n_iter: {model.steps[-1][1].n_iter_},"
              f" test acc: {test_acc:0.3f}")

results:

random seed: 0
n_iter_no_change: 1,  n_iter: 11, test acc: 0.922
n_iter_no_change: 5,  n_iter: 46, test acc: 0.958
random seed: 1
n_iter_no_change: 1,  n_iter: 12, test acc: 0.967
n_iter_no_change: 5,  n_iter: 64, test acc: 0.972
random seed: 2
n_iter_no_change: 1,  n_iter: 11, test acc: 0.925
n_iter_no_change: 5,  n_iter: 34, test acc: 0.922
random seed: 3
n_iter_no_change: 1,  n_iter: 12, test acc: 0.900
n_iter_no_change: 5,  n_iter: 48, test acc: 0.950
random seed: 4
n_iter_no_change: 1,  n_iter: 12, test acc: 0.953
n_iter_no_change: 5,  n_iter: 50, test acc: 0.969

As you can see n_iter_no_change=1 results in a detrimental premature stopping most of the time.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 4, 2018

Member

The effect is even stronger on a small dataset such as iris:

random seed: 0
n_iter_no_change: 1,  n_iter: 3, test acc: 0.600
n_iter_no_change: 5,  n_iter: 19, test acc: 0.767
random seed: 1
n_iter_no_change: 1,  n_iter: 5, test acc: 0.900
n_iter_no_change: 5,  n_iter: 18, test acc: 1.000
random seed: 2
n_iter_no_change: 1,  n_iter: 4, test acc: 0.700
n_iter_no_change: 5,  n_iter: 34, test acc: 0.833
random seed: 3
n_iter_no_change: 1,  n_iter: 5, test acc: 0.700
n_iter_no_change: 5,  n_iter: 22, test acc: 0.933
random seed: 4
n_iter_no_change: 1,  n_iter: 5, test acc: 0.833
n_iter_no_change: 5,  n_iter: 24, test acc: 0.867

Arguably, iris is probably too small for serious machine learning, especially with stochastic solvers, but still.

Member

ogrisel commented Jul 4, 2018

The effect is even stronger on a small dataset such as iris:

random seed: 0
n_iter_no_change: 1,  n_iter: 3, test acc: 0.600
n_iter_no_change: 5,  n_iter: 19, test acc: 0.767
random seed: 1
n_iter_no_change: 1,  n_iter: 5, test acc: 0.900
n_iter_no_change: 5,  n_iter: 18, test acc: 1.000
random seed: 2
n_iter_no_change: 1,  n_iter: 4, test acc: 0.700
n_iter_no_change: 5,  n_iter: 34, test acc: 0.833
random seed: 3
n_iter_no_change: 1,  n_iter: 5, test acc: 0.700
n_iter_no_change: 5,  n_iter: 22, test acc: 0.933
random seed: 4
n_iter_no_change: 1,  n_iter: 5, test acc: 0.833
n_iter_no_change: 5,  n_iter: 24, test acc: 0.867

Arguably, iris is probably too small for serious machine learning, especially with stochastic solvers, but still.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 4, 2018

Member

For completeness I have also tried on a larger dataset (covertype), and while it's expected that than a large n_iter_no_change is not necessary in that case, it does not seem to hurt test accuracy:

random seed: 0
n_iter_no_change: 1,  n_iter: 6, test acc: 0.712
n_iter_no_change: 5,  n_iter: 10, test acc: 0.709
random seed: 1
n_iter_no_change: 1,  n_iter: 6, test acc: 0.708
n_iter_no_change: 5,  n_iter: 10, test acc: 0.712
random seed: 2
n_iter_no_change: 1,  n_iter: 6, test acc: 0.708
n_iter_no_change: 5,  n_iter: 10, test acc: 0.711
random seed: 3
n_iter_no_change: 1,  n_iter: 6, test acc: 0.710
n_iter_no_change: 5,  n_iter: 10, test acc: 0.714
random seed: 4
n_iter_no_change: 1,  n_iter: 6, test acc: 0.707
n_iter_no_change: 5,  n_iter: 10, test acc: 0.709
Member

ogrisel commented Jul 4, 2018

For completeness I have also tried on a larger dataset (covertype), and while it's expected that than a large n_iter_no_change is not necessary in that case, it does not seem to hurt test accuracy:

random seed: 0
n_iter_no_change: 1,  n_iter: 6, test acc: 0.712
n_iter_no_change: 5,  n_iter: 10, test acc: 0.709
random seed: 1
n_iter_no_change: 1,  n_iter: 6, test acc: 0.708
n_iter_no_change: 5,  n_iter: 10, test acc: 0.712
random seed: 2
n_iter_no_change: 1,  n_iter: 6, test acc: 0.708
n_iter_no_change: 5,  n_iter: 10, test acc: 0.711
random seed: 3
n_iter_no_change: 1,  n_iter: 6, test acc: 0.710
n_iter_no_change: 5,  n_iter: 10, test acc: 0.714
random seed: 4
n_iter_no_change: 1,  n_iter: 6, test acc: 0.707
n_iter_no_change: 5,  n_iter: 10, test acc: 0.709
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 4, 2018

Member

Same run on covertype but using early stopping on a 10% validation split:

random seed: 0
n_iter_no_change: 1,  n_iter: 5, test acc: 0.710
n_iter_no_change: 5,  n_iter: 13, test acc: 0.711
random seed: 1
n_iter_no_change: 1,  n_iter: 6, test acc: 0.709
n_iter_no_change: 5,  n_iter: 10, test acc: 0.710
random seed: 2
n_iter_no_change: 1,  n_iter: 5, test acc: 0.710
n_iter_no_change: 5,  n_iter: 19, test acc: 0.711
random seed: 3
n_iter_no_change: 1,  n_iter: 4, test acc: 0.708
n_iter_no_change: 5,  n_iter: 22, test acc: 0.712
random seed: 4
n_iter_no_change: 1,  n_iter: 4, test acc: 0.700
n_iter_no_change: 5,  n_iter: 8, test acc: 0.708

In this case, n_iter_no_change=5 is consistently better than n_iter_no_change=1 despite the size of the dataset.

Member

ogrisel commented Jul 4, 2018

Same run on covertype but using early stopping on a 10% validation split:

random seed: 0
n_iter_no_change: 1,  n_iter: 5, test acc: 0.710
n_iter_no_change: 5,  n_iter: 13, test acc: 0.711
random seed: 1
n_iter_no_change: 1,  n_iter: 6, test acc: 0.709
n_iter_no_change: 5,  n_iter: 10, test acc: 0.710
random seed: 2
n_iter_no_change: 1,  n_iter: 5, test acc: 0.710
n_iter_no_change: 5,  n_iter: 19, test acc: 0.711
random seed: 3
n_iter_no_change: 1,  n_iter: 4, test acc: 0.708
n_iter_no_change: 5,  n_iter: 22, test acc: 0.712
random seed: 4
n_iter_no_change: 1,  n_iter: 4, test acc: 0.700
n_iter_no_change: 5,  n_iter: 8, test acc: 0.708

In this case, n_iter_no_change=5 is consistently better than n_iter_no_change=1 despite the size of the dataset.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 4, 2018

Member

@jnothman @TomDLT I will go offline before appveyor has completed. Feel free to merge when green. Based on the runs I made I am confident that n_iter_no_change=5 by default is the good / safe choice.

Member

ogrisel commented Jul 4, 2018

@jnothman @TomDLT I will go offline before appveyor has completed. Feel free to merge when green. Based on the runs I made I am confident that n_iter_no_change=5 by default is the good / safe choice.

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jul 4, 2018

Member
Member

jnothman commented Jul 4, 2018

@jnothman

@glemaitre, does this still have your +1?

Show outdated Hide outdated doc/whats_new/v0.20.rst Outdated
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 5, 2018

Member

I reported the appveyor failure here: #11438. I believe it's unrelated to this PR.

Member

ogrisel commented Jul 5, 2018

I reported the appveyor failure here: #11438. I believe it's unrelated to this PR.

@ogrisel ogrisel merged commit 0fc7ce6 into scikit-learn:master Jul 5, 2018

8 checks passed

LGTM analysis: Python No alert changes
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: python2 Your tests passed on CircleCI!
Details
ci/circleci: python3 Your tests passed on CircleCI!
Details
codecov/patch 97.11% of diff hit (target 95.3%)
Details
codecov/project Absolute coverage decreased by -<.01% but relative coverage increased by +1.8% compared to 175bedb
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jul 5, 2018

Member

Merged! Thank you very much @TomDLT!

Member

ogrisel commented Jul 5, 2018

Merged! Thank you very much @TomDLT!

@ogrisel ogrisel deleted the TomDLT:sgd_validation branch Jul 5, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment