Skip to content

Commit

Permalink
[MRG] FIX Update power_transform docstring and add FutureWarning (#12317
Browse files Browse the repository at this point in the history
)
  • Loading branch information
chang authored and jnothman committed Oct 15, 2018
1 parent b420655 commit 854978f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 23 deletions.
67 changes: 46 additions & 21 deletions sklearn/preprocessing/data.py
Expand Up @@ -2480,7 +2480,7 @@ class PowerTransformer(BaseEstimator, TransformerMixin):
or other situations where normality is desired.
Currently, PowerTransformer supports the Box-Cox transform and the
Yeo-Johson transform. The optimal parameter for stabilizing variance and
Yeo-Johnson transform. The optimal parameter for stabilizing variance and
minimizing skewness is estimated through maximum likelihood.
Box-Cox requires input data to be strictly positive, while Yeo-Johnson
Expand Down Expand Up @@ -2535,8 +2535,8 @@ class PowerTransformer(BaseEstimator, TransformerMixin):
Notes
-----
NaNs are treated as missing values: disregarded in fit, and maintained in
transform.
NaNs are treated as missing values: disregarded in ``fit``, and maintained
in ``transform``.
For a comparison of the different scalers, transformers, and normalizers,
see :ref:`examples/preprocessing/plot_all_scaling.py
Expand Down Expand Up @@ -2835,18 +2835,19 @@ def _check_input(self, X, check_positive=False, check_shape=False,
return X


def power_transform(X, method='box-cox', standardize=True, copy=True):
"""Apply a power transform featurewise to make data more Gaussian-like.
def power_transform(X, method='warn', standardize=True, copy=True):
"""
Power transforms are a family of parametric, monotonic transformations
that are applied to make data more Gaussian-like. This is useful for
modeling issues related to heteroscedasticity (non-constant variance),
or other situations where normality is desired.
Currently, power_transform() supports the Box-Cox transform. Box-Cox
requires input data to be strictly positive. The optimal parameter
for stabilizing variance and minimizing skewness is estimated
through maximum likelihood.
Currently, power_transform supports the Box-Cox transform and the
Yeo-Johnson transform. The optimal parameter for stabilizing variance and
minimizing skewness is estimated through maximum likelihood.
Box-Cox requires input data to be strictly positive, while Yeo-Johnson
supports both positive or negative data.
By default, zero-mean, unit-variance normalization is applied to the
transformed data.
Expand All @@ -2858,49 +2859,73 @@ def power_transform(X, method='box-cox', standardize=True, copy=True):
X : array-like, shape (n_samples, n_features)
The data to be transformed using a power transformation.
method : str, (default='box-cox')
The power transform method. Currently, 'box-cox' (Box-Cox transform)
is the only option available.
method : str
The power transform method. Available methods are:
- 'yeo-johnson' [1]_, works with positive and negative values
- 'box-cox' [2]_, only works with strictly positive values
The default method will be changed from 'box-cox' to 'yeo-johnson'
in version 0.23. To suppress the FutureWarning, explicitly set the
parameter.
standardize : boolean, default=True
Set to True to apply zero-mean, unit-variance normalization to the
transformed output.
copy : boolean, optional, default=True
Set to False to perform inplace computation.
Set to False to perform inplace computation during transformation.
Returns
-------
X_trans : array-like, shape (n_samples, n_features)
The transformed data.
Examples
--------
>>> import numpy as np
>>> from sklearn.preprocessing import power_transform
>>> data = [[1, 2], [3, 2], [4, 5]]
>>> print(power_transform(data)) # doctest: +ELLIPSIS
>>> print(power_transform(data, method='box-cox')) # doctest: +ELLIPSIS
[[-1.332... -0.707...]
[ 0.256... -0.707...]
[ 1.076... 1.414...]]
See also
--------
PowerTransformer: Performs power transformation using the ``Transformer``
API (as part of a preprocessing :class:`sklearn.pipeline.Pipeline`).
PowerTransformer : Equivalent transformation with the
``Transformer`` API (e.g. as part of a preprocessing
:class:`sklearn.pipeline.Pipeline`).
quantile_transform : Maps data to a standard normal distribution with
the parameter `output_distribution='normal'`.
Notes
-----
NaNs are treated as missing values: disregarded to compute the statistics,
and maintained during the data transformation.
NaNs are treated as missing values: disregarded in ``fit``, and maintained
in ``transform``.
For a comparison of the different scalers, transformers, and normalizers,
see :ref:`examples/preprocessing/plot_all_scaling.py
<sphx_glr_auto_examples_preprocessing_plot_all_scaling.py>`.
References
----------
G.E.P. Box and D.R. Cox, "An Analysis of Transformations", Journal of the
Royal Statistical Society B, 26, 211-252 (1964).
.. [1] I.K. Yeo and R.A. Johnson, "A new family of power transformations to
improve normality or symmetry." Biometrika, 87(4), pp.954-959,
(2000).
.. [2] G.E.P. Box and D.R. Cox, "An Analysis of Transformations", Journal
of the Royal Statistical Society B, 26, 211-252 (1964).
"""
if method == 'warn':
warnings.warn("The default value of 'method' will change from "
"'box-cox' to 'yeo-johnson' in version 0.23. Set "
"the 'method' argument explicitly to silence this "
"warning in the meantime.",
FutureWarning)
method = 'box-cox'
pt = PowerTransformer(method=method, standardize=standardize, copy=copy)
return pt.fit_transform(X)

Expand Down
28 changes: 26 additions & 2 deletions sklearn/preprocessing/tests/test_data.py
Expand Up @@ -2031,7 +2031,10 @@ def test_power_transformer_1d():
pt = PowerTransformer(method='box-cox', standardize=standardize)

X_trans = pt.fit_transform(X)
X_trans_func = power_transform(X, standardize=standardize)
X_trans_func = power_transform(
X, method='box-cox',
standardize=standardize
)

X_expected, lambda_expected = stats.boxcox(X.flatten())

Expand All @@ -2055,7 +2058,10 @@ def test_power_transformer_2d():
pt = PowerTransformer(method='box-cox', standardize=standardize)

X_trans_class = pt.fit_transform(X)
X_trans_func = power_transform(X, standardize=standardize)
X_trans_func = power_transform(
X, method='box-cox',
standardize=standardize
)

for X_trans in [X_trans_class, X_trans_func]:
for j in range(X_trans.shape[1]):
Expand Down Expand Up @@ -2278,3 +2284,21 @@ def test_power_transformer_copy_False(method, standardize):

X_inv_trans = pt.inverse_transform(X_trans)
assert X_trans is X_inv_trans


def test_power_transform_default_method():
X = np.abs(X_2d)

future_warning_message = (
"The default value of 'method' "
"will change from 'box-cox'"
)
assert_warns_message(FutureWarning, future_warning_message,
power_transform, X)

with warnings.catch_warnings():
warnings.simplefilter('ignore')
X_trans_default = power_transform(X)

X_trans_boxcox = power_transform(X, method='box-cox')
assert_array_equal(X_trans_boxcox, X_trans_default)

0 comments on commit 854978f

Please sign in to comment.