Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG + 2] ENH enable setting pipeline components as parameters #1769

Merged
merged 29 commits into from
Aug 29, 2016

Conversation

jnothman
Copy link
Member

Until now, get_params() would return the steps of a pipeline by name, but setting them would fail silently (by setting an unused attribute); fixes bug #1800.

This allows users to grid search over alternative estimators for some step, as illustrated in the included example, or even to delete a step in a search by setting it to None.
But it may also be more directly practical: while a user may currently use get_params to extract an estimator from a nested pipeline using the double-underscore naming convention, e.g. for selective serialisation, they cannot use the reciprocal set_params which this PR enables.

This changeset also prohibits step names that equal initialisation parameters of the pipeline; otherwise FeatureUnion.set_params(transformer_weights=foo) would be ambiguous.

@@ -140,47 +196,46 @@ def fit_transform(self, X, y=None, **fit_params):
else:
return self.steps[-1][-1].fit(Xt, y, **fit_params).transform(Xt)

def run_pipeline(self, X, est_method, est_args=(), est_kwargs={}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a run_pipeline to appear on FeatureUnion, so run would be a better name. I don't really see the use of this at present, though. Also, the docstring is incomplete (the types should be documented).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its only purpose was to refactor. I made it public as a "why not?" decision, which is probably why I forgot the docstring. Or perhaps because parameter descriptions are lacking throughout the class, and are incomplete in FeatureUnion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this only applies to Pipeline, not _BasePipeline, so run_pipeline may indeed be appropriate (unless you think run is sufficiently descriptive). An equivalent does not easily apply to FeatureUnion.

@larsmans
Copy link
Member

I don't like the way this overloads the meaning of set_params. I'm thinking instead of an either list-like or dict-like interface with a __setitem__ to set a step to a different estimator. WDYT?

@jnothman
Copy link
Member Author

@larsmans, the problem is that in order for the underscore notation to work, let's say param "X__Y", apparently "X" needs to be a valid param (though perhaps this is a bug in BaseEstimator.set_params, but I have no idea what motivation there is for the current code and so wouldn't dare remove it). Therefore "X" needs to be returned by get_params(deep=True), as it had been previously. But calling pipeline.set_params(X=blah) would just perform pipeline.X = blah to no meaningful effect.

Now, one option is to make pipeline.X = blah meaningful, one of:

  • add BasePipeline.__{g,s}etattr__ which identifies step names and handles them differently
  • actually just store the steps by name in pipeline.__dict__, and compile pipeline.steps (and the unused pipeline.named_steps) on the fly: steps = [getattr(self, name) for name in self._step_names].

I personally think actually having the steps as attributes on the object is ugly. Already having to have the step names and the initialiser arguments in the same namespace is ugly. Having to also make sure step names don't conflict with actual attributes of the classifier (including things like 'fit', 'transform' which are certainly possible step names in existing code) seems unreasonable.

As much as modifying some of the standard semantics of set_params is undesirable, breaking existing code is less desirable. But I'm open to suggestions.

So I pass the ball back to you...

@amueller
Copy link
Member

I like the feature a lot, didn't have time to look into the implementation, though.,.. and won't have until next week :-/

@amueller
Copy link
Member

It would be nice if we could also set an estimator in the pipline to None, meaning a step should be skipped.

@jnothman
Copy link
Member Author

Please help with test fail: it's failing because _BasePipeline inherits from BaseEstimator, and hence is presumed to be testable among all estimators in sklearn/tests/test_common.py. Is the correct fix to make _BasePipeline some form of mixin (and hence not use super to call BaseEstimator.set_params), or to add _BasePipeline to sklearn.test_common.dont_test.

@amueller
Copy link
Member

The correct method is to make _BasePipeline an abstract base class by making the constructor an abstract function. look at other base classes for examples.

@jnothman
Copy link
Member Author

@amueller I like the idea of setting a step to None, but I think there's too much risk of some function -- be it a method on a _BasePipeline descendant, or some external library -- forgetting to handle that case. Why not just let the user supply a dummy? If sklearn lacks a library of dummies, this is a use-case worth considering.

@amueller
Copy link
Member

Yeah I also thought about that. I am not sure I buy your argument, though ;)
It would mean some extra code in the Pipeline, but it would be much more user friendly.

@jnothman
Copy link
Member Author

Presumably you would also need raise an error if the final step in the pipeline is None... How about, when we're happy with this patch, we consider the ramifications of that enhancement.

@jnothman
Copy link
Member Author

Related to @larsmans' comment, set_params order of setting is now significant:

grid_clf = GridSearchCV(
    Pipeline([('sel', SelectKBest(chi2)), ('est', LogisticRegression())]),
    param_grid={'est': [LogisticRegression(), LinearSVC()], 'est__C': [0.1, 1.0]}
)

If 'est__C' is set before 'est' -- dependent on the implementation of dict.iterkeys were it left to BaseEstimator.set_params -- it won't have any effect (and worse, GridSearchCV will act as if it did the right thing).

It is also possible with this patch to set 'steps' as well as one of the steps, and the current implementation doesn't address this in set_params.

We could ignore this last somewhat-pathological case and generally ensure that BaseEstimator.set_params iterates in order of string length, or do those without underscores before those with. Or we can implement orderings on a per-estimator basis (which overriding set_params here does, but there needs to be a comment to that effect; an alternative would be for estimators to explicitly define a parameter ordering when appropriate).

And it's worth considering whether there's a use-case where the user would require an explicit parameter ordering, and whether we care to facilitate that...

@jnothman
Copy link
Member Author

jnothman commented Apr 7, 2013

These latest changes attempt to:

@jnothman
Copy link
Member Author

jnothman commented Apr 8, 2013

And this should possibly be augmented by an example of using advanced grid_search and pipeline features such as:

pipeline = Pipeline([('sel', SelectKBest()), ('clf', LinearSVC())]
param_grid = [{'sel__k': [5, 10, 20], 'sel__score_func': [chi2, f_classif], 'clf__C': [.1, 1, 10]}, {'sel': None, 'clf__C': [.1, 1, 10]}]
search = GridSearchCV(pipeline), param_grid=param_grid)

not that I think this is a particularly strong motivating example.

@amueller
Copy link
Member

amueller commented Apr 8, 2013

This is great :)
Unfortunately there is a deadline for ICCV next week, so I won't have much time this week either :-/ Sorry!


def score(self, X, y=None):
@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a property here? What's the benefit compared to a method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As elsewhere, it ensures that hasattr(pipeline, 'inverse_transform') iff hasattr(step, 'inverse_transform') for all step in pipeline. Makes ducktyping meaningful.

@amueller
Copy link
Member

amueller commented May 7, 2013

Currently this PR does 3 thinks if I'm correct:

  • add meaningful docstrings
  • fix the duck typing
  • make it possible to replace estimators with set_params.

Could you maybe cherry-pick the first into master? The diff is very hard to read :-/
Also, it might be worth splitting up the other two parts into two independent PRs if that is not too much work.
Did you have any comments for @GaelVaroquaux appraoch?
For the third part, I am really for the feature, but I think we need to discuss the api for it.

@jnothman
Copy link
Member Author

jnothman commented May 7, 2013

Yes, the PR is bloated. If I find the time, I'll try separate out the stuff that's not in the title of this PR.

Anyone that doesn't think steps should be set as parameters needs to contend with some other fix for #1800. So as far as I'm concerned, it's a must, or a might-as-well. I am not sure what API issues you mean. Do you mean:

  • should it actually set an attribute on the estimator for consistency with BaseEstimator.set_params? The current implementation can be changed to do that easily, but trying to do the reverse creates backwards-compatibility issues.
  • should there be some other way of specifying parameter-setting priority, rather than overwriting set_params? I consider this an implementation detail.
  • should we support setting to None? I can't see why not.

Gaël's approach is about how to fix the duck-typing and keep the code readable (which is one reason I can't just cherry-pick it into master). He tests hasattr explicitly and raises an AttributeError in the else clause, which seems redundant and verbose. But if this promises substantially more readable, explicit code, I have no problem with it (except that it consumes three more repeated lines per method). He also delegates operation to an underscore-prefixed version of each method, which I think should be avoided: all these methods do the same thing, modulo the method called on the final estimator, so the code is more explicit, and less bug-prone, when refactored into something like _run_pipeline. (But again, this is not precisely the topic of this PR.)

@amueller
Copy link
Member

amueller commented May 8, 2013

Thanks for your comments. I think we should fix the duck-typing first, and then discuss your improvement.
I must admit I did not look too closely into the implementation of the set_params, I was just a bit uneasy with it ;) there is probably no better way to resolve #1800. I think I just have to read the code in detail to convince myself that its a good idea ;)

@jnothman
Copy link
Member Author

jnothman commented May 8, 2013

Okay, I've squashed the #1805 and comments and the parameter setting separately. When you (or other) give that first commit the okay, I'll pull it into master. Then we can discuss the second commit separately.

@amueller
Copy link
Member

amueller commented May 8, 2013

Thanks for the quick update. I think it looks good but I'll have a closer look later. @GaelVaroquaux, any opinion? Still the same?

for name, trans in self.transformer_list)
delayed(_transform_one)(trans, name, weight, X)
for name, trans, weight in self._iter())
if not Xs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tested now? Now this happens when all transformers are None, right? I think I understand now what happens ;) (still maybe add a one-line comment "all transformers are None"? Is having an empty list of transformers allowed?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented. I suppose an empty list of transformers is allowed but untested.

@amueller
Copy link
Member

I think tests that are missing are:

  • init pipeline, set steps to something weird, call fit_transform (this is currently not guarded against, I think), because no call to validate_steps in fit_transform.
  • init pipeline, set some steps to None, fit and predict/transform. (validate_steps is never called with a step being None).
  • maybe set a step to None at creation time?

@jnothman
Copy link
Member Author

I've fixed those things. Also, despite my comment to @MechCoder, I've now made support for None as last estimator when [inverse_]transform is called...

@amueller
Copy link
Member

LGTM though python 2.6 complains ImportError: cannot import name assert_dict_equal

@jnothman
Copy link
Member Author

I've submitted a patch that just uses assert_equal when assert_dict_equal is not importable; it's only cosmetic anyway.

@jnothman
Copy link
Member Author

I'm astounded that this might actually, finally, be merged. If @MechCoder or @amueller would like to give its recent changes a final pass, that'd be great.

@amueller
Copy link
Member

LGTM.

Sorry, apparently I've been putting out fires in other places for the last 3 years (I've wanted this for a while).

@jnothman
Copy link
Member Author

hahaha ;)

On 29 August 2016 at 02:47, Andreas Mueller notifications@github.com
wrote:

LGTM.

Sorry, apparently I've been putting out fires in other places for the last
3 years (I've wanted this for a while).


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#1769 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz64GkZmq3E_ciJl9dVpvU6kY2x5WZks5qkbuHgaJpZM4Afu26
.

@jnothman
Copy link
Member Author

Well, I suppose I'd better merge before what's new merge conflicts appear!

@jnothman jnothman merged commit 5b20d48 into scikit-learn:master Aug 29, 2016
@MechCoder
Copy link
Member

Sorry for the delay. Does it still need a review? :p

@MechCoder
Copy link
Member

Thanks for your effort and long wait!

@amueller
Copy link
Member

Wohoo!

@betatim
Copy link
Member

betatim commented Aug 29, 2016

This will be great!

@jnothman
Copy link
Member Author

Is it worth highlighting this feature alongside model_selection changes in what's new? It happens to fit very nicely with being able to access values of each parameter searched by grid search.

TomDLT pushed a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016
…arn#1769)

Pipeline and FeatureUnion steps may now be set with set_params, and transformers may be replaced with None to effectively remove them.

Also test and improve ducktyping of Pipeline methods
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…arn#1769)

Pipeline and FeatureUnion steps may now be set with set_params, and transformers may be replaced with None to effectively remove them.

Also test and improve ducktyping of Pipeline methods
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants