Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] MLPRegressor quits fitting too soon due to self._no_improvement_count #9457

Merged
merged 30 commits into from Oct 29, 2017
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c32075e
_no_improvement_count limit not a magic number
nnadeau Jul 27, 2017
04e568a
added no_improvement_limit as __init__ arguments
Jul 27, 2017
011facd
fixed early quitting messages with respect to mutable limit
Jul 27, 2017
a2aab9c
renamed no_improvement_limit to n_iter_no_change to align with #7071
Jul 27, 2017
905f64c
added n_iter_no_change tests
Aug 2, 2017
a9fbfd9
added test_n_iter_no_change_inf()
Aug 2, 2017
9ad755e
PEP8 line length fixes
Aug 2, 2017
1412258
updated class docs
Aug 2, 2017
152035f
PEP8 line length fixes
Aug 2, 2017
baf71d3
flake8 line length fixes
Aug 2, 2017
eca27a7
fixing doctests
Aug 4, 2017
d4c6df8
removed test as loss fluctuations cannot be guaranteed
Aug 4, 2017
3012daa
simplified test
Aug 4, 2017
110d13e
fixed flake8 'E128 continuation line under-indented for visual indent'
Aug 4, 2017
78481da
flake8 indent lines to the opening parentheses
Aug 4, 2017
b7f1dc2
updated comment
nnadeau Aug 15, 2017
0c1f5b6
updated default value of n_iter_no_change: 2 -> 10
nnadeau Aug 15, 2017
fe97a44
updated documentation
nnadeau Aug 18, 2017
5414867
updated `whats_new`
nnadeau Oct 17, 2017
5ef758e
updated `test_params_errors` and `_validate_hyperparameters`
nnadeau Oct 17, 2017
a8d2931
added `.. versionadded:: 0.20`
nnadeau Oct 17, 2017
97aabfc
added `@ignore_warnings(category=ConvergenceWarning)`
nnadeau Oct 17, 2017
c289791
fixed flake8 error
nnadeau Oct 17, 2017
e60701a
Revert "updated default value of n_iter_no_change: 2 -> 10"
Oct 19, 2017
558ed2e
Merge branch 'master' into patch-1
nnadeau Oct 27, 2017
cc85e6a
Revert "Revert "updated default value of n_iter_no_change: 2 -> 10""
Oct 27, 2017
ab0e5a7
Merge branch 'master' into patch-1
nnadeau Oct 27, 2017
d913172
added `n_iter_no_change` bugfix to whats_new
Oct 27, 2017
e78fb78
fixing double backticks for documentation
engnadeau Oct 28, 2017
443c348
added bugfix to changed models section
engnadeau Oct 28, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 14 additions & 12 deletions doc/modules/neural_networks_supervised.rst
Expand Up @@ -91,12 +91,13 @@ training samples::
...
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
MLPClassifier(activation='relu', alpha=1e-05, batch_size='auto',
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(5, 2), learning_rate='constant',
learning_rate_init=0.001, max_iter=200, momentum=0.9,
nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True,
solver='lbfgs', tol=0.0001, validation_fraction=0.1, verbose=False,
warm_start=False)
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(5, 2),
learning_rate='constant', learning_rate_init=0.001,
max_iter=200, momentum=0.9, n_iter_no_change=10,
nesterovs_momentum=True, power_t=0.5, random_state=1,
shuffle=True, solver='lbfgs', tol=0.0001,
validation_fraction=0.1, verbose=False, warm_start=False)

After fitting (training), the model can predict labels for new samples::

Expand Down Expand Up @@ -139,12 +140,13 @@ indices where the value is `1` represents the assigned classes of that sample::
...
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
MLPClassifier(activation='relu', alpha=1e-05, batch_size='auto',
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(15,), learning_rate='constant',
learning_rate_init=0.001, max_iter=200, momentum=0.9,
nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True,
solver='lbfgs', tol=0.0001, validation_fraction=0.1, verbose=False,
warm_start=False)
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(15,),
learning_rate='constant', learning_rate_init=0.001,
max_iter=200, momentum=0.9, n_iter_no_change=10,
nesterovs_momentum=True, power_t=0.5, random_state=1,
shuffle=True, solver='lbfgs', tol=0.0001,
validation_fraction=0.1, verbose=False, warm_start=False)
>>> clf.predict([[1., 2.]])
array([[1, 1]])
>>> clf.predict([[0., 0.]])
Expand Down
11 changes: 11 additions & 0 deletions doc/whats_new/v0.20.rst
Expand Up @@ -65,6 +65,13 @@ Classifiers and regressors
:class:`sklearn.naive_bayes.GaussianNB` to give a precise control over
variances calculation. :issue:`9681` by :user:`Dmitry Mottl <Mottl>`.

- Add `n_iter_no_change` parameter in
:class:`multilayer_perceptron.BaseMultilayerPerceptron`,
:class:`multilayer_perceptron.MLPRegressor`, and
:class:`multilayer_perceptron.MLPClassifier` to give control over
maximum number of epochs to not meet `tol` improvement.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double backticks for tol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be neural_network everywhere instead of multilayer_perceptron

:issue:`9456` by :user:`Nicholas Nadeau <nnadeau>`.

Model evaluation and meta-estimators

- A scorer based on :func:`metrics.brier_score_loss` is also available.
Expand All @@ -91,6 +98,10 @@ Classifiers and regressors
identical X values.
:issue:`9432` by :user:`Dallas Card <dallascard>`

- Fixed a bug in :class:`multilayer_perceptron.MLPRegressor` where fitting
quit unexpectedly early due to local minima or fluctuations.
:issue:`9456` by :user:`Nicholas Nadeau <nnadeau>`

Decomposition, manifold learning and clustering

- Fix for uninformative error in :class:`decomposition.IncrementalPCA`:
Expand Down
59 changes: 40 additions & 19 deletions sklearn/neural_network/multilayer_perceptron.py
Expand Up @@ -51,7 +51,8 @@ def __init__(self, hidden_layer_sizes, activation, solver,
alpha, batch_size, learning_rate, learning_rate_init, power_t,
max_iter, loss, shuffle, random_state, tol, verbose,
warm_start, momentum, nesterovs_momentum, early_stopping,
validation_fraction, beta_1, beta_2, epsilon):
validation_fraction, beta_1, beta_2, epsilon,
n_iter_no_change):
self.activation = activation
self.solver = solver
self.alpha = alpha
Expand All @@ -74,6 +75,7 @@ def __init__(self, hidden_layer_sizes, activation, solver,
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.n_iter_no_change = n_iter_no_change

def _unpack(self, packed_parameters):
"""Extract the coefficients and intercepts from packed_parameters."""
Expand Down Expand Up @@ -415,6 +417,9 @@ def _validate_hyperparameters(self):
self.beta_2)
if self.epsilon <= 0.0:
raise ValueError("epsilon must be > 0, got %s." % self.epsilon)
if self.n_iter_no_change <= 0:
raise ValueError("n_iter_no_change must be > 0, got %s."
% self.n_iter_no_change)

# raise ValueError if not registered
supported_activations = ('identity', 'logistic', 'tanh', 'relu')
Expand Down Expand Up @@ -537,15 +542,17 @@ def _fit_stochastic(self, X, y, activations, deltas, coef_grads,
# for learning rate that needs to be updated at iteration end
self._optimizer.iteration_ends(self.t_)

if self._no_improvement_count > 2:
# not better than last two iterations by tol.
if self._no_improvement_count > self.n_iter_no_change:
# not better than last `n_iter_no_change` iterations by tol
# stop or decrease learning rate
if early_stopping:
msg = ("Validation score did not improve more than "
"tol=%f for two consecutive epochs." % self.tol)
"tol=%f for %d consecutive epochs." % (
self.tol, self.n_iter_no_change))
else:
msg = ("Training loss did not improve more than tol=%f"
" for two consecutive epochs." % self.tol)
" for %d consecutive epochs." % (
self.tol, self.n_iter_no_change))

is_stopping = self._optimizer.trigger_stopping(
msg, self.verbose)
Expand Down Expand Up @@ -780,9 +787,9 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin):

tol : float, optional, default 1e-4
Tolerance for the optimization. When the loss or score is not improving
by at least tol for two consecutive iterations, unless `learning_rate`
is set to 'adaptive', convergence is considered to be reached and
training stops.
by at least tol for `n_iter_no_change` consecutive iterations, unless
`learning_rate` is set to 'adaptive', convergence is considered to be
reached and training stops.

verbose : bool, optional, default False
Whether to print progress messages to stdout.
Expand All @@ -804,8 +811,8 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin):
Whether to use early stopping to terminate training when validation
score is not improving. If set to true, it will automatically set
aside 10% of training data as validation and terminate training when
validation score is not improving by at least tol for two consecutive
epochs.
validation score is not improving by at least tol for
`n_iter_no_change` consecutive epochs.
Only effective when solver='sgd' or 'adam'

validation_fraction : float, optional, default 0.1
Expand All @@ -824,6 +831,12 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin):
epsilon : float, optional, default 1e-8
Value for numerical stability in adam. Only used when solver='adam'

n_iter_no_change : int, optional, default 10
Maximum number of epochs to not meet `tol` improvement.
Only effective when solver='sgd' or 'adam'

.. versionadded:: 0.20

Attributes
----------
classes_ : array or list of array of shape (n_classes,)
Expand Down Expand Up @@ -890,7 +903,7 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
verbose=False, warm_start=False, momentum=0.9,
nesterovs_momentum=True, early_stopping=False,
validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
epsilon=1e-8):
epsilon=1e-8, n_iter_no_change=10):

sup = super(MLPClassifier, self)
sup.__init__(hidden_layer_sizes=hidden_layer_sizes,
Expand All @@ -903,7 +916,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
nesterovs_momentum=nesterovs_momentum,
early_stopping=early_stopping,
validation_fraction=validation_fraction,
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon,
n_iter_no_change=n_iter_no_change)

def _validate_input(self, X, y, incremental):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
Expand Down Expand Up @@ -1157,9 +1171,9 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin):

tol : float, optional, default 1e-4
Tolerance for the optimization. When the loss or score is not improving
by at least tol for two consecutive iterations, unless `learning_rate`
is set to 'adaptive', convergence is considered to be reached and
training stops.
by at least tol for `n_iter_no_change` consecutive iterations, unless
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double backticks or no backticks?

`learning_rate` is set to 'adaptive', convergence is considered
to be reached and training stops.

verbose : bool, optional, default False
Whether to print progress messages to stdout.
Expand All @@ -1181,8 +1195,8 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin):
Whether to use early stopping to terminate training when validation
score is not improving. If set to true, it will automatically set
aside 10% of training data as validation and terminate training when
validation score is not improving by at least tol for two consecutive
epochs.
validation score is not improving by at least tol for
`n_iter_no_change` consecutive epochs.
Only effective when solver='sgd' or 'adam'

validation_fraction : float, optional, default 0.1
Expand All @@ -1201,6 +1215,12 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin):
epsilon : float, optional, default 1e-8
Value for numerical stability in adam. Only used when solver='adam'

n_iter_no_change : int, optional, default 10
Maximum number of epochs to not meet `tol` improvement.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double backticks

Only effective when solver='sgd' or 'adam'

.. versionadded:: 0.20

Attributes
----------
loss_ : float
Expand Down Expand Up @@ -1265,7 +1285,7 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
verbose=False, warm_start=False, momentum=0.9,
nesterovs_momentum=True, early_stopping=False,
validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
epsilon=1e-8):
epsilon=1e-8, n_iter_no_change=10):

sup = super(MLPRegressor, self)
sup.__init__(hidden_layer_sizes=hidden_layer_sizes,
Expand All @@ -1278,7 +1298,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
nesterovs_momentum=nesterovs_momentum,
early_stopping=early_stopping,
validation_fraction=validation_fraction,
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon,
n_iter_no_change=n_iter_no_change)

def predict(self, X):
"""Predict using the multi-layer perceptron model.
Expand Down
45 changes: 45 additions & 0 deletions sklearn/neural_network/tests/test_mlp.py
Expand Up @@ -420,6 +420,7 @@ def test_params_errors():
assert_raises(ValueError, clf(beta_2=1).fit, X, y)
assert_raises(ValueError, clf(beta_2=-0.5).fit, X, y)
assert_raises(ValueError, clf(epsilon=-0.5).fit, X, y)
assert_raises(ValueError, clf(n_iter_no_change=-1).fit, X, y)

assert_raises(ValueError, clf(solver='hadoken').fit, X, y)
assert_raises(ValueError, clf(learning_rate='converge').fit, X, y)
Expand Down Expand Up @@ -588,3 +589,47 @@ def test_warm_start():
'classes as in the previous call to fit.'
' Previously got [0 1 2], `y` has %s' % np.unique(y_i))
assert_raise_message(ValueError, message, clf.fit, X, y_i)


def test_n_iter_no_change():
# test n_iter_no_change using binary data set
# the classifying fitting process is not prone to loss curve fluctuations
X = X_digits_binary[:100]
y = y_digits_binary[:100]
tol = 0.01
max_iter = 3000

# test multiple n_iter_no_change
for n_iter_no_change in [2, 5, 10, 50, 100]:
clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd',
n_iter_no_change=n_iter_no_change)
clf.fit(X, y)

# validate n_iter_no_change
assert_equal(clf._no_improvement_count, n_iter_no_change + 1)
assert_greater(max_iter, clf.n_iter_)


@ignore_warnings(category=ConvergenceWarning)
def test_n_iter_no_change_inf():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add @ignore_warnings(category=ConvergenceWarning) to ignore the warning as max_iter is reached

# test n_iter_no_change using binary data set
# the fitting process should go to max_iter iterations
X = X_digits_binary[:100]
y = y_digits_binary[:100]

# set a ridiculous tolerance
# this should always trigger _update_no_improvement_count()
tol = 1e9

# fit
n_iter_no_change = np.inf
max_iter = 3000
clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd',
n_iter_no_change=n_iter_no_change)
clf.fit(X, y)

# validate n_iter_no_change doesn't cause early stopping
assert_equal(clf.n_iter_, max_iter)

# validate _update_no_improvement_count() was always triggered
assert_equal(clf._no_improvement_count, clf.n_iter_ - 1)