New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] MLPRegressor
quits fitting too soon due to self._no_improvement_count
#9457
Changes from 23 commits
c32075e
04e568a
011facd
a2aab9c
905f64c
a9fbfd9
9ad755e
1412258
152035f
baf71d3
eca27a7
d4c6df8
3012daa
110d13e
78481da
b7f1dc2
0c1f5b6
fe97a44
5414867
5ef758e
a8d2931
97aabfc
c289791
e60701a
558ed2e
cc85e6a
ab0e5a7
d913172
e78fb78
443c348
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,7 +51,8 @@ def __init__(self, hidden_layer_sizes, activation, solver, | |
alpha, batch_size, learning_rate, learning_rate_init, power_t, | ||
max_iter, loss, shuffle, random_state, tol, verbose, | ||
warm_start, momentum, nesterovs_momentum, early_stopping, | ||
validation_fraction, beta_1, beta_2, epsilon): | ||
validation_fraction, beta_1, beta_2, epsilon, | ||
n_iter_no_change): | ||
self.activation = activation | ||
self.solver = solver | ||
self.alpha = alpha | ||
|
@@ -74,6 +75,7 @@ def __init__(self, hidden_layer_sizes, activation, solver, | |
self.beta_1 = beta_1 | ||
self.beta_2 = beta_2 | ||
self.epsilon = epsilon | ||
self.n_iter_no_change = n_iter_no_change | ||
|
||
def _unpack(self, packed_parameters): | ||
"""Extract the coefficients and intercepts from packed_parameters.""" | ||
|
@@ -415,6 +417,9 @@ def _validate_hyperparameters(self): | |
self.beta_2) | ||
if self.epsilon <= 0.0: | ||
raise ValueError("epsilon must be > 0, got %s." % self.epsilon) | ||
if self.n_iter_no_change <= 0: | ||
raise ValueError("n_iter_no_change must be > 0, got %s." | ||
% self.n_iter_no_change) | ||
|
||
# raise ValueError if not registered | ||
supported_activations = ('identity', 'logistic', 'tanh', 'relu') | ||
|
@@ -537,15 +542,17 @@ def _fit_stochastic(self, X, y, activations, deltas, coef_grads, | |
# for learning rate that needs to be updated at iteration end | ||
self._optimizer.iteration_ends(self.t_) | ||
|
||
if self._no_improvement_count > 2: | ||
# not better than last two iterations by tol. | ||
if self._no_improvement_count > self.n_iter_no_change: | ||
# not better than last `n_iter_no_change` iterations by tol | ||
# stop or decrease learning rate | ||
if early_stopping: | ||
msg = ("Validation score did not improve more than " | ||
"tol=%f for two consecutive epochs." % self.tol) | ||
"tol=%f for %d consecutive epochs." % ( | ||
self.tol, self.n_iter_no_change)) | ||
else: | ||
msg = ("Training loss did not improve more than tol=%f" | ||
" for two consecutive epochs." % self.tol) | ||
" for %d consecutive epochs." % ( | ||
self.tol, self.n_iter_no_change)) | ||
|
||
is_stopping = self._optimizer.trigger_stopping( | ||
msg, self.verbose) | ||
|
@@ -780,9 +787,9 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): | |
|
||
tol : float, optional, default 1e-4 | ||
Tolerance for the optimization. When the loss or score is not improving | ||
by at least tol for two consecutive iterations, unless `learning_rate` | ||
is set to 'adaptive', convergence is considered to be reached and | ||
training stops. | ||
by at least tol for `n_iter_no_change` consecutive iterations, unless | ||
`learning_rate` is set to 'adaptive', convergence is considered to be | ||
reached and training stops. | ||
|
||
verbose : bool, optional, default False | ||
Whether to print progress messages to stdout. | ||
|
@@ -804,8 +811,8 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): | |
Whether to use early stopping to terminate training when validation | ||
score is not improving. If set to true, it will automatically set | ||
aside 10% of training data as validation and terminate training when | ||
validation score is not improving by at least tol for two consecutive | ||
epochs. | ||
validation score is not improving by at least tol for | ||
`n_iter_no_change` consecutive epochs. | ||
Only effective when solver='sgd' or 'adam' | ||
|
||
validation_fraction : float, optional, default 0.1 | ||
|
@@ -824,6 +831,12 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): | |
epsilon : float, optional, default 1e-8 | ||
Value for numerical stability in adam. Only used when solver='adam' | ||
|
||
n_iter_no_change : int, optional, default 10 | ||
Maximum number of epochs to not meet `tol` improvement. | ||
Only effective when solver='sgd' or 'adam' | ||
|
||
.. versionadded:: 0.20 | ||
|
||
Attributes | ||
---------- | ||
classes_ : array or list of array of shape (n_classes,) | ||
|
@@ -890,7 +903,7 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", | |
verbose=False, warm_start=False, momentum=0.9, | ||
nesterovs_momentum=True, early_stopping=False, | ||
validation_fraction=0.1, beta_1=0.9, beta_2=0.999, | ||
epsilon=1e-8): | ||
epsilon=1e-8, n_iter_no_change=10): | ||
|
||
sup = super(MLPClassifier, self) | ||
sup.__init__(hidden_layer_sizes=hidden_layer_sizes, | ||
|
@@ -903,7 +916,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", | |
nesterovs_momentum=nesterovs_momentum, | ||
early_stopping=early_stopping, | ||
validation_fraction=validation_fraction, | ||
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon) | ||
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, | ||
n_iter_no_change=n_iter_no_change) | ||
|
||
def _validate_input(self, X, y, incremental): | ||
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], | ||
|
@@ -1157,9 +1171,9 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): | |
|
||
tol : float, optional, default 1e-4 | ||
Tolerance for the optimization. When the loss or score is not improving | ||
by at least tol for two consecutive iterations, unless `learning_rate` | ||
is set to 'adaptive', convergence is considered to be reached and | ||
training stops. | ||
by at least tol for `n_iter_no_change` consecutive iterations, unless | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. double backticks or no backticks? |
||
`learning_rate` is set to 'adaptive', convergence is considered | ||
to be reached and training stops. | ||
|
||
verbose : bool, optional, default False | ||
Whether to print progress messages to stdout. | ||
|
@@ -1181,8 +1195,8 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): | |
Whether to use early stopping to terminate training when validation | ||
score is not improving. If set to true, it will automatically set | ||
aside 10% of training data as validation and terminate training when | ||
validation score is not improving by at least tol for two consecutive | ||
epochs. | ||
validation score is not improving by at least tol for | ||
`n_iter_no_change` consecutive epochs. | ||
Only effective when solver='sgd' or 'adam' | ||
|
||
validation_fraction : float, optional, default 0.1 | ||
|
@@ -1201,6 +1215,12 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): | |
epsilon : float, optional, default 1e-8 | ||
Value for numerical stability in adam. Only used when solver='adam' | ||
|
||
n_iter_no_change : int, optional, default 10 | ||
Maximum number of epochs to not meet `tol` improvement. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. double backticks |
||
Only effective when solver='sgd' or 'adam' | ||
|
||
.. versionadded:: 0.20 | ||
|
||
Attributes | ||
---------- | ||
loss_ : float | ||
|
@@ -1265,7 +1285,7 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", | |
verbose=False, warm_start=False, momentum=0.9, | ||
nesterovs_momentum=True, early_stopping=False, | ||
validation_fraction=0.1, beta_1=0.9, beta_2=0.999, | ||
epsilon=1e-8): | ||
epsilon=1e-8, n_iter_no_change=10): | ||
|
||
sup = super(MLPRegressor, self) | ||
sup.__init__(hidden_layer_sizes=hidden_layer_sizes, | ||
|
@@ -1278,7 +1298,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", | |
nesterovs_momentum=nesterovs_momentum, | ||
early_stopping=early_stopping, | ||
validation_fraction=validation_fraction, | ||
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon) | ||
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, | ||
n_iter_no_change=n_iter_no_change) | ||
|
||
def predict(self, X): | ||
"""Predict using the multi-layer perceptron model. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -420,6 +420,7 @@ def test_params_errors(): | |
assert_raises(ValueError, clf(beta_2=1).fit, X, y) | ||
assert_raises(ValueError, clf(beta_2=-0.5).fit, X, y) | ||
assert_raises(ValueError, clf(epsilon=-0.5).fit, X, y) | ||
assert_raises(ValueError, clf(n_iter_no_change=-1).fit, X, y) | ||
|
||
assert_raises(ValueError, clf(solver='hadoken').fit, X, y) | ||
assert_raises(ValueError, clf(learning_rate='converge').fit, X, y) | ||
|
@@ -588,3 +589,47 @@ def test_warm_start(): | |
'classes as in the previous call to fit.' | ||
' Previously got [0 1 2], `y` has %s' % np.unique(y_i)) | ||
assert_raise_message(ValueError, message, clf.fit, X, y_i) | ||
|
||
|
||
def test_n_iter_no_change(): | ||
# test n_iter_no_change using binary data set | ||
# the classifying fitting process is not prone to loss curve fluctuations | ||
X = X_digits_binary[:100] | ||
y = y_digits_binary[:100] | ||
tol = 0.01 | ||
max_iter = 3000 | ||
|
||
# test multiple n_iter_no_change | ||
for n_iter_no_change in [2, 5, 10, 50, 100]: | ||
clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd', | ||
n_iter_no_change=n_iter_no_change) | ||
clf.fit(X, y) | ||
|
||
# validate n_iter_no_change | ||
assert_equal(clf._no_improvement_count, n_iter_no_change + 1) | ||
assert_greater(max_iter, clf.n_iter_) | ||
|
||
|
||
@ignore_warnings(category=ConvergenceWarning) | ||
def test_n_iter_no_change_inf(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add |
||
# test n_iter_no_change using binary data set | ||
# the fitting process should go to max_iter iterations | ||
X = X_digits_binary[:100] | ||
y = y_digits_binary[:100] | ||
|
||
# set a ridiculous tolerance | ||
# this should always trigger _update_no_improvement_count() | ||
tol = 1e9 | ||
|
||
# fit | ||
n_iter_no_change = np.inf | ||
max_iter = 3000 | ||
clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd', | ||
n_iter_no_change=n_iter_no_change) | ||
clf.fit(X, y) | ||
|
||
# validate n_iter_no_change doesn't cause early stopping | ||
assert_equal(clf.n_iter_, max_iter) | ||
|
||
# validate _update_no_improvement_count() was always triggered | ||
assert_equal(clf._no_improvement_count, clf.n_iter_ - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double backticks for tol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be
neural_network
everywhere instead ofmultilayer_perceptron