Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] FIX n_iter -> max_iter conversion in SGDClassifier #9558

Merged
merged 10 commits into from
Aug 16, 2017
2 changes: 1 addition & 1 deletion doc/modules/kernel_approximation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ a linear algorithm, for example a linear SVM::
>>> clf.fit(X_features, y)
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
eta0=0.0, fit_intercept=True, l1_ratio=0.15,
learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None,
learning_rate='optimal', loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
shuffle=True, tol=None, verbose=0, warm_start=False)
>>> clf.score(X_features, y)
Expand Down
2 changes: 1 addition & 1 deletion doc/modules/sgd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ for the training samples::
>>> clf.fit(X, y)
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
eta0=0.0, fit_intercept=True, l1_ratio=0.15,
learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None,
learning_rate='optimal', loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
shuffle=True, tol=None, verbose=0, warm_start=False)

Expand Down
9 changes: 5 additions & 4 deletions sklearn/linear_model/passive_aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class PassiveAggressiveClassifier(BaseSGDClassifier):
>>> clf = PassiveAggressiveClassifier(random_state=0)
>>> clf.fit(X, y)
PassiveAggressiveClassifier(C=1.0, average=False, class_weight=None,
fit_intercept=True, loss='hinge', max_iter=5, n_iter=None,
fit_intercept=True, loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, random_state=0, shuffle=True, tol=None, verbose=0,
warm_start=False)
>>> print(clf.coef_)
Expand Down Expand Up @@ -319,9 +319,9 @@ class PassiveAggressiveRegressor(BaseSGDRegressor):
>>> regr = PassiveAggressiveRegressor(random_state=0)
>>> regr.fit(X, y)
PassiveAggressiveRegressor(C=1.0, average=False, epsilon=0.1,
fit_intercept=True, loss='epsilon_insensitive', max_iter=5,
n_iter=None, random_state=0, shuffle=True, tol=None,
verbose=0, warm_start=False)
fit_intercept=True, loss='epsilon_insensitive',
max_iter=None, n_iter=None, random_state=0, shuffle=True,
tol=None, verbose=0, warm_start=False)
>>> print(regr.coef_)
[ 20.48736655 34.18818427 67.59122734 87.94731329]
>>> print(regr.intercept_)
Expand Down Expand Up @@ -377,6 +377,7 @@ def partial_fit(self, X, y):
-------
self : returns an instance of self.
"""
self._validate_params()
lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2"
return self._partial_fit(X, y, alpha=1.0, C=self.C,
loss="epsilon_insensitive",
Expand Down
82 changes: 45 additions & 37 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,30 +66,12 @@ def __init__(self, loss, penalty='l2', alpha=0.0001, C=1.0,
self.power_t = power_t
self.warm_start = warm_start
self.average = average

if n_iter is not None:
warnings.warn("n_iter parameter is deprecated in 0.19 and will be"
" removed in 0.21. Use max_iter and tol instead.",
DeprecationWarning)
# Same behavior as before 0.19
self.max_iter = n_iter
tol = None

elif tol is None and max_iter is None:
warnings.warn(
"max_iter and tol parameters have been added in %s in 0.19. If"
" both are left unset, they default to max_iter=5 and tol=None"
". If tol is not None, max_iter defaults to max_iter=1000. "
"From 0.21, default max_iter will be 1000, "
"and default tol will be 1e-3." % type(self), FutureWarning)
# Before 0.19, default was n_iter=5
self.max_iter = 5
else:
self.max_iter = max_iter if max_iter is not None else 1000

self.n_iter = n_iter
self.max_iter = max_iter
self.tol = tol

self._validate_params()
# current tests expect init to do parameter validation
# but we are not allowed to set attributes
self._validate_params(set_max_iter=False)

def set_params(self, *args, **kwargs):
super(BaseSGD, self).set_params(*args, **kwargs)
Expand All @@ -100,11 +82,11 @@ def set_params(self, *args, **kwargs):
def fit(self, X, y):
"""Fit model."""

def _validate_params(self):
def _validate_params(self, set_max_iter=True):
"""Validate input params. """
if not isinstance(self.shuffle, bool):
raise ValueError("shuffle must be either True or False")
if self.max_iter <= 0:
if self.max_iter is not None and self.max_iter <= 0:
raise ValueError("max_iter must be > zero. Got %f" % self.max_iter)
if not (0.0 <= self.l1_ratio <= 1.0):
raise ValueError("l1_ratio must be in [0, 1]")
Expand All @@ -125,6 +107,31 @@ def _validate_params(self):
if self.loss not in self.loss_functions:
raise ValueError("The loss %s is not supported. " % self.loss)

if not set_max_iter:
return
# n_iter deprecation, set self._max_iter, self._tol
self._tol = self.tol
if self.n_iter is not None:
warnings.warn("n_iter parameter is deprecated in 0.19 and will be"
" removed in 0.21. Use max_iter and tol instead.",
DeprecationWarning)
# Same behavior as before 0.19
max_iter = self.n_iter
self._tol = None

elif self.tol is None and self.max_iter is None:
warnings.warn(
"max_iter and tol parameters have been added in %s in 0.19. If"
" both are left unset, they default to max_iter=5 and tol=None"
". If tol is not None, max_iter defaults to max_iter=1000. "
"From 0.21, default max_iter will be 1000, "
"and default tol will be 1e-3." % type(self), FutureWarning)
# Before 0.19, default was n_iter=5
max_iter = 5
else:
max_iter = self.max_iter if self.max_iter is not None else 1000
self._max_iter = max_iter

def _get_loss_function(self, loss):
"""Get concrete ``LossFunction`` object for str ``loss``. """
try:
Expand Down Expand Up @@ -365,7 +372,6 @@ def _partial_fit(self, X, y, alpha, C,

n_samples, n_features = X.shape

self._validate_params()
_check_partial_fit_first_call(self, classes)

n_classes = self.classes_.shape[0]
Expand Down Expand Up @@ -405,6 +411,7 @@ def _partial_fit(self, X, y, alpha, C,

def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
intercept_init=None, sample_weight=None):
self._validate_params()
if hasattr(self, "classes_"):
self.classes_ = None

Expand Down Expand Up @@ -433,11 +440,11 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
# Clear iteration count for multiple call to fit.
self.t_ = 1.0

self._partial_fit(X, y, alpha, C, loss, learning_rate, self.max_iter,
self._partial_fit(X, y, alpha, C, loss, learning_rate, self._max_iter,
classes, sample_weight, coef_init, intercept_init)

if (self.tol is not None and self.tol > -np.inf
and self.n_iter_ == self.max_iter):
if (self._tol is not None and self._tol > -np.inf
and self.n_iter_ == self._max_iter):
warnings.warn("Maximum number of iteration reached before "
"convergence. Consider increasing max_iter to "
"improve the fit.",
Expand Down Expand Up @@ -530,6 +537,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
-------
self : returns an instance of self.
"""
self._validate_params()
if self.class_weight in ['balanced']:
raise ValueError("class_weight '{0}' is not supported for "
"partial_fit. In order to use 'balanced' weights,"
Expand Down Expand Up @@ -753,7 +761,7 @@ class SGDClassifier(BaseSGDClassifier):
... #doctest: +NORMALIZE_WHITESPACE
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
eta0=0.0, fit_intercept=True, l1_ratio=0.15,
learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None,
learning_rate='optimal', loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
shuffle=True, tol=None, verbose=0, warm_start=False)

Expand Down Expand Up @@ -933,8 +941,6 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,

n_samples, n_features = X.shape

self._validate_params()

# Allocate datastructures from input arguments
sample_weight = self._validate_sample_weight(sample_weight, n_samples)

Expand Down Expand Up @@ -976,6 +982,7 @@ def partial_fit(self, X, y, sample_weight=None):
-------
self : returns an instance of self.
"""
self._validate_params()
return self._partial_fit(X, y, self.alpha, C=1.0,
loss=self.loss,
learning_rate=self.learning_rate, max_iter=1,
Expand All @@ -984,6 +991,7 @@ def partial_fit(self, X, y, sample_weight=None):

def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
intercept_init=None, sample_weight=None):
self._validate_params()
if self.warm_start and getattr(self, "coef_", None) is not None:
if coef_init is None:
coef_init = self.coef_
Expand All @@ -1003,11 +1011,11 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
self.t_ = 1.0

self._partial_fit(X, y, alpha, C, loss, learning_rate,
self.max_iter, sample_weight, coef_init,
self._max_iter, sample_weight, coef_init,
intercept_init)

if (self.tol is not None and self.tol > -np.inf
and self.n_iter_ == self.max_iter):
if (self._tol is not None and self._tol > -np.inf
and self.n_iter_ == self._max_iter):
warnings.warn("Maximum number of iteration reached before "
"convergence. Consider increasing max_iter to "
"improve the fit.",
Expand Down Expand Up @@ -1096,7 +1104,7 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
# Windows
seed = random_state.randint(0, np.iinfo(np.int32).max)

tol = self.tol if self.tol is not None else -np.inf
tol = self._tol if self._tol is not None else -np.inf

if self.average > 0:
self.standard_coef_, self.standard_intercept_, \
Expand Down Expand Up @@ -1306,7 +1314,7 @@ class SGDRegressor(BaseSGDRegressor):
... #doctest: +NORMALIZE_WHITESPACE
SGDRegressor(alpha=0.0001, average=False, epsilon=0.1, eta0=0.01,
fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',
loss='squared_loss', max_iter=5, n_iter=None, penalty='l2',
loss='squared_loss', max_iter=None, n_iter=None, penalty='l2',
power_t=0.25, random_state=None, shuffle=True, tol=None,
verbose=0, warm_start=False)

Expand Down
34 changes: 20 additions & 14 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,12 +1207,13 @@ def test_tol_parameter():
def test_future_and_deprecation_warnings():
# Test that warnings are raised. Will be removed in 0.21

def init(max_iter=None, tol=None, n_iter=None):
sgd = SGDClassifier(max_iter=max_iter, tol=tol, n_iter=n_iter)
sgd._validate_params()

# When all default values are used
msg_future = "max_iter and tol parameters have been added in "
assert_warns_message(FutureWarning, msg_future, SGDClassifier)

def init(max_iter=None, tol=None, n_iter=None):
SGDClassifier(max_iter=max_iter, tol=tol, n_iter=n_iter)
assert_warns_message(FutureWarning, msg_future, init)

# When n_iter is specified
msg_deprecation = "n_iter parameter is deprecated"
Expand All @@ -1228,24 +1229,29 @@ def init(max_iter=None, tol=None, n_iter=None):
def test_tol_and_max_iter_default_values():
# Test that the default values are correctly changed
est = SGDClassifier()
assert_equal(est.tol, None)
assert_equal(est.max_iter, 5)
est._validate_params()
assert_equal(est._tol, None)
assert_equal(est._max_iter, 5)

est = SGDClassifier(n_iter=42)
assert_equal(est.tol, None)
assert_equal(est.max_iter, 42)
est._validate_params()
assert_equal(est._tol, None)
assert_equal(est._max_iter, 42)

est = SGDClassifier(tol=1e-2)
assert_equal(est.tol, 1e-2)
assert_equal(est.max_iter, 1000)
est._validate_params()
assert_equal(est._tol, 1e-2)
assert_equal(est._max_iter, 1000)

est = SGDClassifier(max_iter=42)
assert_equal(est.tol, None)
assert_equal(est.max_iter, 42)
est._validate_params()
assert_equal(est._tol, None)
assert_equal(est._max_iter, 42)

est = SGDClassifier(max_iter=42, tol=1e-2)
assert_equal(est.tol, 1e-2)
assert_equal(est.max_iter, 42)
est._validate_params()
assert_equal(est._tol, 1e-2)
assert_equal(est._max_iter, 42)


def _test_gradient_common(loss_function, cases):
Expand Down