Skip to content

Commit

Permalink
[MRG+1] Catch cases for different class size in MLPClassifier with wa…
Browse files Browse the repository at this point in the history
…rm start (scikit-learn#7976)  (scikit-learn#8035)

* added test that fails

* generate standard value error for different class size

* moved num_classes one class down

* fixed over-indented lines

* standard error occurs a layer up.

* created a different label comparison for warm_start

* spaces around multiplication sign.

* reworded error and added another edge case.

* fixed pep8 violation

* make test shorter

* updated ignore warning
  • Loading branch information
vincentpham1991 authored and raghavrv committed Jan 5, 2017
1 parent 268ea1a commit c76e8dd
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
26 changes: 26 additions & 0 deletions sklearn/neural_network/multilayer_perceptron.py
Expand Up @@ -908,6 +908,13 @@ def _validate_input(self, X, y, incremental):
self._label_binarizer = LabelBinarizer()
self._label_binarizer.fit(y)
self.classes_ = self._label_binarizer.classes_
elif self.warm_start:
classes = unique_labels(y)
if set(classes) != set(self.classes_):
raise ValueError("warm_start can only be used where `y` has "
"the same classes as in the previous "
"call to fit. Previously got %s, `y` has %s" %
(self.classes_, classes))
else:
classes = unique_labels(y)
if np.setdiff1d(classes, self.classes_, assume_unique=True):
Expand Down Expand Up @@ -939,6 +946,25 @@ def predict(self, X):

return self._label_binarizer.inverse_transform(y_pred)

def fit(self, X, y):
"""Fit the model to data matrix X and target(s) y.
Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
The input data.
y : array-like, shape (n_samples,) or (n_samples, n_outputs)
The target values (class labels in classification, real numbers in
regression).
Returns
-------
self : returns a trained MLP model.
"""
return self._fit(X, y, incremental=(self.warm_start and
hasattr(self, "classes_")))

@property
def partial_fit(self):
"""Fit the model to data matrix X and target y.
Expand Down
34 changes: 33 additions & 1 deletion sklearn/neural_network/tests/test_mlp.py
Expand Up @@ -12,7 +12,7 @@

from numpy.testing import assert_almost_equal, assert_array_equal

from sklearn.datasets import load_digits, load_boston
from sklearn.datasets import load_digits, load_boston, load_iris
from sklearn.datasets import make_regression, make_multilabel_classification
from sklearn.exceptions import ConvergenceWarning
from sklearn.externals.six.moves import cStringIO as StringIO
Expand All @@ -24,6 +24,7 @@
from scipy.sparse import csr_matrix
from sklearn.utils.testing import (assert_raises, assert_greater, assert_equal,
assert_false, ignore_warnings)
from sklearn.utils.testing import assert_raise_message


np.seterr(all='warn')
Expand All @@ -49,6 +50,11 @@
Xboston = StandardScaler().fit_transform(boston.data)[: 200]
yboston = boston.target[:200]

iris = load_iris()

X_iris = iris.data
y_iris = iris.target


def test_alpha():
# Test that larger alpha yields weights closer to zero
Expand Down Expand Up @@ -556,3 +562,29 @@ def test_adaptive_learning_rate():
clf.fit(X, y)
assert_greater(clf.max_iter, clf.n_iter_)
assert_greater(1e-6, clf._optimizer.learning_rate)


@ignore_warnings(RuntimeError)
def test_warm_start():
X = X_iris
y = y_iris

y_2classes = np.array([0] * 75 + [1] * 75)
y_3classes = np.array([0] * 40 + [1] * 40 + [2] * 70)
y_3classes_alt = np.array([0] * 50 + [1] * 50 + [3] * 50)
y_4classes = np.array([0] * 37 + [1] * 37 + [2] * 38 + [3] * 38)
y_5classes = np.array([0] * 30 + [1] * 30 + [2] * 30 + [3] * 30 + [4] * 30)

# No error raised
clf = MLPClassifier(hidden_layer_sizes=2, solver='lbfgs',
warm_start=True).fit(X, y)
clf.fit(X, y)
clf.fit(X, y_3classes)

for y_i in (y_2classes, y_3classes_alt, y_4classes, y_5classes):
clf = MLPClassifier(hidden_layer_sizes=2, solver='lbfgs',
warm_start=True).fit(X, y)
message = ('warm_start can only be used where `y` has the same '
'classes as in the previous call to fit.'
' Previously got [0 1 2], `y` has %s' % np.unique(y_i))
assert_raise_message(ValueError, message, clf.fit, X, y_i)

0 comments on commit c76e8dd

Please sign in to comment.