New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add experimental.ColumnTransformer #9012

Merged
merged 90 commits into from May 29, 2018

Conversation

@jorisvandenbossche
Contributor

jorisvandenbossche commented Jun 6, 2017

Continuation @amueller's PR #3886 (for now just rebased and updated for changes in sklearn)

Fixes #2034.

Closes #2034, closes #3886, closes #8540, closes #8539

@amueller

This comment has been minimized.

Member

amueller commented Jun 6, 2017

feel free to squash my commits

('tfidf', TfidfVectorizer()),
('best', TruncatedSVD(n_components=50)),
])),
]), 'body'),

This comment has been minimized.

@vene

vene Jun 6, 2017

Member

Are we sold on the tuple-based API here? I'd like it if this were a bit more explicit... (I'd like it to say column_name='body' somehow)

This comment has been minimized.

@vene

vene Jun 6, 2017

Member

While we're at it, why is the outer structure a dict and not an ordered list of tuples like FeatureUnion?

Often it is easiest to preprocess data before applying scikit-learn methods, for example using
pandas.
If the preprocessing has parameters that you want to adjust within a
grid-search, however, they need to be inside a transformer. This can be

This comment has been minimized.

@vene

vene Jun 6, 2017

Member

they=?

I'd remove the whole sentence and just say ":class:ColumnTransformer is a convenient way to perform heterogeneous preprocessing on data columns within a pipeline.

.. note::
:class:`ColumnTransformer` expects a very different data format from the numpy arrays usually used in scikit-learn.
For a numpy array ``X_array``, ``X_array[1]`` will give a single sample (``X_array[1].shape == (n_samples.)``), but all features.

This comment has been minimized.

@vene

vene Jun 6, 2017

Member

Do you mean "will give all feature values for the selected sample, e.g. (X_array[1].shape == (n_features,)) ?

This comment has been minimized.

@jorisvandenbossche

jorisvandenbossche Jun 6, 2017

Contributor

Yeah, I didn't update the rst docs yet, and I also saw there are still some errors like these

This comment has been minimized.

@vene

vene Jun 6, 2017

Member

sure, don't take my comment personally, just making notes

:class:`ColumnTransformer` expects a very different data format from the numpy arrays usually used in scikit-learn.
For a numpy array ``X_array``, ``X_array[1]`` will give a single sample (``X_array[1].shape == (n_samples.)``), but all features.
For columnar data like a dict or pandas dataframe ``X_columns``, ``X_columns[1]`` is expected to give a feature called
``1`` for each sample (``X_columns[1].shape == (n_samples,)``).

This comment has been minimized.

@vene

vene Jun 6, 2017

Member

Are we supporting integer column labels?

.. note::
:class:`ColumnTransformer` expects a very different data format from the numpy arrays usually used in scikit-learn.
For a numpy array ``X_array``, ``X_array[1]`` will give a single sample (``X_array[1].shape == (n_samples.)``), but all features.
For columnar data like a dict or pandas dataframe ``X_columns``, ``X_columns[1]`` is expected to give a feature called

This comment has been minimized.

@vene

vene Jun 6, 2017

Member

-> a pandas DataFrame

Also this chapter should probably have actual links to the pandas website or something, for readers who might have no idea what we're talking about.

@jorisvandenbossche

This comment has been minimized.

Contributor

jorisvandenbossche commented Jun 6, 2017

So the current way to specify a transformer is like this:

ColumnTransformer({"name": (Transformer(), column), ..})

(where 'name' is the transformer name, and column is the column on which to apply the transformer).

There was some discussion and back-and-forth about this in the original PR, and other options mentioned are (as far I as read it correctly):

ColumnTransformer([("name", Transformer(), column), ..]) # more similar to Pipeline interface

or

ColumnTransformer([('column', Transformer()), ..])   # in this case transformer name / column name has to be identical

BTW, when using dicts, I would actually find this interface more logical:

ColumnTransformer({"column": ("name", Transformer()), ..})

which switches the place of column and transformer name, which gives you a tuple of (name, trans) similar to the Pipeline interface, and uses the dict key to select (which mimics getitem how also the values are selected from the input data).
But this has the disadvantage that we cannot extend this to multiple columns with lists (since lists cannot be dict keys).

@amueller

This comment has been minimized.

Member

amueller commented Jun 6, 2017

ColumnTransformer({"column": ("name", Transformer()), ..})

The column is a numpy array, right, so it's not hashable.

I think we could use either the list or dict thing here, and have a helper make_column_transformer or somthing that does
make_column_transformer({Transformer(): column})
Transformer is hashable, so that works, and we can generate the name from the class name as in make_pipeline.

@amueller

This comment has been minimized.

Member

amueller commented Jun 6, 2017

Oh I didn't fully read your comment. I think multiple columns are essential, so we can't do that...

@vene

This comment has been minimized.

Member

vene commented Jun 7, 2017

Maybe i'm a little bit confused but then, does this overlap in scope with FeatureUnion?

If there is more than one transformer, we need to know what order to use for column_stacking their output, right? So if we use a dict with transformers as keys can we guarantee a consistent order?

@vene

This comment has been minimized.

Member

vene commented Jun 7, 2017

Let's try to discuss this with @amueller before proceeding. I think I'll help on this today.

jorisvandenbossche added some commits Jun 7, 2017

transformations to each field of the data, producing a homogeneous feature
matrix from a heterogeneous data source.
The transformers are applied in parallel, and the feature matrices they output
are concatenated side-by-side into a larger matrix.

This comment has been minimized.

@vene

vene Jun 7, 2017

Member

Since this PR adds ColumnTransformer, we can say here something like: for data organized in fields with heterogeneous types, see the related class class:ColumnTransformer.

@jnothman jnothman referenced this pull request Jun 7, 2017

Closed

[MRG+1] Stacking classifier with pipelines API #8960

7 of 7 tasks complete
@jorisvandenbossche

This comment has been minimized.

Contributor

jorisvandenbossche commented Jun 7, 2017

Additional problem that we encountered:

In the meantime (since the original PR was made), Transformers need 2D X values. Therefore, I made sure in the ColumnTransformer that I always pass through the subset of columns to the transformer as a 2D object, also when you apply the transformer only on a single column.

But, by ensuring this, the example using a TfidfVectorizer fails, because that one expects a 1D object of text samples.

So possible options:

  1. Add a way to specify in the ColumnTransformer that for certain transformers the passed X values should be kept as 1D.
    E.g. we could have

    ColumnTransformer([('tfidf', TfidfVectorizer, 'text_col'), ...],
                      flatten_column={'tfidf': True})`
    

    where flatten_column (or other name like keep_1d) is False by default (satisfying the normal Transformers), but you can specify per transformer if you want to override this default of False.
    The use of a dict here is similar to the transformer_weights keyword.

  2. Adapt the example (which means: letting the user write more boilerplate) to fix it, eg by adding a step in the pipeline to select the single column from the 2D object.
    We could add one line to the current pipeline that holds the TFIDF (this tuple is one of the transformers in the ColumnTransformer)

            # Pipeline for standard bag-of-words model for body
            ('body_bow', Pipeline([
    added -->   ('flatten', FunctionTransformer(lambda x: x[:, 0], validate=False)),
                ('tfidf', TfidfVectorizer()),
                ('best', TruncatedSVD(n_components=50)),
            ]), 'body'),
    
  3. Adapt TfidfVectorizer to eg have a keyword that allows to specify that 2D data is expected (which would be False by default for backwards compatibility).
    If we would like to do this one, this might ideally be a separate PR and so the second option can be used as a temporary hack to the example to have it working and which can be removed in the other PR.

@jorisvandenbossche

This comment has been minimized.

Contributor

jorisvandenbossche commented Jun 7, 2017

Another option would be:

  1. Make a distinction between specifying the columns as a scalar or as a list: when using a scalar, the data is passed as 1D to the Transformers, as a list as 2D data.
    The disadvantage of this is that for all Transformers except of the text vectorizers, you will have to specify the single column as a list:

    ColumnTransformer([('tfidf', TfidfVectorizer, 'text_col')
                       ('scaler', StandardScaler(), ['other_col'])])
    
@vene

This comment has been minimized.

Member

vene commented Jun 7, 2017

Option 4 seems too magic and not explicit enough..

@amueller

This comment has been minimized.

Member

amueller commented Jun 7, 2017

I don't think it's too magic, it mirrors indexing semantics in numpy:

In [2]: import numpy as np

In [7]: x = np.arange(9).reshape(3, 3)

In [8]: x[1]
Out[8]: array([3, 4, 5])

In [9]: x[[1]]
Out[9]: array([[3, 4, 5]])
@jorisvandenbossche

This comment has been minimized.

Contributor

jorisvandenbossche commented May 25, 2018

OK, I updated this today, but ended up needing to make a bit more changes than I anticipated. So I think it would be good that at least somebody takes a closer look to the changes. Diff view of only those changes: https://github.com/scikit-learn/scikit-learn/pull/9012/files/04bcb1ec7292a566e1b6b5f2fa0b7d38d60d9102..d298fc310069273f05d8806428fcfd0d2530b77d

What I changed:

  • Some small changes according to the inline feedback of Andy and Joel (for get_feature_names I added a a NotImplementedError in case you call this with a 'passthrough' transformer)
  • Renamed unspecified keyword to remainder, and switched default from 'drop' to 'passthrough', as discussed above.
  • However, this change of default uncovered some issues with the current implementation. I hstack the output of the different transformers. However, if the transformers do not necessarily all return 2D output, the hstack will give a (not very informative) error or even produce wrong results.
    I discussed with Olivier, and we didn't really think of a good use case for a Transformer to return 1D output (some of the vectorizers need 1D input but will all output 2D data). So it probably mainly occurred due to how I wrote the tests. But therefore we decided to be strict for now and require 2D output of each transformer, and I added a check for this in fit/fit_transform in order to provide a nicer error message.
@GaelVaroquaux

This comment has been minimized.

Member

GaelVaroquaux commented May 25, 2018

@jnothman

Those changes look good

@jnothman

Some things I noticed

@@ -101,6 +101,105 @@ memory the ``DictVectorizer`` class uses a ``scipy.sparse`` matrix by
default instead of a ``numpy.ndarray``.
.. _column_transformer:

This comment has been minimized.

@jnothman

jnothman May 27, 2018

Member

This should be in compose.rst, but perhaps noted at the top of this file

This comment has been minimized.

@jorisvandenbossche

jorisvandenbossche May 28, 2018

Contributor

Yes, I know, but also (related to what I mentioned here: #9012 (comment)):

  • when moving to compose.rst, I think we should use a different example (eg using transformers from preprocessing module, as I think that is a more typical use case)
  • we should reference this in preprocessing.rst
  • we should add a better 'typical data science usecase" example for the example gallery
  • I would maybe keep the explanation currently in feature_extraction.rst (the example), but shorten it by referring to compose.rst for the general explanation.

I can work on the above this week. But in light of getting this merged sooner rather than later, I would prefer doing it as a follow-up PR, if that is fine? (I can also do a minimal here and simply move the current docs addition to compose.rst without any of the other mentioned improvements).

feature_names : list of strings
Names of the features produced by transform.
"""
check_is_fitted(self, 'transformers_')

This comment has been minimized.

@jnothman

jnothman May 27, 2018

Member

Shouldn't remainder be handled here too?

This comment has been minimized.

@jorisvandenbossche

jorisvandenbossche May 28, 2018

Contributor

Ideally, yes. I am only not fully sure what to do here currently, given that get_feature_names is in general not really well supported.
I think ideally I would add the names of the passed through columns to feature_names, but then the actual string column names in case of pandas dataframes. And in case of numpy arrays return the indices into that array as strings? (['0', '1', ..])

I can also raise an error for now if there are columns passed through, just to make sure that if we improve the get_feature_names in the future, it does not lead to a change in behaviour (but a removal of the error).

This comment has been minimized.

@jnothman

jnothman May 28, 2018

Member

Can just raise NotImplementedError in case of remainder != 'drop' for now.... Or you can tack the remainder transformer onto the end of _iter.

I agree get_feature_names is not quite the right design.

This comment has been minimized.

@amueller

amueller Jun 1, 2018

Member

Can we have an issue for this?

@jnothman

This comment has been minimized.

Member

jnothman commented May 28, 2018

@jorisvandenbossche

This comment has been minimized.

Contributor

jorisvandenbossche commented May 29, 2018

Added last update: added additional error for get_feature_names, and moved the docs to compose.rst.

As far as I am concerned, somebody can push the green button ;-)

@jnothman

This comment has been minimized.

Member

jnothman commented May 29, 2018

Indeed: let's see how this flies!

@jnothman jnothman merged commit 0b6308c into scikit-learn:master May 29, 2018

8 checks passed

ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: python2 Your tests passed on CircleCI!
Details
ci/circleci: python3 Your tests passed on CircleCI!
Details
codecov/patch 99.2% of diff hit (target 95.1%)
Details
codecov/project 95.2% (+0.09%) compared to 0c424ce
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
lgtm analysis: Python No alert changes
Details
@jnothman

This comment has been minimized.

Member

jnothman commented May 29, 2018

Thanks Joris for some great work on finally making this happen!

@jorisvandenbossche

This comment has been minimized.

Contributor

jorisvandenbossche commented May 29, 2018

Woohoo, thanks for merging!

@TomDLT

This comment has been minimized.

Member

TomDLT commented May 29, 2018

Great work congrats !

@GaelVaroquaux

This comment has been minimized.

Member

GaelVaroquaux commented May 29, 2018

@glemaitre

This comment has been minimized.

Contributor

glemaitre commented May 29, 2018

Really nice!!! Let's stack those columns then ;)

@eyadsibai

This comment has been minimized.

eyadsibai commented May 29, 2018

Looking forward for the next release

@jorisvandenbossche jorisvandenbossche deleted the jorisvandenbossche:amueller/heterogeneous_feature_union branch May 30, 2018

@armgilles

This comment has been minimized.

armgilles commented May 31, 2018

Next release will be amazing !

@amueller

This comment has been minimized.

Member

amueller commented May 31, 2018

OMG this is great! Thank you so much for your work (and patience) on this one @jorisvandenbossche

@partmor

This comment has been minimized.

Contributor

partmor commented Jun 1, 2018

Thank you @jorisvandenbossche!! Great stuff.
I have a question regarding this feature (I'm very new in GitHub, not sure if this is the right place... apologies in advance):

Is there going to be an effort (I would like to contribute) to implement get_feature_names in the majority of transformers?

I find that one of the big advantages that DataFrameMapper from sklearn-pandas brought to us also is the ability to trace names of derived features (using aliases and df_out=True), with what this means for interpretability (e.g. get some feature importances for a tree based model after a fairly complex non-sequential preprocessing pipeline). Having get_feature_names working consistently in ColumnTransformer would be the bomb.

What do you guys think?

Thank you in advance.

@amueller

This comment has been minimized.

Member

amueller commented Jun 1, 2018

@partmor See #9606 and #6425. What exactly was working with DataFrameMapper that's not currently working? I feel like the main usecase will be with OneHotEncoder/CategoricalEncoder who will provide a get_feature_names. For most multivariate transformations it's hard to get feature names, so I don't know how DataFrameMapper did that.

@partmor

This comment has been minimized.

Contributor

partmor commented Jun 1, 2018

@amueller thank you for the links. For instance, if we want to use StandardScaler on a set of numeric variables, ColumnTransformer raises an exception because SC does not have get_feature_names implemented. In "1 to 1" column transformations like standard scaling in DataFrameMapper you could just passthrough the feature names: ['x0', 'x1',...] to ['x0', 'x1',...]. ColumnTransformer.get_feature_names() just raises exception by using SC in it.

@amueller

This comment has been minimized.

Member

amueller commented Jun 2, 2018

@partmor yeah for univariate things like that it would be easy to implement. We should revisit #6425 keeping in mind that ColumnTransformer relies heavily on it.

@jnothman

This comment has been minimized.

Member

jnothman commented Jun 3, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment