Skip to content

Commit

Permalink
MNT do not call fit twice in TransformedTargetetRegressor (#11641)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored and jnothman committed Feb 19, 2019
1 parent 8b9c1a3 commit 03df72f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
21 changes: 14 additions & 7 deletions sklearn/compose/_target.py
Expand Up @@ -113,6 +113,12 @@ def __init__(self, regressor=None, transformer=None,
self.check_inverse = check_inverse

def _fit_transformer(self, y):
"""Check transformer and fit transformer.
Create the default transformer, fit it and make additional inverse
check on a subset (optional).
"""
if (self.transformer is not None and
(self.func is not None or self.inverse_func is not None)):
raise ValueError("'transformer' and functions 'func'/"
Expand Down Expand Up @@ -177,19 +183,20 @@ def fit(self, X, y, sample_weight=None):
y_2d = y
self._fit_transformer(y_2d)

if self.regressor is None:
from ..linear_model import LinearRegression
self.regressor_ = LinearRegression()
else:
self.regressor_ = clone(self.regressor)

# transform y and convert back to 1d array if needed
y_trans = self.transformer_.fit_transform(y_2d)
y_trans = self.transformer_.transform(y_2d)
# FIXME: a FunctionTransformer can return a 1D array even when validate
# is set to True. Therefore, we need to check the number of dimension
# first.
if y_trans.ndim == 2 and y_trans.shape[1] == 1:
y_trans = y_trans.squeeze(axis=1)

if self.regressor is None:
from ..linear_model import LinearRegression
self.regressor_ = LinearRegression()
else:
self.regressor_ = clone(self.regressor)

if sample_weight is None:
self.regressor_.fit(X, y_trans)
else:
Expand Down
28 changes: 28 additions & 0 deletions sklearn/compose/tests/test_target.py
Expand Up @@ -265,3 +265,31 @@ def test_transform_target_regressor_ensure_y_array():
tt.predict(X.tolist())
assert_raises(AssertionError, tt.fit, X, y.tolist())
assert_raises(AssertionError, tt.predict, X)


class DummyTransformer(BaseEstimator, TransformerMixin):
"""Dummy transformer which count how many time fit was called."""
def __init__(self, fit_counter=0):
self.fit_counter = fit_counter

def fit(self, X, y=None):
self.fit_counter += 1
return self

def transform(self, X):
return X

def inverse_transform(self, X):
return X


@pytest.mark.parametrize("check_inverse", [False, True])
def test_transform_target_regressor_count_fit(check_inverse):
# regression test for gh-issue #11618
# check that we only call a single time fit for the transformer
X, y = friedman
ttr = TransformedTargetRegressor(
transformer=DummyTransformer(), check_inverse=check_inverse
)
ttr.fit(X, y)
assert ttr.transformer_.fit_counter == 1

0 comments on commit 03df72f

Please sign in to comment.