Skip to content

Commit

Permalink
ENH Add sparse_threshold keyword to ColumnTransformer (#11614)
Browse files Browse the repository at this point in the history
Reasoning: when eg OneHotEncoder is used as part of ColumnTransformer it would cause the final result to be a sparse matrix. As this is a typical case when we have mixed dtype data, it means that many pipeline will have to deal with sparse data implicitly, even if you have only some categorical features with low cardinality. 
Idea was first to change default of `OneHotEncoder` sparse to False, but based on gitter discussion (https://gitter.im/scikit-learn/dev?at=5b4e5a69a94c5255523bc9fc) we decided to let ColumnTransformer switch between both based on a threshold. The user still has full control if he/she wants always or never sparse.
  • Loading branch information
jorisvandenbossche authored and jnothman committed Jul 25, 2018
1 parent c2b7478 commit cf897de
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 27 deletions.
10 changes: 6 additions & 4 deletions doc/modules/compose.rst
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ By default, the remaining rating columns are ignored (``remainder='drop'``)::
... remainder='drop')

>>> column_trans.fit(X) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ColumnTransformer(n_jobs=1, remainder='drop', transformer_weights=None,
ColumnTransformer(n_jobs=1, remainder='drop', sparse_threshold=0.3,
transformer_weights=None,
transformers=...)

>>> column_trans.get_feature_names()
Expand Down Expand Up @@ -461,7 +462,7 @@ transformation::
... ('title_bow', CountVectorizer(), 'title')],
... remainder='passthrough')

>>> column_trans.fit_transform(X).toarray()
>>> column_trans.fit_transform(X)
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
array([[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 5, 4],
[1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 3, 5],
Expand All @@ -478,7 +479,7 @@ the transformation::
... ('title_bow', CountVectorizer(), 'title')],
... remainder=MinMaxScaler())

>>> column_trans.fit_transform(X)[:, -2:].toarray()
>>> column_trans.fit_transform(X)[:, -2:]
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
array([[1. , 0.5],
[0. , 1. ],
Expand All @@ -495,7 +496,8 @@ above example would be::
... ('city', CountVectorizer(analyzer=lambda x: [x])),
... ('title', CountVectorizer()))
>>> column_trans # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ColumnTransformer(n_jobs=1, remainder='drop', transformer_weights=None,
ColumnTransformer(n_jobs=1, remainder='drop', sparse_threshold=0.3,
transformer_weights=None,
transformers=[('countvectorizer-1', ...)

.. topic:: Examples:
Expand Down
54 changes: 41 additions & 13 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Author: Andreas Mueller
# Joris Van den Bossche
# License: BSD
from __future__ import division

from itertools import chain

import numpy as np
Expand Down Expand Up @@ -83,6 +85,14 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
non-specified columns will use the ``remainder`` estimator. The
estimator must support `fit` and `transform`.
sparse_threshold : float, default = 0.3
If the transformed output consists of a mix of sparse and dense data,
it will be stacked as a sparse matrix if the density is lower than this
value. Use ``sparse_threshold=0`` to always return dense.
When the transformed output consists of all sparse or all dense data,
the stacked result will be sparse or dense, respectively, and this
keyword will be ignored.
n_jobs : int, optional
Number of jobs to run in parallel (default 1).
Expand All @@ -108,6 +118,11 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
Keys are transformer names and values are the fitted transformer
objects.
sparse_output_ : boolean
Boolean flag indicating wether the output of ``transform`` is a
sparse matrix or a dense numpy array, which depends on the output
of the individual transformers and the `sparse_threshold` keyword.
Notes
-----
The order of the columns in the transformed feature matrix follows the
Expand Down Expand Up @@ -141,10 +156,11 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
"""

def __init__(self, transformers, remainder='drop', n_jobs=1,
transformer_weights=None):
def __init__(self, transformers, remainder='drop', sparse_threshold=0.3,
n_jobs=1, transformer_weights=None):
self.transformers = transformers
self.remainder = remainder
self.sparse_threshold = sparse_threshold
self.n_jobs = n_jobs
self.transformer_weights = transformer_weights

Expand Down Expand Up @@ -374,12 +390,9 @@ def fit(self, X, y=None):
This estimator
"""
self._validate_remainder(X)
self._validate_transformers()

transformers = self._fit_transform(X, y, _fit_one_transformer)
self._update_fitted_transformers(transformers)

# we use fit_transform to make sure to set sparse_output_ (for which we
# need the transformed data) to have consistent output type in predict
self.fit_transform(X, y=y)
return self

def fit_transform(self, X, y=None):
Expand Down Expand Up @@ -409,15 +422,28 @@ def fit_transform(self, X, y=None):
result = self._fit_transform(X, y, _fit_transform_one)

if not result:
self._update_fitted_transformers([])
# All transformers are None
return np.zeros((X.shape[0], 0))

Xs, transformers = zip(*result)

# determine if concatenated output will be sparse or not
if all(sparse.issparse(X) for X in Xs):
self.sparse_output_ = True
elif any(sparse.issparse(X) for X in Xs):
nnz = sum(X.nnz if sparse.issparse(X) else X.size for X in Xs)
total = sum(X.shape[0] * X.shape[1] if sparse.issparse(X)
else X.size for X in Xs)
density = nnz / total
self.sparse_output_ = density < self.sparse_threshold
else:
self.sparse_output_ = False

self._update_fitted_transformers(transformers)
self._validate_output(Xs)

return _hstack(list(Xs))
return _hstack(list(Xs), self.sparse_output_)

def transform(self, X):
"""Transform X separately by each transformer, concatenate results.
Expand Down Expand Up @@ -445,7 +471,7 @@ def transform(self, X):
# All transformers are None
return np.zeros((X.shape[0], 0))

return _hstack(list(Xs))
return _hstack(list(Xs), self.sparse_output_)


def _check_key_type(key, superclass):
Expand Down Expand Up @@ -479,16 +505,17 @@ def _check_key_type(key, superclass):
return False


def _hstack(X):
def _hstack(X, sparse_):
"""
Stacks X horizontally.
Supports input types (X): list of
numpy arrays, sparse arrays and DataFrames
"""
if any(sparse.issparse(f) for f in X):
if sparse_:
return sparse.hstack(X).tocsr()
else:
X = [f.toarray() if sparse.issparse(f) else f for f in X]
return np.hstack(X)


Expand Down Expand Up @@ -658,7 +685,8 @@ def make_column_transformer(*transformers, **kwargs):
... (['numerical_column'], StandardScaler()),
... (['categorical_column'], OneHotEncoder()))
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ColumnTransformer(n_jobs=1, remainder='drop', transformer_weights=None,
ColumnTransformer(n_jobs=1, remainder='drop', sparse_threshold=0.3,
transformer_weights=None,
transformers=[('standardscaler',
StandardScaler(...),
['numerical_column']),
Expand Down
80 changes: 70 additions & 10 deletions sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.externals import six
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.preprocessing import StandardScaler, Normalizer, OneHotEncoder
from sklearn.feature_extraction import DictVectorizer


Expand Down Expand Up @@ -262,14 +262,16 @@ def test_column_transformer_sparse_array():
for remainder, res in [('drop', X_res_first),
('passthrough', X_res_both)]:
ct = ColumnTransformer([('trans', Trans(), col)],
remainder=remainder)
remainder=remainder,
sparse_threshold=0.8)
assert_true(sparse.issparse(ct.fit_transform(X_sparse)))
assert_allclose_dense_sparse(ct.fit_transform(X_sparse), res)
assert_allclose_dense_sparse(ct.fit(X_sparse).transform(X_sparse),
res)

for col in [[0, 1], slice(0, 2)]:
ct = ColumnTransformer([('trans', Trans(), col)])
ct = ColumnTransformer([('trans', Trans(), col)],
sparse_threshold=0.8)
assert_true(sparse.issparse(ct.fit_transform(X_sparse)))
assert_allclose_dense_sparse(ct.fit_transform(X_sparse), X_res_both)
assert_allclose_dense_sparse(ct.fit(X_sparse).transform(X_sparse),
Expand All @@ -279,7 +281,8 @@ def test_column_transformer_sparse_array():
def test_column_transformer_sparse_stacking():
X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
col_trans = ColumnTransformer([('trans1', Trans(), [0]),
('trans2', SparseMatrixTrans(), 1)])
('trans2', SparseMatrixTrans(), 1)],
sparse_threshold=0.8)
col_trans.fit(X_array)
X_trans = col_trans.transform(X_array)
assert_true(sparse.issparse(X_trans))
Expand All @@ -288,6 +291,57 @@ def test_column_transformer_sparse_stacking():
assert len(col_trans.transformers_) == 2
assert col_trans.transformers_[-1][0] != 'remainder'

col_trans = ColumnTransformer([('trans1', Trans(), [0]),
('trans2', SparseMatrixTrans(), 1)],
sparse_threshold=0.1)
col_trans.fit(X_array)
X_trans = col_trans.transform(X_array)
assert not sparse.issparse(X_trans)
assert X_trans.shape == (X_trans.shape[0], X_trans.shape[0] + 1)
assert_array_equal(X_trans[:, 1:], np.eye(X_trans.shape[0]))


def test_column_transformer_sparse_threshold():
X_array = np.array([['a', 'b'], ['A', 'B']], dtype=object).T
# above data has sparsity of 4 / 8 = 0.5

# if all sparse, keep sparse (even if above threshold)
col_trans = ColumnTransformer([('trans1', OneHotEncoder(), [0]),
('trans2', OneHotEncoder(), [1])],
sparse_threshold=0.2)
res = col_trans.fit_transform(X_array)
assert sparse.issparse(res)
assert col_trans.sparse_output_

# mixed -> sparsity of (4 + 2) / 8 = 0.75
for thres in [0.75001, 1]:
col_trans = ColumnTransformer(
[('trans1', OneHotEncoder(sparse=True), [0]),
('trans2', OneHotEncoder(sparse=False), [1])],
sparse_threshold=thres)
res = col_trans.fit_transform(X_array)
assert sparse.issparse(res)
assert col_trans.sparse_output_

for thres in [0.75, 0]:
col_trans = ColumnTransformer(
[('trans1', OneHotEncoder(sparse=True), [0]),
('trans2', OneHotEncoder(sparse=False), [1])],
sparse_threshold=thres)
res = col_trans.fit_transform(X_array)
assert not sparse.issparse(res)
assert not col_trans.sparse_output_

# if nothing is sparse -> no sparse
for thres in [0.33, 0, 1]:
col_trans = ColumnTransformer(
[('trans1', OneHotEncoder(sparse=False), [0]),
('trans2', OneHotEncoder(sparse=False), [1])],
sparse_threshold=thres)
res = col_trans.fit_transform(X_array)
assert not sparse.issparse(res)
assert not col_trans.sparse_output_


def test_column_transformer_error_msg_1D():
X_array = np.array([[0., 1., 2.], [2., 4., 6.]]).T
Expand All @@ -311,9 +365,9 @@ def test_2D_transformer_output():
('trans2', TransNo2D(), 1)])
assert_raise_message(ValueError, "the 'trans2' transformer should be 2D",
ct.fit_transform, X_array)
ct.fit(X_array)
# because fit is also doing transform, this raises already on fit
assert_raise_message(ValueError, "the 'trans2' transformer should be 2D",
ct.transform, X_array)
ct.fit, X_array)


def test_2D_transformer_output_pandas():
Expand All @@ -326,9 +380,9 @@ def test_2D_transformer_output_pandas():
ct = ColumnTransformer([('trans1', TransNo2D(), 'col1')])
assert_raise_message(ValueError, "the 'trans1' transformer should be 2D",
ct.fit_transform, X_df)
ct.fit(X_df)
# because fit is also doing transform, this raises already on fit
assert_raise_message(ValueError, "the 'trans1' transformer should be 2D",
ct.transform, X_df)
ct.fit, X_df)


@pytest.mark.parametrize("remainder", ['drop', 'passthrough'])
Expand Down Expand Up @@ -406,6 +460,7 @@ def test_column_transformer_get_set_params():

exp = {'n_jobs': 1,
'remainder': 'drop',
'sparse_threshold': 0.3,
'trans1': ct.transformers[0][1],
'trans1__copy': True,
'trans1__with_mean': True,
Expand All @@ -425,6 +480,7 @@ def test_column_transformer_get_set_params():
ct.set_params(trans1='passthrough')
exp = {'n_jobs': 1,
'remainder': 'drop',
'sparse_threshold': 0.3,
'trans1': 'passthrough',
'trans2': ct.transformers[1][1],
'trans2__copy': True,
Expand Down Expand Up @@ -708,7 +764,8 @@ def test_column_transformer_sparse_remainder_transformer():
[8, 6, 4]]).T

ct = ColumnTransformer([('trans1', Trans(), [0])],
remainder=SparseMatrixTrans())
remainder=SparseMatrixTrans(),
sparse_threshold=0.8)

X_trans = ct.fit_transform(X_array)
assert sparse.issparse(X_trans)
Expand All @@ -730,7 +787,8 @@ def test_column_transformer_drop_all_sparse_remainder_transformer():
[2, 4, 6],
[8, 6, 4]]).T
ct = ColumnTransformer([('trans1', 'drop', [0])],
remainder=SparseMatrixTrans())
remainder=SparseMatrixTrans(),
sparse_threshold=0.8)

X_trans = ct.fit_transform(X_array)
assert sparse.issparse(X_trans)
Expand All @@ -753,6 +811,7 @@ def test_column_transformer_get_set_params_with_remainder():
'remainder__copy': True,
'remainder__with_mean': True,
'remainder__with_std': True,
'sparse_threshold': 0.3,
'trans1': ct.transformers[0][1],
'trans1__copy': True,
'trans1__with_mean': True,
Expand All @@ -771,6 +830,7 @@ def test_column_transformer_get_set_params_with_remainder():
'remainder__copy': True,
'remainder__with_mean': True,
'remainder__with_std': False,
'sparse_threshold': 0.3,
'trans1': 'passthrough',
'transformers': ct.transformers,
'transformer_weights': None}
Expand Down

0 comments on commit cf897de

Please sign in to comment.