Skip to content

Commit

Permalink
ENH 32-bit support for MLPClassifier and MLPRegressor (#17759)
Browse files Browse the repository at this point in the history
  • Loading branch information
postmalloc committed Jul 4, 2020
1 parent ffbb1b4 commit 9fc0006
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 12 deletions.
10 changes: 7 additions & 3 deletions doc/whats_new/v0.24.rst
Expand Up @@ -164,8 +164,8 @@ Changelog
class to be used when computing the precision and recall statistics.
:pr:`17569` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Feature| :func:`metrics.plot_confusion_matrix` now supports making colorbar
optional in the matplotlib plot by setting colorbar=False. :pr:`17192` by
- |Feature| :func:`metrics.plot_confusion_matrix` now supports making colorbar
optional in the matplotlib plot by setting colorbar=False. :pr:`17192` by
:user:`Avi Gupta <avigupta2612>`

:mod:`sklearn.model_selection`
Expand Down Expand Up @@ -219,7 +219,11 @@ Changelog

- |Enhancement| Avoid converting float32 input to float64 in
:class:`neural_network.BernoulliRBM`.
:pr:`16352` by :user:`Arthur Imbert <Henley13>`.
:pr:`16352` by :user:`Arthur Imbert <Henley13>`.

- |Enhancement| Support 32-bit computations in :class:`neural_network.MLPClassifier`
and :class:`neural_network.MLPRegressor`.
:pr:`17759` by :user:`Srimukh Sripada <d3b0unce>`.

:mod:`sklearn.preprocessing`
............................
Expand Down
28 changes: 19 additions & 9 deletions sklearn/neural_network/_multilayer_perceptron.py
Expand Up @@ -289,7 +289,7 @@ def _backprop(self, X, y, activations, deltas, coef_grads,

return loss, coef_grads, intercept_grads

def _initialize(self, y, layer_units):
def _initialize(self, y, layer_units, dtype):
# set all attributes, allocate weights etc for first call
# Initialize parameters
self.n_iter_ = 0
Expand All @@ -315,7 +315,8 @@ def _initialize(self, y, layer_units):

for i in range(self.n_layers_ - 1):
coef_init, intercept_init = self._init_coef(layer_units[i],
layer_units[i + 1])
layer_units[i + 1],
dtype)
self.coefs_.append(coef_init)
self.intercepts_.append(intercept_init)

Expand All @@ -328,7 +329,7 @@ def _initialize(self, y, layer_units):
else:
self.best_loss_ = np.inf

def _init_coef(self, fan_in, fan_out):
def _init_coef(self, fan_in, fan_out, dtype):
# Use the initialization method recommended by
# Glorot et al.
factor = 6.
Expand All @@ -341,6 +342,8 @@ def _init_coef(self, fan_in, fan_out):
(fan_in, fan_out))
intercept_init = self._random_state.uniform(-init_bound, init_bound,
fan_out)
coef_init = coef_init.astype(dtype, copy=False)
intercept_init = intercept_init.astype(dtype, copy=False)
return coef_init, intercept_init

def _fit(self, X, y, incremental=False):
Expand All @@ -357,6 +360,7 @@ def _fit(self, X, y, incremental=False):
hidden_layer_sizes)

X, y = self._validate_input(X, y, incremental)

n_samples, n_features = X.shape

# Ensure y is 2D
Expand All @@ -374,17 +378,19 @@ def _fit(self, X, y, incremental=False):
if not hasattr(self, 'coefs_') or (not self.warm_start and not
incremental):
# First time training the model
self._initialize(y, layer_units)
self._initialize(y, layer_units, X.dtype)

# Initialize lists
activations = [X] + [None] * (len(layer_units) - 1)
deltas = [None] * (len(activations) - 1)

coef_grads = [np.empty((n_fan_in_, n_fan_out_)) for n_fan_in_,
coef_grads = [np.empty((n_fan_in_, n_fan_out_), dtype=X.dtype)
for n_fan_in_,
n_fan_out_ in zip(layer_units[:-1],
layer_units[1:])]

intercept_grads = [np.empty(n_fan_out_) for n_fan_out_ in
intercept_grads = [np.empty(n_fan_out_, dtype=X.dtype)
for n_fan_out_ in
layer_units[1:]]

# Run the Stochastic optimization solver
Expand Down Expand Up @@ -960,7 +966,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", *,

def _validate_input(self, X, y, incremental):
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'],
multi_output=True)
multi_output=True,
dtype=(np.float64, np.float32))
if y.ndim == 2 and y.shape[1] == 1:
y = column_or_1d(y, warn=True)

Expand All @@ -982,7 +989,9 @@ def _validate_input(self, X, y, incremental):
" `self.classes_` has %s. 'y' has %s." %
(self.classes_, classes))

y = self._label_binarizer.transform(y)
# This downcast to bool is to prevent upcasting when working with
# float32 data
y = self._label_binarizer.transform(y).astype(np.bool)
return X, y

def predict(self, X):
Expand Down Expand Up @@ -1393,7 +1402,8 @@ def predict(self, X):

def _validate_input(self, X, y, incremental):
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'],
multi_output=True, y_numeric=True)
multi_output=True, y_numeric=True,
dtype=(np.float64, np.float32))
if y.ndim == 2 and y.shape[1] == 1:
y = column_or_1d(y, warn=True)
return X, y
58 changes: 58 additions & 0 deletions sklearn/neural_network/tests/test_mlp.py
Expand Up @@ -716,3 +716,61 @@ def test_early_stopping_stratified():
ValueError,
match='The least populated class in y has only 1 member'):
mlp.fit(X, y)


def test_mlp_classifier_dtypes_casting():
# Compare predictions for different dtypes
mlp_64 = MLPClassifier(alpha=1e-5,
hidden_layer_sizes=(5, 3),
random_state=1, max_iter=50)
mlp_64.fit(X_digits[:300], y_digits[:300])
pred_64 = mlp_64.predict(X_digits[300:])
proba_64 = mlp_64.predict_proba(X_digits[300:])

mlp_32 = MLPClassifier(alpha=1e-5,
hidden_layer_sizes=(5, 3),
random_state=1, max_iter=50)
mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300])
pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32))
proba_32 = mlp_32.predict_proba(X_digits[300:].astype(np.float32))

assert_array_equal(pred_64, pred_32)
assert_allclose(proba_64, proba_32, rtol=1e-02)


def test_mlp_regressor_dtypes_casting():
mlp_64 = MLPRegressor(alpha=1e-5,
hidden_layer_sizes=(5, 3),
random_state=1, max_iter=50)
mlp_64.fit(X_digits[:300], y_digits[:300])
pred_64 = mlp_64.predict(X_digits[300:])

mlp_32 = MLPRegressor(alpha=1e-5,
hidden_layer_sizes=(5, 3),
random_state=1, max_iter=50)
mlp_32.fit(X_digits[:300].astype(np.float32), y_digits[:300])
pred_32 = mlp_32.predict(X_digits[300:].astype(np.float32))

assert_allclose(pred_64, pred_32, rtol=1e-04)


@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('Estimator', [MLPClassifier, MLPRegressor])
def test_mlp_param_dtypes(dtype, Estimator):
# Checks if input dtype is used for network parameters
# and predictions
X, y = X_digits.astype(dtype), y_digits
mlp = Estimator(alpha=1e-5,
hidden_layer_sizes=(5, 3),
random_state=1, max_iter=50)
mlp.fit(X[:300], y[:300])
pred = mlp.predict(X[300:])

assert all([intercept.dtype == dtype
for intercept in mlp.intercepts_])

assert all([coef.dtype == dtype
for coef in mlp.coefs_])

if Estimator == MLPRegressor:
assert pred.dtype == dtype

0 comments on commit 9fc0006

Please sign in to comment.