From 9fc00067e944e9d70d503ed97bd1123c52f8894c Mon Sep 17 00:00:00 2001 From: Srimukh Sripada Date: Sat, 4 Jul 2020 08:55:38 +0000 Subject: [PATCH] ENH 32-bit support for MLPClassifier and MLPRegressor (#17759) --- doc/whats_new/v0.24.rst | 10 +++- .../neural_network/_multilayer_perceptron.py | 28 ++++++--- sklearn/neural_network/tests/test_mlp.py | 58 +++++++++++++++++++ 3 files changed, 84 insertions(+), 12 deletions(-) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 096d05b91db1d..c41f761de1018 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -164,8 +164,8 @@ Changelog class to be used when computing the precision and recall statistics. :pr:`17569` by :user:`Guillaume Lemaitre `. -- |Feature| :func:`metrics.plot_confusion_matrix` now supports making colorbar - optional in the matplotlib plot by setting colorbar=False. :pr:`17192` by +- |Feature| :func:`metrics.plot_confusion_matrix` now supports making colorbar + optional in the matplotlib plot by setting colorbar=False. :pr:`17192` by :user:`Avi Gupta ` :mod:`sklearn.model_selection` @@ -219,7 +219,11 @@ Changelog - |Enhancement| Avoid converting float32 input to float64 in :class:`neural_network.BernoulliRBM`. - :pr:`16352` by :user:`Arthur Imbert `. + :pr:`16352` by :user:`Arthur Imbert `. + +- |Enhancement| Support 32-bit computations in :class:`neural_network.MLPClassifier` + and :class:`neural_network.MLPRegressor`. + :pr:`17759` by :user:`Srimukh Sripada `. :mod:`sklearn.preprocessing` ............................ diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index b1045d450508f..07be594f1e6e0 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -289,7 +289,7 @@ def _backprop(self, X, y, activations, deltas, coef_grads, return loss, coef_grads, intercept_grads - def _initialize(self, y, layer_units): + def _initialize(self, y, layer_units, dtype): # set all attributes, allocate weights etc for first call # Initialize parameters self.n_iter_ = 0 @@ -315,7 +315,8 @@ def _initialize(self, y, layer_units): for i in range(self.n_layers_ - 1): coef_init, intercept_init = self._init_coef(layer_units[i], - layer_units[i + 1]) + layer_units[i + 1], + dtype) self.coefs_.append(coef_init) self.intercepts_.append(intercept_init) @@ -328,7 +329,7 @@ def _initialize(self, y, layer_units): else: self.best_loss_ = np.inf - def _init_coef(self, fan_in, fan_out): + def _init_coef(self, fan_in, fan_out, dtype): # Use the initialization method recommended by # Glorot et al. factor = 6. @@ -341,6 +342,8 @@ def _init_coef(self, fan_in, fan_out): (fan_in, fan_out)) intercept_init = self._random_state.uniform(-init_bound, init_bound, fan_out) + coef_init = coef_init.astype(dtype, copy=False) + intercept_init = intercept_init.astype(dtype, copy=False) return coef_init, intercept_init def _fit(self, X, y, incremental=False): @@ -357,6 +360,7 @@ def _fit(self, X, y, incremental=False): hidden_layer_sizes) X, y = self._validate_input(X, y, incremental) + n_samples, n_features = X.shape # Ensure y is 2D @@ -374,17 +378,19 @@ def _fit(self, X, y, incremental=False): if not hasattr(self, 'coefs_') or (not self.warm_start and not incremental): # First time training the model - self._initialize(y, layer_units) + self._initialize(y, layer_units, X.dtype) # Initialize lists activations = [X] + [None] * (len(layer_units) - 1) deltas = [None] * (len(activations) - 1) - coef_grads = [np.empty((n_fan_in_, n_fan_out_)) for n_fan_in_, + coef_grads = [np.empty((n_fan_in_, n_fan_out_), dtype=X.dtype) + for n_fan_in_, n_fan_out_ in zip(layer_units[:-1], layer_units[1:])] - intercept_grads = [np.empty(n_fan_out_) for n_fan_out_ in + intercept_grads = [np.empty(n_fan_out_, dtype=X.dtype) + for n_fan_out_ in layer_units[1:]] # Run the Stochastic optimization solver @@ -960,7 +966,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", *, def _validate_input(self, X, y, incremental): X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'], - multi_output=True) + multi_output=True, + dtype=(np.float64, np.float32)) if y.ndim == 2 and y.shape[1] == 1: y = column_or_1d(y, warn=True) @@ -982,7 +989,9 @@ def _validate_input(self, X, y, incremental): " `self.classes_` has %s. 'y' has %s." % (self.classes_, classes)) - y = self._label_binarizer.transform(y) + # This downcast to bool is to prevent upcasting when working with + # float32 data + y = self._label_binarizer.transform(y).astype(np.bool) return X, y def predict(self, X): @@ -1393,7 +1402,8 @@ def predict(self, X): def _validate_input(self, X, y, incremental): X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'], - multi_output=True, y_numeric=True) + multi_output=True, y_numeric=True, + dtype=(np.float64, np.float32)) if y.ndim == 2 and y.shape[1] == 1: y = column_or_1d(y, warn=True) return X, y diff --git a/sklearn/neural_network/tests/test_mlp.py b/sklearn/neural_network/tests/test_mlp.py index 85b0840445a5a..8a1c5f2a5d232 100644 --- a/sklearn/neural_network/tests/test_mlp.py +++ b/sklearn/neural_network/tests/test_mlp.py @@ -716,3 +716,61 @@ def test_early_stopping_stratified(): ValueError, match='The least populated class in y has only 1 member'): mlp.fit(X, y) + + +def test_mlp_classifier_dtypes_casting(): + # Compare predictions for different dtypes + mlp_64 = MLPClassifier(alpha=1e-5, + hidden_layer_sizes=(5, 3), + random_state=1, max_iter=50) + mlp_64.fit(X_digits[:300], y_digits[:300]) + pred_64 = mlp_64.predict(X_digits[300:]) + proba_64 = mlp_64.predict_proba(X_digits[300:]) + + mlp_32 = MLPClassifier(alpha=1e-5, + hidden_layer_sizes=(5, 3), + random_state=1, max_iter=50) + mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300]) + pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32)) + proba_32 = mlp_32.predict_proba(X_digits[300:].astype(np.float32)) + + assert_array_equal(pred_64, pred_32) + assert_allclose(proba_64, proba_32, rtol=1e-02) + + +def test_mlp_regressor_dtypes_casting(): + mlp_64 = MLPRegressor(alpha=1e-5, + hidden_layer_sizes=(5, 3), + random_state=1, max_iter=50) + mlp_64.fit(X_digits[:300], y_digits[:300]) + pred_64 = mlp_64.predict(X_digits[300:]) + + mlp_32 = MLPRegressor(alpha=1e-5, + hidden_layer_sizes=(5, 3), + random_state=1, max_iter=50) + mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300]) + pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32)) + + assert_allclose(pred_64, pred_32, rtol=1e-04) + + +@pytest.mark.parametrize('dtype', [np.float32, np.float64]) +@pytest.mark.parametrize('Estimator', [MLPClassifier, MLPRegressor]) +def test_mlp_param_dtypes(dtype, Estimator): + # Checks if input dtype is used for network parameters + # and predictions + X, y = X_digits.astype(dtype), y_digits + mlp = Estimator(alpha=1e-5, + hidden_layer_sizes=(5, 3), + random_state=1, max_iter=50) + mlp.fit(X[:300], y[:300]) + pred = mlp.predict(X[300:]) + + assert all([intercept.dtype == dtype + for intercept in mlp.intercepts_]) + + assert all([coef.dtype == dtype + for coef in mlp.coefs_]) + + if Estimator == MLPRegressor: + assert pred.dtype == dtype