Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH ColumnTransformer.get_feature_names() handles passthrough #14048

Merged
merged 16 commits into from Apr 19, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 16 additions & 11 deletions sklearn/compose/_column_transformer.py
Expand Up @@ -313,19 +313,18 @@ def _validate_remainder(self, X):
self.remainder)

# Make it possible to check for reordered named columns on transform
if (hasattr(X, 'columns') and
any(_determine_key_type(cols) == 'str'
for cols in self._columns)):
self._has_str_cols = any(_determine_key_type(cols) == 'str'
for cols in self._columns)
if hasattr(X, 'columns'):
self._df_columns = X.columns

self._n_features = X.shape[1]
cols = []
for columns in self._columns:
cols.extend(_get_column_indices(X, columns))
remaining_idx = list(set(range(self._n_features)) - set(cols))
remaining_idx = sorted(remaining_idx) or None

self._remainder = ('remainder', self.remainder, remaining_idx)
remaining_idx = sorted(list(set(range(self._n_features)) - set(cols)))
lrjball marked this conversation as resolved.
Show resolved Hide resolved
self._remainder = ('remainder', self.remainder, remaining_idx or None)

@property
def named_transformers_(self):
Expand Down Expand Up @@ -354,11 +353,15 @@ def get_feature_names(self):
if trans == 'drop' or (
hasattr(column, '__len__') and not len(column)):
continue
elif trans == 'passthrough':
raise NotImplementedError(
"get_feature_names is not yet supported when using "
"a 'passthrough' transformer.")
elif not hasattr(trans, 'get_feature_names'):
if trans == 'passthrough':
if hasattr(self, '_df_columns'):
feature_names.extend([self._df_columns[c]
if isinstance(c, int)
else c for c in column])
else:
feature_names.extend(['x%d' % i for i in column])
continue
if not hasattr(trans, 'get_feature_names'):
raise AttributeError("Transformer %s (type %s) does not "
"provide get_feature_names."
% (str(name), type(trans).__name__))
Expand Down Expand Up @@ -540,6 +543,7 @@ def fit_transform(self, X, y=None):

self._update_fitted_transformers(transformers)
self._validate_output(Xs)
self._output_dims = [X.shape[1] for X in Xs]
lrjball marked this conversation as resolved.
Show resolved Hide resolved

return self._hstack(list(Xs))

Expand Down Expand Up @@ -580,6 +584,7 @@ def transform(self, X):
# name order and count. See #14237 for details.
if (self._remainder[2] is not None and
hasattr(self, '_df_columns') and
self._has_str_cols and
hasattr(X, 'columns')):
n_cols_fit = len(self._df_columns)
n_cols_transform = len(X.columns)
Expand Down
51 changes: 40 additions & 11 deletions sklearn/compose/tests/test_column_transformer.py
Expand Up @@ -668,25 +668,54 @@ def test_column_transformer_get_feature_names():
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b', 'col1__c']

# passthrough transformers not supported
# drop transformer
ct = ColumnTransformer(
[('col0', DictVectorizer(), 0), ('col1', 'drop', 1)])
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b']

# passthrough transformer
ct = ColumnTransformer([('trans', 'passthrough', [0, 1])])
ct.fit(X)
assert_raise_message(
NotImplementedError, 'get_feature_names is not yet supported',
ct.get_feature_names)
assert ct.get_feature_names() == ['x0', 'x1']

ct = ColumnTransformer([('trans', DictVectorizer(), 0)],
remainder='passthrough')
ct.fit(X)
assert_raise_message(
NotImplementedError, 'get_feature_names is not yet supported',
ct.get_feature_names)
assert ct.get_feature_names() == ['trans__a', 'trans__b', 'x1']

# drop transformer
ct = ColumnTransformer(
[('col0', DictVectorizer(), 0), ('col1', 'drop', 1)])
ct = ColumnTransformer([('trans', 'passthrough', [1])],
remainder='passthrough')
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b']

assert ct.get_feature_names() == ['x1', 'x0']

# passthough transformer with a dataframe
pd = pytest.importorskip('pandas')
lrjball marked this conversation as resolved.
Show resolved Hide resolved
X_df = pd.DataFrame(X, columns=['col0', 'col1'])

ct = ColumnTransformer([('trans', 'passthrough', ['col0', 'col1'])])
ct.fit(X_df)
assert ct.get_feature_names() == ['col0', 'col1']

ct = ColumnTransformer([('trans', 'passthrough', [0, 1])])
ct.fit(X_df)
assert ct.get_feature_names() == ['col0', 'col1']

ct = ColumnTransformer([('col0', DictVectorizer(), 0)],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col0__a', 'col0__b', 'col1']

ct = ColumnTransformer([('trans', 'passthrough', ['col1'])],
lrjball marked this conversation as resolved.
Show resolved Hide resolved
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col1', 'col0']

ct = ColumnTransformer([('trans', 'passthrough', [1])],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col1', 'col0']


def test_column_transformer_special_strings():
Expand Down