Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] Deprecate n_iter in SGDClassifier and implement max_iter #5036

Merged
merged 8 commits into from Jun 23, 2017

Conversation

@TomDLT
Copy link
Member

@TomDLT TomDLT commented Jul 27, 2015

Solve #5022
In SGDClassifier, SGDRegressor, Perceptron, PassiveAgressive:

  • Deprecate n_iter. Default is now None. If not None, it warns and sets max_iter = n_iter and tol = 0, to have exact previous behavior.
  • Implement max_iter and tol. The stopping criterion in sgd_fast._plain_sgd() is identical to the one in SAG new solver for Ridge and LogisticRegression.
  • Add self.n_iter_ after the fit. For multiclass classifiers, we keep the maximum n_iter_ over all binary (OvA) fits.
@TomDLT TomDLT force-pushed the TomDLT:sgd_maxiter branch 2 times, most recently from 8a5f890 to 43d931b Jul 27, 2015
@@ -700,7 +728,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,

w.reset_wscale()

This comment has been minimized.

@ogrisel

ogrisel Aug 17, 2015
Member

Please raise a ConvergenceWarning with an informative message if max_iter == epoch + 1.

This comment has been minimized.

@TomDLT

TomDLT Sep 7, 2015
Author Member

It would raise a warning for each partial_fit (which has max_iter=1).
Instead I suggest to raise it in the _fit method.

This comment has been minimized.

@ogrisel

ogrisel Sep 8, 2015
Member

You could disable the convergence warning by passing tol=0 as a local variable only when called from partial_fit while passing tol=self.tol when called from fit.

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Aug 17, 2015

About your remark in #5022, you suggested that we could avoid deprecating n_iter and instead leave it to None by default so that the max_iter + tol criterion would be used by default (issue a ConvergenceWarning if tol is not reached before max_iter) while still leaving the user the ability to ask for a specific n_iter passes over the data in which case max_iter and tol are ignored and ConvergenceWarning is never raised.

I am +0 for this convenience feature. @amueller do you have an opinion in this regard?

@amueller
Copy link
Member

@amueller amueller commented Aug 18, 2015

I'm -0 ;)

Do we want a validation set for SGDClassifier in the future? And if so, do we introduce this using a validation=boolean variable? This is a step in the right direction but we should think about where this is going.

@TomDLT TomDLT force-pushed the TomDLT:sgd_maxiter branch from 43d931b to 5310f67 Sep 7, 2015
@TomDLT
Copy link
Member Author

@TomDLT TomDLT commented Sep 7, 2015

you suggested we could avoid deprecating n_iter...

I though it as deprecation, a temporary way not to break any user code, before removing completely n_iter.
I am not sure to understand what would be the classic deprecation.

Do we want a validation set for SGDClassifier in the future?...

Do you mean to give a validation set with a performance goal as a stopping criterion, directly in the SGD solver?

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Sep 8, 2015

I though it as deprecation, a temporary way not to break any user code, before removing completely n_iter.
I am not sure to understand what would be the classic deprecation.

Classic deprecation is what you did in this PR: raise a DeprecationWarning now while still behaving the same if the user is passing a n_iter != None and remove the n_iter kwarg completely in 2 releases.

Do we want a validation set for SGDClassifier in the future?...

Do you mean to give a validation set with a performance goal as a stopping criterion, directly in the SGD solver?

Yes early stopping on the lack of improvement as measured on a validation set. The validation set can be specified as number between 0 and 1 (typically 0.1 by default) and the model extracts internally in the fit method by randomly splitting the user provided data into train and validation folds. But this is outside of the scope of this PR.

verbose=0, loss="hinge", n_jobs=1, random_state=None,
warm_start=False, class_weight=None):

def __init__(self, C=1.0, fit_intercept=True, max_iter=5, tol=1e-4,

This comment has been minimized.

@ogrisel

ogrisel Sep 8, 2015
Member

Now that we have a good stopping criterion, I think we should set max_iter=100 by default and expect the stopping criterion to kick in before that in 99% of the cases.

This comment has been minimized.

@amueller

amueller Sep 8, 2015
Member

On 09/08/2015 02:33 AM, Olivier Grisel wrote:

Now that we have a good stopping criterion, I think we should set
|max_iter=100| by default and expect the stopping criterion to kick in
before that in 99% of the cases.

Did someone do experiments on how well that works in practice?

The number of passes over the training data (aka epochs).
max_iter : int, optional
The maximum number of passes over the training data (aka epochs).
The maximum number of iterations is set to 1 if using partial_fit.

This comment has been minimized.

@ogrisel

ogrisel Sep 8, 2015
Member

I would rather say that this parameter only impacts the behavior of the fit method, not the partial_fit method.

def __init__(self, C=1.0, fit_intercept=True, n_iter=5, shuffle=True,
verbose=0, loss="epsilon_insensitive",
epsilon=DEFAULT_EPSILON, random_state=None, warm_start=False):
def __init__(self, C=1.0, fit_intercept=True, max_iter=5, tol=1e-4,

This comment has been minimized.

@ogrisel

ogrisel Sep 8, 2015
Member

max_iter=100 as well here.

@@ -73,6 +76,13 @@ def __init__(self, loss, penalty='l2', alpha=0.0001, C=1.0,
self.warm_start = warm_start
self.average = average

if n_iter is not None:
warnings.warn("n_iter parameter is deprecated and will be removed"

This comment has been minimized.

@ogrisel

ogrisel Sep 8, 2015
Member

It's better to be very explicit: "n_iter parameter is deprecated in 0.17 and will be removed in 0.19. ..."

classes, sample_weight, coef_init, intercept_init)

if self.n_iter_ == self.max_iter:

This comment has been minimized.

@ogrisel

ogrisel Sep 8, 2015
Member

Please change this test to:

if self.tol > 0 and self.n_iter_ == self.max_iter:

so that the user can disable the ConvergenceWarning when he/she decide to always perform max_iter iterations intentionally (effectively disabling the stopping criterion).

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Sep 8, 2015

The handling of partial_fit is wrong. The n_iter_ attribute should not necessarily be defined when the users calls only partial_fit to train the model as the model is not aw

>>> from sklearn.datasets import load_boston
>>> from sklearn.linear_model import SGDRegressor
>>> from sklearn.utils import gen_batches
>>> boston = load_boston()
>>> n_samples, n_features = boston.data.shape
>>> n_samples, n_features
(506, 13)
>>> all_batches = list(gen_batches(n_samples, 100))
>>> m = SGDRegressor(max_iter=2)
>>> for batch in all_batches:
...     m.fit(boston.data[batch], boston.target[batch])
...
>>> m.t_
13.0

m.t_ should be the total number of samples has seen since the beginning of the incremental fit. Iterating over all_batches should be equivalent to performing one pass over the data (one epoch) and therefore m.t_ should be n_samples + 1 = 507 instead of currently: max_iter * len(all_batches) + 1.

In particular, calling

m = SGDRegressor(max_iter=1, tol=0, shuffle=False, random_state=0)
m.fit(boston.data, boston.target)

should be equivalent (same m.t_, m.coef_ and m.intercept_) to:

# max_iter should not impact incremental fitting at all
m = SGDRegressor(max_iter=42, shuffle=False, random_state=0)
for batch in all_batches:
    m.partial_fit(boston.data[batch], boston.target[batch])

m.t_ should be n_samples + 1 in both cases.

Furthermore:

m = SGDRegressor(max_iter=10, tol=0, shuffle=False, random_state=0)
m.fit(boston.data, boston.target)

should be equivalent to:

m = SGDRegressor(max_iter=42, shuffle=False, random_state=0)
for i in 10:
    for batch in all_batches:
        m.partial_fit(boston.data[batch], boston.target[batch])

m.t_ should be 10 * n_samples + 1 in both cases.

The fact that the tests do not fail means that partial_fit testing is not strong enough for SGDRegressor (and probably SGDClassifier and related models).

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Sep 8, 2015

Also for some reason I don't get the expected convergence warning when I do:

>>> SGDRegressor(max_iter=1, tol=1e-15).fit(boston.data, boston.target)
@ogrisel
Copy link
Member

@ogrisel ogrisel commented Sep 8, 2015

On the other hand, I would expect the following model to converge earlier (n_iter_ should be smaller than max_iter):

>>> SGDRegressor(max_iter=10000, tol=1e-2).fit(boston.data, boston.target).n_iter_
10000
@TomDLT
Copy link
Member Author

@TomDLT TomDLT commented Sep 8, 2015

Edit: solved by scaling the data (#5036 (comment))


Thanks for the review.

About your last comment, it comes from the fact that the SGD converges quite slowly with the boston dataset, as you can see in the following plot (tested on master):

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_boston
from sklearn.linear_model import SGDRegressor

boston = load_boston()
X, y = boston.data, boston.target
n_features = X.shape[1]

iter_range = np.arange(1, 11) * 1000
coefs = np.zeros((n_features, iter_range.size))

for i, n_iter in enumerate(iter_range):
    reg = SGDRegressor(n_iter=n_iter).fit(X, y)
    coefs[:, i] = reg.coef_

for i in range(n_features):
    plt.plot(iter_range, coefs[i, :])

plt.xlabel("n_iter")
plt.ylabel("coefs")
plt.show()

unstable_boston


I changed the default max_iter, I added the ConvergenceWarning in SGDRegressor, and I am looking into your comment about t_ and partial_fit.

@TomDLT
Copy link
Member Author

@TomDLT TomDLT commented Sep 8, 2015

Actually I think the code is OK about t_.

>>>all_batches
[slice(0, 100, None),
 slice(100, 200, None),
 slice(200, 300, None),
 slice(300, 400, None),
 slice(400, 500, None),
 slice(500, 506, None)]
>>> for batch in all_batches:
...     m.fit(boston.data[batch], boston.target[batch])
...
>>> m.t_
13.0

You call fit on the last slice (with only 6 samples), with max_iter=2, so it is normal that m.t_= 2 * 6 +1. Warm starting is not activated by default, and I guess we want t_ to restart even in the case of warm starting.

With partial_fit, t_ seems correct :

from sklearn.datasets import load_boston
from sklearn.linear_model import SGDRegressor
from sklearn.utils import gen_batches

boston = load_boston()
n_samples, n_features = boston.data.shape
all_batches = list(gen_batches(n_samples, 100))

for max_iter in range(1, 11):
    # one full pass with fit
    m1 = SGDRegressor(max_iter=max_iter, tol=0, shuffle=False, random_state=0)
    m1.fit(boston.data, boston.target)

    # batches with partial_fit
    m2 = SGDRegressor(max_iter=42, shuffle=False, random_state=0)
    for _ in range(max_iter):
        for batch in all_batches:
            m2.partial_fit(boston.data[batch], boston.target[batch])

    print("%d, %f" % (m1.t_, m1.coef_[1]))
    print("%d, %f" % (m2.t_, m2.coef_[1]))

gives

507, 216965230251.886780
507, 216965230251.886414
1013, 891040660054.554077
1013, 891040660054.553955
1519, 724458977715.773315
1519, 724458977715.773071
2025, 753346099309.831177
2025, 753346099309.830688
2531, 836014190413.200317
2531, 836014190413.200317
3037, 942892124924.625488
3037, 764967807987.214600
3543, 365984307776.865479
3543, 498725629682.746338
4049, 81118161082.128448
4049, 575624185079.631470
4555, 47289103200.144722
4555, 618775289008.635620
5061, -32564036946.908470
5061, 266499003140.677155

partial_fit and fit give almost the same results. I don't understand why they are not perfectly identical. SGD is very unstable on Boston dataset (cf. previous plot), which explains the large difference with more iterations. This behavior was already present in master branch.

About n_iter_ with partial_fit, is it really misleading to let it at 1?

@TomDLT TomDLT force-pushed the TomDLT:sgd_maxiter branch 3 times, most recently from ebd2a42 to 7e30cff Sep 8, 2015
@amueller
Copy link
Member

@amueller amueller commented Sep 8, 2015

for default parameters of SGD to work, the data needs to be scaled.
Maybe we should warn / add this to SGDClassifier and SGDRegressor?

@amueller
Copy link
Member

@amueller amueller commented Sep 8, 2015

Tag 0.17 @ogrisel ? Not sure if it's ready.

@amueller
Copy link
Member

@amueller amueller commented Sep 8, 2015

I'm not sure if we should do this change simultaneously with a default scaling change?

@TomDLT TomDLT force-pushed the TomDLT:sgd_maxiter branch from 7e30cff to 9c4c73e Sep 9, 2015
@TomDLT
Copy link
Member Author

@TomDLT TomDLT commented Sep 10, 2015

for default parameters of SGD to work, the data needs to be scaled.

Indeed, adding a StandardScaler step solves the unstable behavior in previous plot. And then fit and partial_fit do have the exact same values. Thanks

I'm not sure if we should do this change simultaneously with a default scaling change?

The problem of scaling is rather not linked to n_iter and max_iter.
I would suggest to make a different PR.


I think this PR is OK, except if I missed your point @ogrisel

@amueller amueller added this to the 0.17 milestone Sep 10, 2015
@TomDLT TomDLT force-pushed the TomDLT:sgd_maxiter branch from 9c4c73e to b7b5ffc Sep 11, 2015
@amueller
Copy link
Member

@amueller amueller commented Sep 11, 2015

Hm this is breaking behavior, right?
Shouldn't we add tol=None and warn that it'll change in the future?

@TomDLT
Copy link
Member Author

@TomDLT TomDLT commented Sep 11, 2015

SGD(n_iter=n_iter)'s behavior is not changed and raise a DeprecationWarning.
SGD()'s behavior is changed however.

It might be better to warn a future change.
But if we add tol=None we need to change back to max_iter=5.

# When n_iter=None, and at least one of tol and max_iter is specified
assert_no_warnings(init, 100, None, None)
assert_no_warnings(init, None, 1e-3, None)
assert_no_warnings(init, 100, 1e-3, None)

This comment has been minimized.

@ogrisel

ogrisel Jun 22, 2017
Member

Please add assertions for the resulting values of clf.max_iter and clf.tol for each of these cases, e.g. something like:

clf = assert_no_warnings(SGDClassifier, max_iter=100, tol=1e-3, n_iter=None)
assert clf.max_iter == 100
assert clf.tol == 1e-3

This comment has been minimized.

@TomDLT

TomDLT Jun 22, 2017
Author Member

cf. test_tol_and_max_iter_default_values ?

This comment has been minimized.

@ogrisel

ogrisel Jun 23, 2017
Member

Fine.

@TomDLT
Copy link
Member Author

@TomDLT TomDLT commented Jun 22, 2017

I would have expected a larger number of iterations because the loss has decreased by more than 1e-3 during the last iteration.

Good point, the loss accumulator was no reset after each epoch, and was not scaled by n_samples. I fixed it. Now your snippet gives:

-- Epoch 1
Norm: 24.28, NNZs: 4, Bias: 9.764409, T: 150, Avg. loss: 0.222557
Total training time: 0.00 seconds.
-- Epoch 2
Norm: 29.40, NNZs: 4, Bias: 9.535794, T: 300, Avg. loss: 0.195226
Total training time: 0.00 seconds.
-- Epoch 3
Norm: 26.35, NNZs: 4, Bias: 9.535794, T: 450, Avg. loss: 0.000000
Total training time: 0.00 seconds.
-- Epoch 4
Norm: 23.88, NNZs: 4, Bias: 9.535794, T: 600, Avg. loss: 0.000000
Total training time: 0.00 seconds.
convergence after 4 epochs took 0.00 seconds
effective n_iter: 4
@TomDLT TomDLT force-pushed the TomDLT:sgd_maxiter branch from 4d3f785 to e2cd080 Jun 22, 2017
@ogrisel
Copy link
Member

@ogrisel ogrisel commented Jun 23, 2017

@TomDLT I have pushed two small improvements to my ogrisel/sgd_maxiter branch. Could you please include them in your PR to get CI to run on them. Other than that I think I am +1.

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Jun 23, 2017

@TomDLT please allow others scikit-learn devs to push into the branch of your next PRs in the future :)

Copy link
Member

@ogrisel ogrisel left a comment

+1 once ogrisel/sgd_maxiter is included to this PR.

@TomDLT TomDLT force-pushed the TomDLT:sgd_maxiter branch from e2cd080 to cf77163 Jun 23, 2017
@TomDLT
Copy link
Member Author

@TomDLT TomDLT commented Jun 23, 2017

Done and all green

@ogrisel ogrisel merged commit edeb3af into scikit-learn:master Jun 23, 2017
5 checks passed
5 checks passed
ci/circleci Your tests passed on CircleCI!
Details
codecov/patch 98.77% of diff hit (target 96.3%)
Details
codecov/project Absolute coverage decreased by -<.01% but relative coverage increased by +2.46% compared to ebf2bf8
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@ogrisel
Copy link
Member

@ogrisel ogrisel commented Jun 23, 2017

Merged 🎉

@jnothman
Copy link
Member

@jnothman jnothman commented Jun 24, 2017

@amueller
Copy link
Member

@amueller amueller commented Jun 28, 2017

Wohooo!!

dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
NelleV added a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
AishwaryaRK added a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
@TomDLT TomDLT deleted the TomDLT:sgd_maxiter branch Jun 15, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
scikit-learn 0.19
PRs that need reviews
Linked issues

Successfully merging this pull request may close these issues.

None yet

6 participants