diff --git a/doc/whats_new/v0.23.rst b/doc/whats_new/v0.23.rst index 05097a4fa9516..e3bce16e96ae2 100644 --- a/doc/whats_new/v0.23.rst +++ b/doc/whats_new/v0.23.rst @@ -169,7 +169,7 @@ Changelog deprecated. It has no effect. :pr:`11950` by :user:`Jeremie du Boisberranger `. -- |API| The ``random_state`` parameter has been added to +- |API| The ``random_state`` parameter has been added to :class:`cluster.AffinityPropagation`. :pr:`16801` by :user:`rcwoolston` and :user:`Chiara Marmo `. diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 00557179ac38b..3cebba4182b1d 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -166,6 +166,10 @@ Changelog :pr:`17603`, :pr:`17604`, :pr:`17606`, :pr:`17608`, :pr:`17609`, :pr:`17633` by :user:`Alex Henrie `. +- |Enhancement| Avoid converting float32 input to float64 in + :class:`neural_network.BernoulliRBM`. + :pr:`16352` by :user:`Arthur Imbert `. + :mod:`sklearn.preprocessing` ............................ diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index fcb4e90772598..5a41ad01c696c 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -131,7 +131,7 @@ def transform(self, X): """ check_is_fitted(self) - X = check_array(X, accept_sparse='csr', dtype=np.float64) + X = check_array(X, accept_sparse='csr', dtype=(np.float64, np.float32)) return self._mean_hiddens(X) def _mean_hiddens(self, v): @@ -344,16 +344,20 @@ def fit(self, X, y=None): self : BernoulliRBM The fitted model. """ - X = self._validate_data(X, accept_sparse='csr', dtype=np.float64) + X = self._validate_data( + X, accept_sparse='csr', dtype=(np.float64, np.float32) + ) n_samples = X.shape[0] rng = check_random_state(self.random_state) self.components_ = np.asarray( rng.normal(0, 0.01, (self.n_components, X.shape[1])), - order='F') - self.intercept_hidden_ = np.zeros(self.n_components, ) - self.intercept_visible_ = np.zeros(X.shape[1], ) - self.h_samples_ = np.zeros((self.batch_size, self.n_components)) + order='F', + dtype=X.dtype) + self.intercept_hidden_ = np.zeros(self.n_components, dtype=X.dtype) + self.intercept_visible_ = np.zeros(X.shape[1], dtype=X.dtype) + self.h_samples_ = np.zeros((self.batch_size, self.n_components), + dtype=X.dtype) n_batches = int(np.ceil(float(n_samples) / self.batch_size)) batch_slices = list(gen_even_slices(n_batches * self.batch_size, diff --git a/sklearn/neural_network/tests/test_rbm.py b/sklearn/neural_network/tests/test_rbm.py index e319e0e4f3428..22d3b1c75fb01 100644 --- a/sklearn/neural_network/tests/test_rbm.py +++ b/sklearn/neural_network/tests/test_rbm.py @@ -1,9 +1,11 @@ import sys import re +import pytest import numpy as np from scipy.sparse import csc_matrix, csr_matrix, lil_matrix -from sklearn.utils._testing import (assert_almost_equal, assert_array_equal) +from sklearn.utils._testing import (assert_almost_equal, assert_array_equal, + assert_allclose) from sklearn.datasets import load_digits from io import StringIO @@ -189,3 +191,43 @@ def test_sparse_and_verbose(): r" time = (\d|\.)+s", s) finally: sys.stdout = old_stdout + + +@pytest.mark.parametrize("dtype_in, dtype_out", [ + (np.float32, np.float32), + (np.float64, np.float64), + (np.int, np.float64)]) +def test_transformer_dtypes_casting(dtype_in, dtype_out): + X = Xdigits[:100].astype(dtype_in) + rbm = BernoulliRBM(n_components=16, batch_size=5, n_iter=5, + random_state=42) + Xt = rbm.fit_transform(X) + + # dtype_in and dtype_out should be consistent + assert Xt.dtype == dtype_out, ('transform dtype: {} - original dtype: {}' + .format(Xt.dtype, X.dtype)) + + +def test_convergence_dtype_consistency(): + # float 64 transformer + X_64 = Xdigits[:100].astype(np.float64) + rbm_64 = BernoulliRBM(n_components=16, batch_size=5, n_iter=5, + random_state=42) + Xt_64 = rbm_64.fit_transform(X_64) + + # float 32 transformer + X_32 = Xdigits[:100].astype(np.float32) + rbm_32 = BernoulliRBM(n_components=16, batch_size=5, n_iter=5, + random_state=42) + Xt_32 = rbm_32.fit_transform(X_32) + + # results and attributes should be close enough in 32 bit and 64 bit + assert_allclose(Xt_64, Xt_32, + rtol=1e-06, atol=0) + assert_allclose(rbm_64.intercept_hidden_, rbm_32.intercept_hidden_, + rtol=1e-06, atol=0) + assert_allclose(rbm_64.intercept_visible_, rbm_32.intercept_visible_, + rtol=1e-05, atol=0) + assert_allclose(rbm_64.components_, rbm_32.components_, + rtol=1e-03, atol=0) + assert_allclose(rbm_64.h_samples_, rbm_32.h_samples_)