Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Pipeline can now be sliced or indexed #2568

Merged
merged 19 commits into from
Mar 7, 2019

Conversation

jnothman
Copy link
Member

@jnothman jnothman commented Nov 2, 2013

This PR offers an alternative to #2561 and #2562, making it easy to apply inverse transforms (or transforms) over only a sub-sequence of steps in a pipeline. Thus:

pca = PCA()
clf = LinearSVC()
pipeline = Pipeline([('pca', pca), ('clf', clf)])

pipeline.fit(X, y)
pipeline[:-1].inverse_transform(pipeline[-1].coef_)

@schwarty, is this sufficient for your needs?

Closes #8431, an alternative API
Closes #8448, an alternative API

@GaelVaroquaux
Copy link
Member

I have a strong dislike for overriding "__" methods in order to give a domain-specific behavior to objects. I find that it leads to code that is not explicit and hard to read unless you know the package very well. Of course, a method with an explicite name ("get_estimators") would not raise such criticism from me. However, when discussing @schwarty's usecase, the reason that we had envisaged a "get_estimated" method was that, for the better or the worst, it could be useful in many composite estimators other than the pipeline, for instance the GridSearch, and it would somewhat abstract the details of the compositing done by the
estimator.

@jnothman
Copy link
Member Author

jnothman commented Nov 2, 2013

And sorry, I forgot to commit changes to sklearn.utils.testing. Now tests should pass.

@coveralls
Copy link

Coverage Status

Coverage remained the same when pulling d02a64a on jnothman:pipeline_slice into f2ceb4f on scikit-learn:master.

@jnothman
Copy link
Member Author

jnothman commented Nov 2, 2013

I have a strong like for APIs that are convenient and intuitive with minimal surprises, and in that, magic methods are little different. I don't find it surprising that a Pipeline should have syntactic behaviours akin to a list or an ordered dict (had Pipeline explicitly declared itself a Python Sequence I would not be surprised, and hence do not consider this "domain behaviour").

I find the code resulting from this proposal much more intuitive, unsurprising and explicit than get_estimated, while being much more easily read than Pipeline(pipeline.steps[:-1]).inverse_transform(pipeline.steps[-1][1]).

@larsmans
Copy link
Member

I actually like this idea because it follows a Python convention, but it still needs to be documented somehow. In particular, it's not immediately obvious that slicing a pipeline makes a shallow copy: settings the steps in a slice changes the slice, but fitting one of the estimators changes the original.

@jnothman
Copy link
Member Author

Certainly a clone is not sufficient, but would it be more friendly (if
more expensive) to use a deep copy for people to access fitted models?

Also, the semantics are no different from other Python containers, but that
doesn't make them ideal to this purpose.

On Fri, Nov 22, 2013 at 8:45 PM, Lars Buitinck notifications@github.comwrote:

I actually like this idea quite a lot, but it still needs to be
documented somehow. In particular, it's not immediately obvious that
slicing a pipeline makes a shallow copy: settings the steps in a slice
changes the slice, but fitting one of the estimators changes the original.


Reply to this email directly or view it on GitHubhttps://github.com//pull/2568#issuecomment-29060594
.

@larsmans
Copy link
Member

No, but the semantics are different from those of NumPy arrays. For simplicity's sake, I think a shallow copy is fine, it's just that we have to spell it out somewhere.

@jnothman
Copy link
Member Author

Well, I don't think we'll support __setitem__ (should we?) which is where they differ more...

@jaquesgrobler
Copy link
Member

I agree with @larsmans on this one - if it's documented well I'm quite +1 on this PR as it's quite intuitive. Though I see @GaelVaroquaux 's point too. Where are we in terms of moving forward on either this PR or the alternatives?

@jnothman
Copy link
Member Author

Added to docstrings/tests, narrative doc, example

@GaelVaroquaux
Copy link
Member

I am busy preparing a course for statistics in Python for beginners. The
variety of notation, conventions, data structures, syntaxes and
shorthands across modules (pandas, statsmodels, matplotlib, numpy,
scipy.stats) makes the course really challenging.

Think about this PR is this respect: how will students or beginners
discover a code base writen using scikit-learn?

To quote the zen of Python:

Explicit is better than implicit.
...
There should be one-- and preferably only one --obvious way to do it.

I am aware that "Beautiful is better than ugly" could be applied here. I
am just trying to motivate why I don't want this feature in: it will make
it harder for people to understand what is going on in code using
scikit-learn.

In addition, I don't believe that this PR will not solve a very general
problem, as it is specific to the pipeline. We cannot go down the path of
hacking semantics such as indexing semantics for each estimator. For
instance, it could make sens to have ensembles indexable also. The
indexing would have a completely different meaning.

@jnothman
Copy link
Member Author

As with any religious scripture, we can agree on the text of the Zen, but only marginally on its interpretation. We can probably also agree that numpy, matplotlib and probably pandas aren't exemplary disciples of Zen, where other priorities (chiefly compatibility) came into play. So I wish you luck in teaching people the "one obvious way to do it" in that context, knowing that they will read plenty of code that differs.

Anyway, if there should be one obvious way to do it, let's consider the alternatives to pipeline[-1]:

pipeline.steps[-1][1]

is both inexplicit and ugly, but we see it often enough. We could convert step tuples to namedtuples (does the API design allow us to do this in __init__?) to provide:

pipeline.steps[-1].estimator

which is explicit and not ugly, but verbose. Or we could provide a get_estimator method.

For getting a sub-pipeline, Pipeline(pipeline.steps[:-1]) isn't terrible. But for those of us that want to quickly inspect model coefficients in the original feature space, particularly in an interactive session, it's also excessively verbose (especially if from sklearn.pipeline import Pipeline is counted in that verbosity).

@jnothman
Copy link
Member Author

PS: I can't find that error in the Travis log. Any clues?

@jaquesgrobler
Copy link
Member

@jnothman regarding travis, I assume it's the doctest one:

Check that the pseudo likelihood is computed without clipping. ... ok
test_rbm.test_rbm_verbose ... ok
Make sure RBM works with sparse input when verbose=True ... ok
Doctest: sklearn.pipeline.Pipeline ... FAIL
Doctest: sklearn.preprocessing.data.OneHotEncoder ... ok
Doctest: sklearn.preprocessing.data.PolynomialFeatures ... ok

Though this doesn't give much on what or how

-- trying to see if I can narrow it down

@jaquesgrobler
Copy link
Member

@jnothman

Here you go, I think:

======================================================================
FAIL: Doctest: sklearn.pipeline.Pipeline
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/lib/python2.7/doctest.py", line 2201, in runTest
    raise self.failureException(self.format_failure(new.getvalue()))
AssertionError: Failed doctest test for sklearn.pipeline.Pipeline
  File "/home/travis/build/scikit-learn/scikit-learn/sklearn/pipeline.py", line 26, in Pipeline

----------------------------------------------------------------------
File "/home/travis/build/scikit-learn/scikit-learn/sklearn/pipeline.py", line 78, in sklearn.pipeline.Pipeline
Failed example:
    coef = anova_svm[-1].coef_
Expected:
    (1, 5)
Got nothing
----------------------------------------------------------------------
File "/home/travis/build/scikit-learn/scikit-learn/sklearn/pipeline.py", line 80, in sklearn.pipeline.Pipeline
Failed example:
    anova_svm['clf'] is anova_svm[-1]
Exception raised:
    Traceback (most recent call last):
      File "/usr/lib/python2.7/doctest.py", line 1289, in __run
        compileflags, 1) in test.globs
      File "<doctest sklearn.pipeline.Pipeline[15]>", line 1, in <module>
        anova_svm['clf'] is anova_svm[-1]
      File "/home/travis/build/scikit-learn/scikit-learn/sklearn/pipeline.py", line 129, in __getitem__
        return self.named_steps[ind]
    KeyError: 'clf'
----------------------------------------------------------------------
File "/home/travis/build/scikit-learn/scikit-learn/sklearn/pipeline.py", line 82, in sklearn.pipeline.Pipeline
Failed example:
    coef.shape
Expected nothing
Got:
    (1, 10)

>>  raise self.failureException(self.format_failure(<StringIO.StringIO instance at 0x616f440>.getvalue()))

@jaquesgrobler
Copy link
Member

That's the only failure

@jnothman
Copy link
Member Author

Oh. I must have modified it after testing it locally. Thanks.

On Tue, Nov 26, 2013 at 10:25 PM, Jaques Grobler
notifications@github.comwrote:

That's the only failure


Reply to this email directly or view it on GitHubhttps://github.com//pull/2568#issuecomment-29284874
.

@coveralls
Copy link

Coverage Status

Coverage remained the same when pulling d305273 on jnothman:pipeline_slice into f2ceb4f on scikit-learn:master.

@ogrisel
Copy link
Member

ogrisel commented Nov 29, 2013

I also find that indexing pipeline is intuitive as it's fundamentally an ordered sequence. To me this is the one obvious way to do it, that is to construct a sub-pipeline without leading or trailing estimators and visualize the partial (inverse) transformations the produce on test data for model inspection purpose.

@GaelVaroquaux
Copy link
Member

As with any religious scripture, we can agree on the text of the Zen, but only marginally on its interpretation.

:). I like that analogy.

We can probably also agree that numpy, matplotlib and probably pandas
aren't exemplary disciples of Zen, where other priorities (chiefly
compatibility) came into play.

Yes, and it's a big problem for beginners.

So I wish you luck in teaching people the "one obvious way to do it" in
that context, knowing that they will read plenty of code that differs.

We shouldn't make things worse. They are already pretty bad.

For getting a sub-pipeline, Pipeline(pipeline.steps[:-1]) isn't
terrible. But for those of us that want to quickly inspect model
coefficients in the original feature space,

I'd like to stress that this proposed solution doesn't help at all the
problem that we are facing at the lab, which is that in the case of
composed estimators (pipeline, grid-search, multi-task), we have to write
custom code to retrieve model parameters. So we are proposing an
extension of API that is very custom to an estimator. This kind of
approach raises red flags for me as a software architect.

@agramfort
Copy link
Member

agramfort commented Mar 1, 2019 via email

@GaelVaroquaux
Copy link
Member

I'm still not sold on overriding the dunder methods (after all these years :D ).

I heard the arguments against a method called "get_slice" (which are that "slice" is a word that non-Python users might not identify with what we are doing here). I would suggest "get_segment", or "get_portion" (I prefer "get_segment".

@jnothman
Copy link
Member Author

jnothman commented Mar 1, 2019 via email

@agramfort
Copy link
Member

agramfort commented Mar 1, 2019 via email

@amueller
Copy link
Member

amueller commented Mar 1, 2019

I feel like any possible word will be less intuitive than using slicing syntax and will make it more complicated.

@@ -188,6 +199,26 @@ def _iter(self, with_final=True):
if trans is not None and trans != 'passthrough':
yield idx, name, trans

def __getitem__(self, ind):
"""Returns a sub-pipeline or a single esimtator in the pipeline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in "estimator"

returns another Pipeline instance which copies a slice of this
Pipeline. This copy is shallow: modifying (or fitting) estimators in
the sub-pipeline will affect the larger pipeline and vice-versa.
However, replacing a value in `step` will not affect a copy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps clarify: replacing a value in step in the original pipeline instance of the sub-pipeline instance.

assert pipe['transf'] == transf
assert pipe[-1] == clf
assert pipe['clf'] == clf
assert_raises(IndexError, lambda: pipe[3])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps another test could be added for sub-pipeline index over several steps that exceeds the max. The present test gets at the case where a single estimator is returned, but not the case where a sub-pipeline is returned as a Pipeline() instance.

Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you didn't allow pushing to the branch so I suggested changes ;)

sklearn/pipeline.py Outdated Show resolved Hide resolved
sklearn/pipeline.py Outdated Show resolved Hide resolved
sklearn/pipeline.py Outdated Show resolved Hide resolved
Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good apart from my suggestions.

amueller and others added 4 commits March 2, 2019 23:30
Co-Authored-By: jnothman <joel.nothman@gmail.com>
Co-Authored-By: jnothman <joel.nothman@gmail.com>
Co-Authored-By: jnothman <joel.nothman@gmail.com>
@amueller
Copy link
Member

amueller commented Mar 4, 2019

hm can you fix it up or allow me to push?

@amueller
Copy link
Member

amueller commented Mar 4, 2019

see amueller@3733569

@jnothman
Copy link
Member Author

jnothman commented Mar 6, 2019 via email

@amueller
Copy link
Member

amueller commented Mar 6, 2019

right, you're still travelling. Sorry to bug you. I should have known given that they just send me your boarding pass ;)

@amueller
Copy link
Member

amueller commented Mar 6, 2019

but looks good to merge? @rth @ogrisel @qinhanmin2014 any thoughts / wanna press the green button?

doc/whats_new/v0.21.rst Outdated Show resolved Hide resolved
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise looks good :)

Co-Authored-By: jnothman <joel.nothman@gmail.com>
@adrinjalali adrinjalali merged commit 2207121 into scikit-learn:master Mar 7, 2019
@jnothman
Copy link
Member Author

jnothman commented Mar 7, 2019 via email

xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
* ENH Pipeline can now be sliced or indexed

* Additional assertion imports for testing

* DOC Documentation and example for Pipeline slicing

* FIX put doctest lines in correct order

* DOC improve compose Pipeline docs

* Fix doctest

* Fix merge error

* DOCs improved after Alex's comments

* This is not the right place to change to LinearSVC

* missed one

* DOC add what's new

* Fix doctest

* doctest tweaks

Co-Authored-By: jnothman <joel.nothman@gmail.com>

* doctest tweaks

Co-Authored-By: jnothman <joel.nothman@gmail.com>

* doctest tweaks

Co-Authored-By: jnothman <joel.nothman@gmail.com>

* fix doctests

* Correct step name

* Update doc/whats_new/v0.21.rst

Co-Authored-By: jnothman <joel.nothman@gmail.com>
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
* ENH Pipeline can now be sliced or indexed

* Additional assertion imports for testing

* DOC Documentation and example for Pipeline slicing

* FIX put doctest lines in correct order

* DOC improve compose Pipeline docs

* Fix doctest

* Fix merge error

* DOCs improved after Alex's comments

* This is not the right place to change to LinearSVC

* missed one

* DOC add what's new

* Fix doctest

* doctest tweaks

Co-Authored-By: jnothman <joel.nothman@gmail.com>

* doctest tweaks

Co-Authored-By: jnothman <joel.nothman@gmail.com>

* doctest tweaks

Co-Authored-By: jnothman <joel.nothman@gmail.com>

* fix doctests

* Correct step name

* Update doc/whats_new/v0.21.rst

Co-Authored-By: jnothman <joel.nothman@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants