Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] QuantileTransformer #8363

Merged
merged 107 commits into from Jun 9, 2017
Merged
Show file tree
Hide file tree
Changes from 102 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
14f3b15
resurrect quantile scaler
turian Jul 21, 2013
cc8d264
move the code in the pre-processing module
Feb 15, 2017
b78b689
first draft
Feb 15, 2017
bb1829a
Add tests.
tguillemot Feb 15, 2017
5c9bcbc
Fix bug in QuantileNormalizer.
tguillemot Feb 15, 2017
8d4b9cc
Add quantile_normalizer.
tguillemot Feb 15, 2017
45e48f7
Implement pickling
Feb 15, 2017
0a646c1
create a specific function for dense transform
Feb 15, 2017
4dbdb6e
Create a fit function for the dense case
Feb 15, 2017
f723edb
Create a toy examples
Feb 15, 2017
5c8d496
First draft with sparse matrices
Feb 16, 2017
bcbf79b
remove useless functions and non-negative sparse compatibility
Feb 16, 2017
1be3f5b
fix slice call
Feb 16, 2017
86b4a22
Fix tests of QuantileNormalizer.
tguillemot Feb 16, 2017
a742a61
Fix estimator compatibility
Feb 16, 2017
79927b6
fix doc
Feb 16, 2017
cc680a7
Add negative ValueError tests for QuantileNormalizer.
tguillemot Feb 16, 2017
1260a70
Fix cosmetics
Feb 16, 2017
c043c07
Fix compatibility numpy <= 1.8
Feb 16, 2017
0a7dc4d
Add n_features tests and correct ValueError.
tguillemot Feb 16, 2017
36a8870
PEP8
Feb 16, 2017
94e26ad
fix fill_value for early scipy compatibility
Feb 16, 2017
f552529
simplify sampling
Feb 16, 2017
8a4592c
Fix tests.
tguillemot Feb 16, 2017
9070871
removing last pring
Feb 16, 2017
cbe4da9
Change choice for permutation
Feb 16, 2017
1051fbb
cosmetics
Feb 16, 2017
790b0cb
fix remove remaining choice
Feb 16, 2017
9713089
DOC
Feb 16, 2017
a1052de
Fix inconsistencies
Feb 16, 2017
5b48b22
pep8
Feb 16, 2017
45172fa
Add checker for init parameters.
tguillemot Feb 17, 2017
ef3b403
hack bounds and make a test
Feb 17, 2017
adc1f37
FIX/TST bounds are provided by the fitting and not X at transform
Feb 17, 2017
22ea4f9
PEP8
Feb 17, 2017
81a3721
FIX/TST axis should be <= 1
Feb 17, 2017
055d8aa
PEP8
Feb 17, 2017
777e353
ENH Add parameter ignore_implicit_zeros
Feb 21, 2017
63708c2
ENH match output distribution
Feb 21, 2017
6e6eb52
ENH clip the data to avoid infinity due to output PDF
Feb 21, 2017
1aba0fe
FIX ENH restraint to uniform and norm
Feb 22, 2017
d1a94f5
[MRG] ENH Add example comparing the distribution of all scaling prepr…
raghavrv Feb 22, 2017
f1282f2
TST Validity of output_pdf
Feb 22, 2017
11709a3
EXA Use OrderedDict; Make it easier to add more transformations
raghavrv Feb 22, 2017
cf5fa8d
FIX PEP8 and replace scipy.stats by str in example
Feb 23, 2017
0150f62
FIX remove useless import
Feb 23, 2017
81c08cc
COSMET change variable names
Feb 27, 2017
adde8cf
FIX change output_pdf occurence to output_distribution
Feb 27, 2017
fe009c9
FIX partial fixies from comments
Feb 28, 2017
6ec43a8
COMIT change class name and code structure
Feb 28, 2017
e94cd48
COSMIT change direction to inverse
Feb 28, 2017
9c13d2a
FIX factorize transform in _transform_col
Feb 28, 2017
5d544ef
PEP8
Feb 28, 2017
d9b3e7a
FIX change the magic 10
Feb 28, 2017
23b3a91
FIX add interp1d to fixes
Feb 28, 2017
04dc89a
FIX/TST allow negative entries when ignore_implicit_zeros is True
Feb 28, 2017
9377cc2
FIX use np.interp instead of sp.interpolate.interp1d
Mar 1, 2017
c132211
FIX/TST fix tests
Mar 1, 2017
9b66d71
DOC start checking doc
Mar 2, 2017
fb88fa1
TST add test to check the behaviour of interp numpy
Mar 14, 2017
f46aea9
TST/EHN Add the possibility to add noise to compute quantile
Mar 14, 2017
d55295a
FIX factorize quantile computation
Mar 14, 2017
38127d5
FIX fixes issues
Mar 14, 2017
90fa3bd
PEP8
Mar 14, 2017
ba8339d
FIX/DOC correct doc
Mar 14, 2017
9a1b79e
TST/DOC improve doc and add random state
Mar 15, 2017
dabd403
EXA add examples to illustrate the use of smoothing_noise
Mar 15, 2017
29c24e0
FIX/DOC fix some grammar
Mar 15, 2017
3023a2f
DOC fix example
Mar 15, 2017
17db1ff
DOC/EXA make plot titles more succint
Mar 15, 2017
1de03ba
EXA improve explanation
Mar 15, 2017
79f6e97
EXA improve the docstring
Mar 15, 2017
12a3f47
DOC add a bit more documentation
Mar 15, 2017
9226f73
FIX advance review
Apr 5, 2017
b47158f
TST add subsampling test
Apr 6, 2017
bd928ed
DOC/TST better example for the docstring
Apr 7, 2017
c70aba0
DOC add ellipsis to docstring
Apr 7, 2017
9a9556c
FIX address olivier comments
Apr 8, 2017
6b105a9
FIX remove random_state in sparse.rand
Apr 8, 2017
dc39f9e
FIX spelling doc
Apr 8, 2017
c3cf631
FIX cite example in user guide and docstring
Apr 11, 2017
570c5d0
FIX olivier comments
Apr 12, 2017
da5604d
EHN improve the example comparing all the pre-processing methods
Apr 13, 2017
7871513
FIX/DOC remove title
Apr 13, 2017
52e4edf
FIX change the scaling of the figure
Apr 13, 2017
28cc2af
FIX plotting layout
Apr 19, 2017
6cdf964
FIX ratio w/h
Apr 19, 2017
58c64c2
Reorder and reword the plot_all_scaling example
ogrisel Apr 18, 2017
1a181fa
Fix aspect ratio and better explanations in the plot_all_scaling.py e…
ogrisel Apr 20, 2017
cb04d53
Fix broken link and remove useless sentence
ogrisel Apr 20, 2017
eac7071
FIX fix couples of spelling
Apr 20, 2017
37afa44
FIX comments joel
glemaitre Apr 23, 2017
a4719b4
FIX/DOC address documentation comments
glemaitre Apr 23, 2017
07906cc
FIX address comments joel
glemaitre May 7, 2017
d4d6bb4
FIX inline sparse and dense transform
glemaitre May 7, 2017
0b5be04
PEP8
glemaitre May 7, 2017
c740628
TST/DOC temporary skipping test
glemaitre May 7, 2017
6c2d7cf
FIX raise an error if n_quantiles > subsample
glemaitre May 16, 2017
22708c9
FIX wording in smoothing_noise example
ogrisel Jun 6, 2017
4d2fe63
EXA Denis comments
glemaitre Jun 7, 2017
49c94b3
FIX rephrasing
glemaitre Jun 7, 2017
2c85eb3
FIX make smoothing_noise to be a boolearn and change doc
glemaitre Jun 7, 2017
0f1bc24
Merge remote-tracking branch 'origin/master' into quantile_scaler
glemaitre Jun 7, 2017
db08c55
FIX address comments
glemaitre Jun 7, 2017
be207c7
FIX verbose the doc slightly more
glemaitre Jun 7, 2017
7b17f14
PEP8/DOC
glemaitre Jun 8, 2017
7046a6d
ENH: 2-ways interpolation to avoid smoothing_noise
GaelVaroquaux Jun 9, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1194,6 +1194,7 @@ See the :ref:`metrics` section of the user guide for further details.
preprocessing.Normalizer
preprocessing.OneHotEncoder
preprocessing.PolynomialFeatures
preprocessing.QuantileTransformer
preprocessing.RobustScaler
preprocessing.StandardScaler

Expand All @@ -1207,6 +1208,7 @@ See the :ref:`metrics` section of the user guide for further details.
preprocessing.maxabs_scale
preprocessing.minmax_scale
preprocessing.normalize
preprocessing.quantile_transform
preprocessing.robust_scale
preprocessing.scale

Expand Down
82 changes: 75 additions & 7 deletions doc/modules/preprocessing.rst
Expand Up @@ -10,6 +10,13 @@ The ``sklearn.preprocessing`` package provides several common
utility functions and transformer classes to change raw feature vectors
into a representation that is more suitable for the downstream estimators.

In general, learning algorithms benefit from standardization of the data set. If
some outliers are present in the set, robust scalers or transformers are more
appropriate. The behaviors of the different scalers, transformers, and
normalizers on a dataset containing marginal outliers is highlighted in
:ref:`sphx_glr_auto_examples_preprocessing_plot_all_scaling.py`.


.. _preprocessing_scaler:

Standardization, or mean removal and variance scaling
Expand Down Expand Up @@ -39,10 +46,10 @@ operation on a single array-like dataset::

>>> from sklearn import preprocessing
>>> import numpy as np
>>> X = np.array([[ 1., -1., 2.],
... [ 2., 0., 0.],
... [ 0., 1., -1.]])
>>> X_scaled = preprocessing.scale(X)
>>> X_train = np.array([[ 1., -1., 2.],
... [ 2., 0., 0.],
... [ 0., 1., -1.]])
>>> X_scaled = preprocessing.scale(X_train)

>>> X_scaled # doctest: +ELLIPSIS
array([[ 0. ..., -1.22..., 1.33...],
Expand Down Expand Up @@ -71,7 +78,7 @@ able to later reapply the same transformation on the testing set.
This class is hence suitable for use in the early steps of a
:class:`sklearn.pipeline.Pipeline`::

>>> scaler = preprocessing.StandardScaler().fit(X)
>>> scaler = preprocessing.StandardScaler().fit(X_train)
>>> scaler
StandardScaler(copy=True, with_mean=True, with_std=True)

Expand All @@ -81,7 +88,7 @@ This class is hence suitable for use in the early steps of a
>>> scaler.scale_ # doctest: +ELLIPSIS
array([ 0.81..., 0.81..., 1.24...])

>>> scaler.transform(X) # doctest: +ELLIPSIS
>>> scaler.transform(X_train) # doctest: +ELLIPSIS
array([[ 0. ..., -1.22..., 1.33...],
[ 1.22..., 0. ..., -0.26...],
[-1.22..., 1.22..., -1.06...]])
Expand All @@ -90,7 +97,8 @@ This class is hence suitable for use in the early steps of a
The scaler instance can then be used on new data to transform it the
same way it did on the training set::

>>> scaler.transform([[-1., 1., 0.]]) # doctest: +ELLIPSIS
>>> X_test = [[-1., 1., 0.]]
>>> scaler.transform(X_test) # doctest: +ELLIPSIS
array([[-2.44..., 1.22..., -0.26...]])

It is possible to disable either centering or scaling by either
Expand Down Expand Up @@ -248,6 +256,66 @@ a :class:`KernelCenterer` can transform the kernel matrix
so that it contains inner products in the feature space
defined by :math:`phi` followed by removal of the mean in that space.

.. _preprocessing_transformer:

Non-linear transformation
=========================

Like scalers, :class:`QuantileTransformer` puts each feature into the same
range or distribution. However, by performing a rank transformation, it smooths
out unusual distributions and is less influenced by outliers than scaling
methods. It does, however, distort correlations and distances within and across
features.

:class:`QuantileTransformer` and :func:`quantile_transform` provide a
non-parametric transformation based on the quantile function to map the data to
a uniform distribution with values between 0 and 1::

>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import train_test_split
>>> iris = load_iris()
>>> X, y = iris.data, iris.target
>>> X_train, X_test, y_train, y_test = train_test_split(X, y)
>>> quantile_transformer = preprocessing.QuantileTransformer(
... smoothing_noise=1e-12)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be:

quantile_transformer = preprocessing.QuantileTransformer()

>>> X_train_trans = quantile_transformer.fit_transform(X_train)
>>> X_test_trans = quantile_transformer.transform(X_test)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't show anything but usage. Can we put some kind of assertion regarding the transformed data here, e.g. report np.percentile(..., [0, 25, 50, 75, 100]) for each of X_train[:, 0], X_train_trans[:, 0], X_test_trans[:, 0]


>>> np.percentile(X_train[:, 0], [0, 25, 50, 75, 100])
... # doctest: +ELLIPSIS, +SKIP
array([...])
>>> np.percentile(X_train_trans[:, 0], [0, 25, 50, 75, 100])
... # doctest: +ELLIPSIS, +SKIP
array([...])
>>> np.percentile(X_test[:, 0], [0, 25, 50, 75, 100])
... # doctest: +ELLIPSIS, +SKIP
array([...])
>>> np.percentile(X_test_trans[:, 0], [0, 25, 50, 75, 100])
... # doctest: +ELLIPSIS, +SKIP
array([...])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point of having docstests with "arrray([...])" as the output. I think we should display the actual percentiles up to 2 digits. We need to fix the random_state=0 value to ensure that the results are reproducible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was due to some numpy < 1.8 support (it was skipped for the moment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get non deterministic results even if I fix the random_state=0 in QuantileTransformer.


It is also possible to map the transformed data to a normal distribution by
setting ``output_distribution='normal'``::

>>> quantile_transformer = preprocessing.QuantileTransformer(
... smoothing_noise=True, output_distribution='normal')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not put smoothing_noise=True on doctest snippets that document another parameter. Here the following would be enough:

>>> quantile_transformer = preprocessing.QuantileTransformer(
...     output_distribution='normal')

>>> X_trans = quantile_transformer.fit_transform(X)
>>> quantile_transformer.quantiles_ # doctest: +ELLIPSIS
array([...])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say "Thus the median of the input becomes the mean of the output, centered at 0. The normal output is clipped so that the input's maximum and minimum do not become infinite under the transformation."

Thus the median of the input becomes the mean of the output, centered at 0. The
normal output is clipped so that the input's minimum and maximum ---
corresponding to the 1e-7 and 1 - 1e-7 quantiles respectively --- do not
become infinite under the transformation.

:class:`QuantileTransformer` provides a ``smoothing_noise`` parameter to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provides a smoothing_noise parameter (set to True by default) to ....

make the interpretation more intuitive when inspecting the
transformation. This is particularly useful when feature values are
replicated identically many times in the training set (e.g. prices, ordinal
values such as user ratings, coarse-grained units of time, etc.). See
:ref:`sphx_glr_auto_examples_preprocessing_plot_smoothing_noise_quantile_transform.py`
for more details.

.. _preprocessing_normalization:

Normalization
Expand Down
13 changes: 10 additions & 3 deletions doc/whats_new.rst
Expand Up @@ -57,6 +57,13 @@ New features
during the first epochs of ridge and logistic regression.
By `Arthur Mensch`_.

- Added :class:`preprocessing.QuantileTransformer` class and
:func:`preprocessing.quantile_transform` function for features
normalization based on quantiles.
:issue:`8363` by :user:`Denis Engemann <dengemann>`,
:user:`Guillaume Lemaitre <glemaitre>`, `Olivier Grisel`_, `Raghav RV`_,
and :user:`Thierry Guillemot <tguillemot>`.

Enhancements
............

Expand Down Expand Up @@ -161,7 +168,7 @@ Enhancements
- Add ``sample_weight`` parameter to :func:`metrics.cohen_kappa_score` by
Victor Poughon.

- In :class:`gaussian_process.GaussianProcessRegressor`, method ``predict``
- In :class:`gaussian_process.GaussianProcessRegressor`, method ``predict``
is a lot faster with ``return_std=True`` by :user:`Hadrien Bertrand <hbertrand>`.

Bug fixes
Expand Down Expand Up @@ -254,7 +261,7 @@ Bug fixes
multiple inheritance context.
:issue:`8316` by :user:`Holger Peters <HolgerPeters>`.

- Fix :func:`sklearn.linear_model.BayesianRidge.fit` to return
- Fix :func:`sklearn.linear_model.BayesianRidge.fit` to return
ridge parameter `alpha_` and `lambda_` consistent with calculated
coefficients `coef_` and `intercept_`.
:issue:`8224` by :user:`Peter Gedeck <gedeck>`.
Expand Down Expand Up @@ -5059,4 +5066,4 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Anish Shah: https://github.com/AnishShah

.. _Neeraj Gangwar: http://neerajgangwar.in
.. _Arthur Mensch: https://amensch.fr
.. _Arthur Mensch: https://amensch.fr