diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 5825409f0f112..8bcb14363d69c 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -610,6 +610,15 @@ a transformer that applies a log transformation in a pipeline, do:: array([[ 0. , 0.69314718], [ 1.09861229, 1.38629436]]) +You can ensure that ``func`` and ``inverse_func`` are the inverse of each other +by setting ``check_inverse=True`` and calling ``fit`` before +``transform``. Please note that a warning is raised and can be turned into an +error with a ``filterwarnings``:: + + >>> import warnings + >>> warnings.filterwarnings("error", message=".*check_inverse*.", + ... category=UserWarning, append=False) + For a full code example that demonstrates using a :class:`FunctionTransformer` to do custom feature selection, see :ref:`sphx_glr_auto_examples_preprocessing_plot_function_transformer.py` diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 13efcfd6cc84d..11f75b0d31bfa 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -40,7 +40,7 @@ Classifiers and regressors - Added :class:`naive_bayes.ComplementNB`, which implements the Complement Naive Bayes classifier described in Rennie et al. (2003). By :user:`Michael A. Alcorn `. - + Model evaluation - Added the :func:`metrics.balanced_accuracy` metric and a corresponding @@ -65,6 +65,11 @@ Classifiers and regressors :class:`sklearn.naive_bayes.GaussianNB` to give a precise control over variances calculation. :issue:`9681` by :user:`Dmitry Mottl `. +- A parameter ``check_inverse`` was added to :class:`FunctionTransformer` + to ensure that ``func`` and ``inverse_func`` are the inverse of each + other. + :issue:`9399` by :user:`Guillaume Lemaitre `. + Model evaluation and meta-estimators - A scorer based on :func:`metrics.brier_score_loss` is also available. diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index 82955b6977691..f2a1290685992 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -2,6 +2,7 @@ from ..base import BaseEstimator, TransformerMixin from ..utils import check_array +from ..utils.testing import assert_allclose_dense_sparse from ..externals.six import string_types @@ -19,8 +20,6 @@ class FunctionTransformer(BaseEstimator, TransformerMixin): function. This is useful for stateless transformations such as taking the log of frequencies, doing custom scaling, etc. - A FunctionTransformer will not do any checks on its function's output. - Note: If a lambda is used as the function, then the resulting transformer will not be pickleable. @@ -59,6 +58,13 @@ class FunctionTransformer(BaseEstimator, TransformerMixin): .. deprecated::0.19 + check_inverse : bool, default=True + Whether to check that or ``func`` followed by ``inverse_func`` leads to + the original inputs. It can be used for a sanity check, raising a + warning when the condition is not fulfilled. + + .. versionadded:: 0.20 + kw_args : dict, optional Dictionary of additional keyword arguments to pass to func. @@ -67,16 +73,30 @@ class FunctionTransformer(BaseEstimator, TransformerMixin): """ def __init__(self, func=None, inverse_func=None, validate=True, - accept_sparse=False, pass_y='deprecated', + accept_sparse=False, pass_y='deprecated', check_inverse=True, kw_args=None, inv_kw_args=None): self.func = func self.inverse_func = inverse_func self.validate = validate self.accept_sparse = accept_sparse self.pass_y = pass_y + self.check_inverse = check_inverse self.kw_args = kw_args self.inv_kw_args = inv_kw_args + def _check_inverse_transform(self, X): + """Check that func and inverse_func are the inverse.""" + idx_selected = slice(None, None, max(1, X.shape[0] // 100)) + try: + assert_allclose_dense_sparse( + X[idx_selected], + self.inverse_transform(self.transform(X[idx_selected]))) + except AssertionError: + warnings.warn("The provided functions are not strictly" + " inverse of each other. If you are sure you" + " want to proceed regardless, set" + " 'check_inverse=False'.", UserWarning) + def fit(self, X, y=None): """Fit transformer by checking X. @@ -92,7 +112,10 @@ def fit(self, X, y=None): self """ if self.validate: - check_array(X, self.accept_sparse) + X = check_array(X, self.accept_sparse) + if (self.check_inverse and not (self.func is None or + self.inverse_func is None)): + self._check_inverse_transform(X) return self def transform(self, X, y='deprecated'): diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index 4e9cb26b64a9d..4d166457777cc 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -1,8 +1,10 @@ import numpy as np +from scipy import sparse from sklearn.preprocessing import FunctionTransformer -from sklearn.utils.testing import assert_equal, assert_array_equal -from sklearn.utils.testing import assert_warns_message +from sklearn.utils.testing import (assert_equal, assert_array_equal, + assert_allclose_dense_sparse) +from sklearn.utils.testing import assert_warns_message, assert_no_warnings def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X): @@ -126,3 +128,43 @@ def test_inverse_transform(): F.inverse_transform(F.transform(X)), np.around(np.sqrt(X), decimals=3), ) + + +def test_check_inverse(): + X_dense = np.array([1, 4, 9, 16], dtype=np.float64).reshape((2, 2)) + + X_list = [X_dense, + sparse.csr_matrix(X_dense), + sparse.csc_matrix(X_dense)] + + for X in X_list: + if sparse.issparse(X): + accept_sparse = True + else: + accept_sparse = False + trans = FunctionTransformer(func=np.sqrt, + inverse_func=np.around, + accept_sparse=accept_sparse, + check_inverse=True) + assert_warns_message(UserWarning, + "The provided functions are not strictly" + " inverse of each other. If you are sure you" + " want to proceed regardless, set" + " 'check_inverse=False'.", + trans.fit, X) + + trans = FunctionTransformer(func=np.expm1, + inverse_func=np.log1p, + accept_sparse=accept_sparse, + check_inverse=True) + Xt = assert_no_warnings(trans.fit_transform, X) + assert_allclose_dense_sparse(X, trans.inverse_transform(Xt)) + + # check that we don't check inverse when one of the func or inverse is not + # provided. + trans = FunctionTransformer(func=np.expm1, inverse_func=None, + check_inverse=True) + assert_no_warnings(trans.fit, X_dense) + trans = FunctionTransformer(func=None, inverse_func=np.expm1, + check_inverse=True) + assert_no_warnings(trans.fit, X_dense)