Skip to content

Commit

Permalink
[MRG+1] FIX n_iter -> max_iter conversion in SGDClassifier (#9558)
Browse files Browse the repository at this point in the history
* move n_iter -> max_iter conversion and warning into _check_params in SGDClassifier for proper deprecation.

* move validate_params so we have self._max_iter in _fit

* validate params in init because the tests wants me to

* better check for input validation

* fix deprecation tests to call _validate_params

* fix parameter validation in PA classifier

* fix max_iter in doctests

* pep8 /doctest whitespace

* more doctests

* maybe I'll find them all....
  • Loading branch information
amueller authored and jnothman committed Aug 23, 2017
1 parent 85e1575 commit bc97fb4
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 57 deletions.
2 changes: 1 addition & 1 deletion doc/modules/kernel_approximation.rst
Expand Up @@ -63,7 +63,7 @@ a linear algorithm, for example a linear SVM::
>>> clf.fit(X_features, y)
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
eta0=0.0, fit_intercept=True, l1_ratio=0.15,
learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None,
learning_rate='optimal', loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
shuffle=True, tol=None, verbose=0, warm_start=False)
>>> clf.score(X_features, y)
Expand Down
2 changes: 1 addition & 1 deletion doc/modules/sgd.rst
Expand Up @@ -63,7 +63,7 @@ for the training samples::
>>> clf.fit(X, y)
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
eta0=0.0, fit_intercept=True, l1_ratio=0.15,
learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None,
learning_rate='optimal', loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
shuffle=True, tol=None, verbose=0, warm_start=False)

Expand Down
9 changes: 5 additions & 4 deletions sklearn/linear_model/passive_aggressive.py
Expand Up @@ -114,7 +114,7 @@ class PassiveAggressiveClassifier(BaseSGDClassifier):
>>> clf = PassiveAggressiveClassifier(random_state=0)
>>> clf.fit(X, y)
PassiveAggressiveClassifier(C=1.0, average=False, class_weight=None,
fit_intercept=True, loss='hinge', max_iter=5, n_iter=None,
fit_intercept=True, loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, random_state=0, shuffle=True, tol=None, verbose=0,
warm_start=False)
>>> print(clf.coef_)
Expand Down Expand Up @@ -319,9 +319,9 @@ class PassiveAggressiveRegressor(BaseSGDRegressor):
>>> regr = PassiveAggressiveRegressor(random_state=0)
>>> regr.fit(X, y)
PassiveAggressiveRegressor(C=1.0, average=False, epsilon=0.1,
fit_intercept=True, loss='epsilon_insensitive', max_iter=5,
n_iter=None, random_state=0, shuffle=True, tol=None,
verbose=0, warm_start=False)
fit_intercept=True, loss='epsilon_insensitive',
max_iter=None, n_iter=None, random_state=0, shuffle=True,
tol=None, verbose=0, warm_start=False)
>>> print(regr.coef_)
[ 20.48736655 34.18818427 67.59122734 87.94731329]
>>> print(regr.intercept_)
Expand Down Expand Up @@ -377,6 +377,7 @@ def partial_fit(self, X, y):
-------
self : returns an instance of self.
"""
self._validate_params()
lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2"
return self._partial_fit(X, y, alpha=1.0, C=self.C,
loss="epsilon_insensitive",
Expand Down
82 changes: 45 additions & 37 deletions sklearn/linear_model/stochastic_gradient.py
Expand Up @@ -66,30 +66,12 @@ def __init__(self, loss, penalty='l2', alpha=0.0001, C=1.0,
self.power_t = power_t
self.warm_start = warm_start
self.average = average

if n_iter is not None:
warnings.warn("n_iter parameter is deprecated in 0.19 and will be"
" removed in 0.21. Use max_iter and tol instead.",
DeprecationWarning)
# Same behavior as before 0.19
self.max_iter = n_iter
tol = None

elif tol is None and max_iter is None:
warnings.warn(
"max_iter and tol parameters have been added in %s in 0.19. If"
" both are left unset, they default to max_iter=5 and tol=None"
". If tol is not None, max_iter defaults to max_iter=1000. "
"From 0.21, default max_iter will be 1000, "
"and default tol will be 1e-3." % type(self), FutureWarning)
# Before 0.19, default was n_iter=5
self.max_iter = 5
else:
self.max_iter = max_iter if max_iter is not None else 1000

self.n_iter = n_iter
self.max_iter = max_iter
self.tol = tol

self._validate_params()
# current tests expect init to do parameter validation
# but we are not allowed to set attributes
self._validate_params(set_max_iter=False)

def set_params(self, *args, **kwargs):
super(BaseSGD, self).set_params(*args, **kwargs)
Expand All @@ -100,11 +82,11 @@ def set_params(self, *args, **kwargs):
def fit(self, X, y):
"""Fit model."""

def _validate_params(self):
def _validate_params(self, set_max_iter=True):
"""Validate input params. """
if not isinstance(self.shuffle, bool):
raise ValueError("shuffle must be either True or False")
if self.max_iter <= 0:
if self.max_iter is not None and self.max_iter <= 0:
raise ValueError("max_iter must be > zero. Got %f" % self.max_iter)
if not (0.0 <= self.l1_ratio <= 1.0):
raise ValueError("l1_ratio must be in [0, 1]")
Expand All @@ -125,6 +107,31 @@ def _validate_params(self):
if self.loss not in self.loss_functions:
raise ValueError("The loss %s is not supported. " % self.loss)

if not set_max_iter:
return
# n_iter deprecation, set self._max_iter, self._tol
self._tol = self.tol
if self.n_iter is not None:
warnings.warn("n_iter parameter is deprecated in 0.19 and will be"
" removed in 0.21. Use max_iter and tol instead.",
DeprecationWarning)
# Same behavior as before 0.19
max_iter = self.n_iter
self._tol = None

elif self.tol is None and self.max_iter is None:
warnings.warn(
"max_iter and tol parameters have been added in %s in 0.19. If"
" both are left unset, they default to max_iter=5 and tol=None"
". If tol is not None, max_iter defaults to max_iter=1000. "
"From 0.21, default max_iter will be 1000, "
"and default tol will be 1e-3." % type(self), FutureWarning)
# Before 0.19, default was n_iter=5
max_iter = 5
else:
max_iter = self.max_iter if self.max_iter is not None else 1000
self._max_iter = max_iter

def _get_loss_function(self, loss):
"""Get concrete ``LossFunction`` object for str ``loss``. """
try:
Expand Down Expand Up @@ -365,7 +372,6 @@ def _partial_fit(self, X, y, alpha, C,

n_samples, n_features = X.shape

self._validate_params()
_check_partial_fit_first_call(self, classes)

n_classes = self.classes_.shape[0]
Expand Down Expand Up @@ -405,6 +411,7 @@ def _partial_fit(self, X, y, alpha, C,

def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
intercept_init=None, sample_weight=None):
self._validate_params()
if hasattr(self, "classes_"):
self.classes_ = None

Expand Down Expand Up @@ -433,11 +440,11 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
# Clear iteration count for multiple call to fit.
self.t_ = 1.0

self._partial_fit(X, y, alpha, C, loss, learning_rate, self.max_iter,
self._partial_fit(X, y, alpha, C, loss, learning_rate, self._max_iter,
classes, sample_weight, coef_init, intercept_init)

if (self.tol is not None and self.tol > -np.inf
and self.n_iter_ == self.max_iter):
if (self._tol is not None and self._tol > -np.inf
and self.n_iter_ == self._max_iter):
warnings.warn("Maximum number of iteration reached before "
"convergence. Consider increasing max_iter to "
"improve the fit.",
Expand Down Expand Up @@ -530,6 +537,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
-------
self : returns an instance of self.
"""
self._validate_params()
if self.class_weight in ['balanced']:
raise ValueError("class_weight '{0}' is not supported for "
"partial_fit. In order to use 'balanced' weights,"
Expand Down Expand Up @@ -753,7 +761,7 @@ class SGDClassifier(BaseSGDClassifier):
... #doctest: +NORMALIZE_WHITESPACE
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
eta0=0.0, fit_intercept=True, l1_ratio=0.15,
learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None,
learning_rate='optimal', loss='hinge', max_iter=None, n_iter=None,
n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
shuffle=True, tol=None, verbose=0, warm_start=False)
Expand Down Expand Up @@ -933,8 +941,6 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,

n_samples, n_features = X.shape

self._validate_params()

# Allocate datastructures from input arguments
sample_weight = self._validate_sample_weight(sample_weight, n_samples)

Expand Down Expand Up @@ -976,6 +982,7 @@ def partial_fit(self, X, y, sample_weight=None):
-------
self : returns an instance of self.
"""
self._validate_params()
return self._partial_fit(X, y, self.alpha, C=1.0,
loss=self.loss,
learning_rate=self.learning_rate, max_iter=1,
Expand All @@ -984,6 +991,7 @@ def partial_fit(self, X, y, sample_weight=None):

def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
intercept_init=None, sample_weight=None):
self._validate_params()
if self.warm_start and getattr(self, "coef_", None) is not None:
if coef_init is None:
coef_init = self.coef_
Expand All @@ -1003,11 +1011,11 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
self.t_ = 1.0

self._partial_fit(X, y, alpha, C, loss, learning_rate,
self.max_iter, sample_weight, coef_init,
self._max_iter, sample_weight, coef_init,
intercept_init)

if (self.tol is not None and self.tol > -np.inf
and self.n_iter_ == self.max_iter):
if (self._tol is not None and self._tol > -np.inf
and self.n_iter_ == self._max_iter):
warnings.warn("Maximum number of iteration reached before "
"convergence. Consider increasing max_iter to "
"improve the fit.",
Expand Down Expand Up @@ -1096,7 +1104,7 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
# Windows
seed = random_state.randint(0, np.iinfo(np.int32).max)

tol = self.tol if self.tol is not None else -np.inf
tol = self._tol if self._tol is not None else -np.inf

if self.average > 0:
self.standard_coef_, self.standard_intercept_, \
Expand Down Expand Up @@ -1306,7 +1314,7 @@ class SGDRegressor(BaseSGDRegressor):
... #doctest: +NORMALIZE_WHITESPACE
SGDRegressor(alpha=0.0001, average=False, epsilon=0.1, eta0=0.01,
fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',
loss='squared_loss', max_iter=5, n_iter=None, penalty='l2',
loss='squared_loss', max_iter=None, n_iter=None, penalty='l2',
power_t=0.25, random_state=None, shuffle=True, tol=None,
verbose=0, warm_start=False)
Expand Down
34 changes: 20 additions & 14 deletions sklearn/linear_model/tests/test_sgd.py
Expand Up @@ -1207,12 +1207,13 @@ def test_tol_parameter():
def test_future_and_deprecation_warnings():
# Test that warnings are raised. Will be removed in 0.21

def init(max_iter=None, tol=None, n_iter=None):
sgd = SGDClassifier(max_iter=max_iter, tol=tol, n_iter=n_iter)
sgd._validate_params()

# When all default values are used
msg_future = "max_iter and tol parameters have been added in "
assert_warns_message(FutureWarning, msg_future, SGDClassifier)

def init(max_iter=None, tol=None, n_iter=None):
SGDClassifier(max_iter=max_iter, tol=tol, n_iter=n_iter)
assert_warns_message(FutureWarning, msg_future, init)

# When n_iter is specified
msg_deprecation = "n_iter parameter is deprecated"
Expand All @@ -1228,24 +1229,29 @@ def init(max_iter=None, tol=None, n_iter=None):
def test_tol_and_max_iter_default_values():
# Test that the default values are correctly changed
est = SGDClassifier()
assert_equal(est.tol, None)
assert_equal(est.max_iter, 5)
est._validate_params()
assert_equal(est._tol, None)
assert_equal(est._max_iter, 5)

est = SGDClassifier(n_iter=42)
assert_equal(est.tol, None)
assert_equal(est.max_iter, 42)
est._validate_params()
assert_equal(est._tol, None)
assert_equal(est._max_iter, 42)

est = SGDClassifier(tol=1e-2)
assert_equal(est.tol, 1e-2)
assert_equal(est.max_iter, 1000)
est._validate_params()
assert_equal(est._tol, 1e-2)
assert_equal(est._max_iter, 1000)

est = SGDClassifier(max_iter=42)
assert_equal(est.tol, None)
assert_equal(est.max_iter, 42)
est._validate_params()
assert_equal(est._tol, None)
assert_equal(est._max_iter, 42)

est = SGDClassifier(max_iter=42, tol=1e-2)
assert_equal(est.tol, 1e-2)
assert_equal(est.max_iter, 42)
est._validate_params()
assert_equal(est._tol, 1e-2)
assert_equal(est._max_iter, 42)


def _test_gradient_common(loss_function, cases):
Expand Down

0 comments on commit bc97fb4

Please sign in to comment.