Skip to content

Commit

Permalink
FIX YeoJohnson transform lambda bounds (#12522)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored and jnothman committed Nov 14, 2018
1 parent afce882 commit 2ccc921
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.20.rst
Expand Up @@ -149,6 +149,10 @@ Changelog
:issue:`12317` by :user:`Eric Chang <chang>`.


- |Fix| Fixed a bug in :class:`preprocessing.PowerTransformer` where the
Yeo-Johnson transform was incorrect for lambda parameters outside of `[0, 2]`
:issue:`12522` by :user:`Nicolas Hug<NicolasHug>`.

:mod:`sklearn.utils`
........................

Expand Down
36 changes: 11 additions & 25 deletions sklearn/preprocessing/data.py
Expand Up @@ -2528,7 +2528,7 @@ class PowerTransformer(BaseEstimator, TransformerMixin):
>>> print(pt.fit(data))
PowerTransformer(copy=True, method='yeo-johnson', standardize=True)
>>> print(pt.lambdas_)
[1.38668178e+00 5.93926346e-09]
[ 1.38668178 -3.10053309]
>>> print(pt.transform(data))
[[-1.31616039 -0.70710678]
[ 0.20998268 -0.70710678]
Expand Down Expand Up @@ -2709,23 +2709,18 @@ def _box_cox_inverse_tranform(self, x, lmbda):
def _yeo_johnson_inverse_transform(self, x, lmbda):
"""Return inverse-transformed input x following Yeo-Johnson inverse
transform with parameter lambda.
Notes
-----
We're comparing lmbda to 1e-19 instead of strict equality to 0. See
scipy/special/_boxcox.pxd for a rationale behind this
"""
x_inv = np.zeros(x.shape, dtype=x.dtype)
x_inv = np.zeros_like(x)
pos = x >= 0

# when x >= 0
if lmbda < 1e-19:
if abs(lmbda) < np.spacing(1.):
x_inv[pos] = np.exp(x[pos]) - 1
else: # lmbda != 0
x_inv[pos] = np.power(x[pos] * lmbda + 1, 1 / lmbda) - 1

# when x < 0
if lmbda < 2 - 1e-19:
if abs(lmbda - 2) > np.spacing(1.):
x_inv[~pos] = 1 - np.power(-(2 - lmbda) * x[~pos] + 1,
1 / (2 - lmbda))
else: # lmbda == 2
Expand All @@ -2736,27 +2731,22 @@ def _yeo_johnson_inverse_transform(self, x, lmbda):
def _yeo_johnson_transform(self, x, lmbda):
"""Return transformed input x following Yeo-Johnson transform with
parameter lambda.
Notes
-----
We're comparing lmbda to 1e-19 instead of strict equality to 0. See
scipy/special/_boxcox.pxd for a rationale behind this
"""

out = np.zeros(shape=x.shape, dtype=x.dtype)
out = np.zeros_like(x)
pos = x >= 0 # binary mask

# when x >= 0
if lmbda < 1e-19:
out[pos] = np.log(x[pos] + 1)
if abs(lmbda) < np.spacing(1.):
out[pos] = np.log1p(x[pos])
else: # lmbda != 0
out[pos] = (np.power(x[pos] + 1, lmbda) - 1) / lmbda

# when x < 0
if lmbda < 2 - 1e-19:
if abs(lmbda - 2) > np.spacing(1.):
out[~pos] = -(np.power(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda)
else: # lmbda == 2
out[~pos] = -np.log(-x[~pos] + 1)
out[~pos] = -np.log1p(-x[~pos])

return out

Expand Down Expand Up @@ -2785,12 +2775,8 @@ def _neg_log_likelihood(lmbda):
x_trans = self._yeo_johnson_transform(x, lmbda)
n_samples = x.shape[0]

# Estimated mean and variance of the normal distribution
est_mean = x_trans.sum() / n_samples
est_var = np.power(x_trans - est_mean, 2).sum() / n_samples

loglike = -n_samples / 2 * np.log(est_var)
loglike += (lmbda - 1) * (np.sign(x) * np.log(np.abs(x) + 1)).sum()
loglike = -n_samples / 2 * np.log(x_trans.var())
loglike += (lmbda - 1) * (np.sign(x) * np.log1p(np.abs(x))).sum()

return -loglike

Expand Down
10 changes: 10 additions & 0 deletions sklearn/preprocessing/tests/test_data.py
Expand Up @@ -2207,6 +2207,16 @@ def test_optimization_power_transformer(method, lmbda):
assert_almost_equal(1, X_inv_trans.std(), decimal=1)


def test_yeo_johnson_darwin_example():
# test from original paper "A new family of power transformations to
# improve normality or symmetry" by Yeo and Johnson.
X = [6.1, -8.4, 1.0, 2.0, 0.7, 2.9, 3.5, 5.1, 1.8, 3.6, 7.0, 3.0, 9.3,
7.5, -6.0]
X = np.array(X).reshape(-1, 1)
lmbda = PowerTransformer(method='yeo-johnson').fit(X).lambdas_
assert np.allclose(lmbda, 1.305, atol=1e-3)


@pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson'])
def test_power_transformer_nans(method):
# Make sure lambda estimation is not influenced by NaN values
Expand Down

0 comments on commit 2ccc921

Please sign in to comment.