New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG + 1] ENH add check_inverse in FunctionTransformer #9399
Changes from 4 commits
4659cb4
72d3c54
df07603
9a5777c
bd7ad2f
5c1851b
4fd988c
3a764a7
586e8ca
43f876c
f3c0d10
7a19979
e59f493
6cb5b5d
72e2005
45e0cb3
afdeca7
e4045a1
c8c23fa
4276618
0297a4a
677cd2a
cec6f53
5238a33
31abd47
4d31e52
65b134a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
import warnings | ||
|
||
from ..base import BaseEstimator, TransformerMixin | ||
from ..utils import check_array | ||
from ..utils import check_array, check_random_state, safe_indexing | ||
from ..utils.testing import assert_allclose_dense_sparse | ||
from ..externals.six import string_types | ||
|
||
|
||
|
@@ -59,23 +60,59 @@ class FunctionTransformer(BaseEstimator, TransformerMixin): | |
|
||
.. deprecated::0.19 | ||
|
||
check_inverse : bool, (default=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spurious comma ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On line 23 above (can't put a comment there), it is stated that "A FunctionTransformer will not do any checks on its function's output.", which is still correct for the default, but might get an update to mention check_inverse ? |
||
Whether to check that ``transform`` followed by ``inverse_transform`` | ||
or ``func`` followed by ``inverse_func`` leads to the original inputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mentioning both transform/inverse_transform and func/inverse_func is a leftover from copying from the other PR? |
||
|
||
.. versionadded:: 0.20 | ||
|
||
kw_args : dict, optional | ||
Dictionary of additional keyword arguments to pass to func. | ||
|
||
inv_kw_args : dict, optional | ||
Dictionary of additional keyword arguments to pass to inverse_func. | ||
|
||
random_state : int, RandomState instance or None, optional (default=None) | ||
If int, random_state is the seed used by the random number generator; | ||
If RandomState instance, random_state is the random number generator; | ||
If None, the random number generator is the RandomState instance used | ||
by np.random. Note that this is used to compute if func and | ||
inverse_func are the inverse of each other. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would clarify that this is only done for |
||
|
||
|
||
""" | ||
def __init__(self, func=None, inverse_func=None, validate=True, | ||
accept_sparse=False, pass_y='deprecated', | ||
kw_args=None, inv_kw_args=None): | ||
accept_sparse=False, pass_y='deprecated', check_inverse=False, | ||
kw_args=None, inv_kw_args=None, random_state=None): | ||
self.func = func | ||
self.inverse_func = inverse_func | ||
self.validate = validate | ||
self.accept_sparse = accept_sparse | ||
self.pass_y = pass_y | ||
self.check_inverse = check_inverse | ||
self.kw_args = kw_args | ||
self.inv_kw_args = inv_kw_args | ||
self.random_state = random_state | ||
|
||
def _validate_inverse(self, X): | ||
"""Check that func and inverse_func are the inverse.""" | ||
random_state = check_random_state(self.random_state) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. utils.resample? |
||
n_subsample = min(100, X.shape[0]) | ||
subsample_idx = random_state.choice(range(X.shape[0]), | ||
size=n_subsample, | ||
replace=False) | ||
|
||
X_sel = safe_indexing(X, subsample_idx) | ||
print(subsample_idx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. print? |
||
try: | ||
assert_allclose_dense_sparse( | ||
X_sel, self.inverse_transform(self.transform(X_sel)), | ||
atol=1e-7) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not the default? |
||
except AssertionError: | ||
raise ValueError("The provided functions are not strictly" | ||
" inverse of each other. If you are sure you" | ||
" want to proceed regardless, set" | ||
" 'check_inverse=False'") | ||
|
||
def fit(self, X, y=None): | ||
"""Fit transformer by checking X. | ||
|
@@ -93,6 +130,8 @@ def fit(self, X, y=None): | |
""" | ||
if self.validate: | ||
check_array(X, self.accept_sparse) | ||
if self.check_inverse: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self._validate_inverse(X) | ||
return self | ||
|
||
def transform(self, X, y='deprecated'): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
import numpy as np | ||
from scipy import sparse | ||
|
||
from sklearn.preprocessing import FunctionTransformer | ||
from sklearn.utils.testing import assert_equal, assert_array_equal | ||
from sklearn.utils.testing import assert_warns_message | ||
from sklearn.utils.testing import (assert_equal, assert_array_equal, | ||
assert_allclose_dense_sparse) | ||
from sklearn.utils.testing import assert_warns_message, assert_raises_regex | ||
|
||
|
||
def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X): | ||
|
@@ -126,3 +128,33 @@ def test_inverse_transform(): | |
F.inverse_transform(F.transform(X)), | ||
np.around(np.sqrt(X), decimals=3), | ||
) | ||
|
||
|
||
def test_check_inverse(): | ||
X_dense = np.array([1, 4, 9, 16], dtype=np.float64).reshape((2, 2)) | ||
|
||
X_list = [X_dense, | ||
sparse.csr_matrix(X_dense), | ||
sparse.csc_matrix(X_dense)] | ||
|
||
for X in X_list: | ||
print(X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. print? |
||
if sparse.issparse(X): | ||
accept_sparse = True | ||
else: | ||
accept_sparse = False | ||
trans = FunctionTransformer(func=np.sqrt, | ||
inverse_func=np.around, | ||
accept_sparse=accept_sparse, | ||
check_inverse=True) | ||
assert_raises_regex(ValueError, "The provided functions are not" | ||
" strictly inverse of each other. If you are sure" | ||
" you want to proceed regardless, set" | ||
" 'check_inverse=False'", | ||
trans.fit, X) | ||
trans = FunctionTransformer(func=np.expm1, | ||
inverse_func=np.log1p, | ||
accept_sparse=accept_sparse, | ||
check_inverse=True) | ||
Xt = trans.fit_transform(X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assert_no_warning? |
||
assert_allclose_dense_sparse(X, trans.inverse_transform(Xt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say "You" but doesn't matter ;)