Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Dynamically set n_quantiles to min(n_samples, n_quantiles) in QuantileTransformer #13333

Merged
merged 9 commits into from Mar 1, 2019
12 changes: 6 additions & 6 deletions doc/modules/preprocessing.rst
Expand Up @@ -387,13 +387,13 @@ Using the earlier example with the iris dataset::
... output_distribution='normal', random_state=0)
>>> X_trans = quantile_transformer.fit_transform(X)
>>> quantile_transformer.quantiles_ # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
array([[4.3..., 2..., 1..., 0.1...],
[4.31..., 2.02..., 1.01..., 0.1...],
[4.32..., 2.05..., 1.02..., 0.1...],
array([[4.3, 2. , 1. , 0.1],
[4.4, 2.2, 1.1, 0.1],
[4.4, 2.2, 1.2, 0.1],
...,
[7.84..., 4.34..., 6.84..., 2.5...],
[7.87..., 4.37..., 6.87..., 2.5...],
[7.9..., 4.4..., 6.9..., 2.5...]])
[7.7, 4.1, 6.7, 2.5],
[7.7, 4.2, 6.7, 2.5],
[7.9, 4.4, 6.9, 2.5]])

Thus the median of the input becomes the mean of the output, centered at 0. The
normal output is clipped so that the input's minimum and maximum ---
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v0.21.rst
Expand Up @@ -376,6 +376,12 @@ Support for Python 3.4 and below has been officially dropped.
:class:`preprocessing.StandardScaler`. :issue:`13007` by
:user:`Raffaello Baluyot <baluyotraf>`

- |Fix| Fixed a bug in :class:`preprocessing.QuantileTransformer` and
:func:`preprocessing.quantile_transform` to force n_quantiles to be at most
equal to n_samples. Values of n_quantiles larger than n_samples were either
useless or resulting in a wrong approximation of the cumulative distribution
function estimator. :issue:`13333` by :user:`Albert Thomas <albertcthomas>`.

:mod:`sklearn.svm`
..................

Expand Down
31 changes: 26 additions & 5 deletions sklearn/preprocessing/data.py
Expand Up @@ -424,7 +424,7 @@ def minmax_scale(X, feature_range=(0, 1), axis=0, copy=True):
X_scaled = X_std * (max - min) + min

where min, max = feature_range.

The transformation is calculated as (when ``axis=0``)::

X_scaled = scale * X + min - X.min(axis=0) * scale
Expand Down Expand Up @@ -592,7 +592,7 @@ class StandardScaler(BaseEstimator, TransformerMixin):
-----
NaNs are treated as missing values: disregarded in fit, and maintained in
transform.

We use a biased estimator for the standard deviation, equivalent to
`numpy.std(x, ddof=0)`. Note that the choice of `ddof` is unlikely to
affect model performance.
Expand Down Expand Up @@ -2041,9 +2041,13 @@ class QuantileTransformer(BaseEstimator, TransformerMixin):

Parameters
----------
n_quantiles : int, optional (default=1000)
n_quantiles : int, optional (default=1000 or n_samples)
Number of quantiles to be computed. It corresponds to the number
of landmarks used to discretize the cumulative distribution function.
If n_quantiles is larger than the number of samples, n_quantiles is set
to the number of samples as a larger number of quantiles does not give
a better approximation of the cumulative distribution function
estimator.

output_distribution : str, optional (default='uniform')
Marginal distribution for the transformed data. The choices are
Expand Down Expand Up @@ -2072,6 +2076,10 @@ class QuantileTransformer(BaseEstimator, TransformerMixin):

Attributes
----------
n_quantiles_ : integer
The actual number of quantiles used to discretize the cumulative
distribution function.

quantiles_ : ndarray, shape (n_quantiles, n_features)
The values corresponding the quantiles of reference.

Expand Down Expand Up @@ -2218,10 +2226,19 @@ def fit(self, X, y=None):
self.subsample))

X = self._check_inputs(X)
n_samples = X.shape[0]

if self.n_quantiles > n_samples:
warnings.warn("n_quantiles (%s) is greater than the total number "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more verbose than is helpful. "n_quantiles (%d) is being reduced to the number of samples (%d)" if you want to add another phrase explaining why, okay, but I think it's intuitively okay.

"of samples (%s). n_quantiles is set to "
"n_samples."
% (self.n_quantiles, n_samples))
self.n_quantiles_ = max(1, min(self.n_quantiles, n_samples))

rng = check_random_state(self.random_state)

# Create the quantiles of reference
self.references_ = np.linspace(0, 1, self.n_quantiles,
self.references_ = np.linspace(0, 1, self.n_quantiles_,
endpoint=True)
if sparse.issparse(X):
self._sparse_fit(X, rng)
Expand Down Expand Up @@ -2443,9 +2460,13 @@ def quantile_transform(X, axis=0, n_quantiles=1000,
Axis used to compute the means and standard deviations along. If 0,
transform each feature, otherwise (if 1) transform each sample.

n_quantiles : int, optional (default=1000)
n_quantiles : int, optional (default=1000 or n_samples)
Number of quantiles to be computed. It corresponds to the number
of landmarks used to discretize the cumulative distribution function.
If n_quantiles is larger than the number of samples, n_quantiles is set
to the number of samples as a larger number of quantiles does not give
a better approximation of the cumulative distribution function
estimator.

output_distribution : str, optional (default='uniform')
Marginal distribution for the transformed data. The choices are
Expand Down
7 changes: 7 additions & 0 deletions sklearn/preprocessing/tests/test_data.py
Expand Up @@ -1260,6 +1260,13 @@ def test_quantile_transform_check_error():
assert_raise_message(ValueError,
'Expected 2D array, got scalar array instead',
transformer.transform, 10)
# check that a warning is raised is n_quantiles > n_samples
transformer = QuantileTransformer(n_quantiles=100)
warn_msg = "n_quantiles is set to n_samples"
with pytest.warns(UserWarning, match=warn_msg) as record:
transformer.fit(X)
assert len(record) == 1
assert transformer.n_quantiles_ == X.shape[0]


def test_quantile_transform_sparse_ignore_zeros():
Expand Down