Skip to content

Commit

Permalink
FIX Preserve parameters when slicing a pipeline (#18429)
Browse files Browse the repository at this point in the history
Co-authored-by: Paweł Biernat <pawel.biernat@ardigen.com>
  • Loading branch information
albertvillanova and Paweł Biernat committed Sep 21, 2020
1 parent 06d6f8a commit 174f935
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 11 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -530,6 +530,11 @@ Changelog
will raise a ``ValueError``.
:pr:`17876` by :user:`Cary Goltermann <Ultramann>`.

- |Fix| A slice of a :class:`pipeline.Pipeline` now inherits the parameters of
the original pipeline (`memory` and `verbose`).
:pr:`18429` by :user:`Albert Villanova del Moral <albertvillanova>` and
:user:`Paweł Biernat <pwl>`.

:mod:`sklearn.preprocessing`
............................

Expand Down
6 changes: 4 additions & 2 deletions sklearn/pipeline.py
Expand Up @@ -207,8 +207,10 @@ def __getitem__(self, ind):
"""
if isinstance(ind, slice):
if ind.step not in (1, None):
raise ValueError('Pipeline slicing only supports a step of 1')
return self.__class__(self.steps[ind])
raise ValueError("Pipeline slicing only supports a step of 1")
return self.__class__(
self.steps[ind], memory=self.memory, verbose=self.verbose
)
try:
name, est = self.steps[ind]
except TypeError:
Expand Down
35 changes: 26 additions & 9 deletions sklearn/tests/test_pipeline.py
Expand Up @@ -558,15 +558,32 @@ def test_pipeline_fit_transform():
assert_array_almost_equal(X_trans, X_trans2)


def test_pipeline_slice():
pipe = Pipeline([('transf1', Transf()),
('transf2', Transf()),
('clf', FitParamT())])
pipe2 = pipe[:-1]
assert isinstance(pipe2, Pipeline)
assert pipe2.steps == pipe.steps[:-1]
assert 2 == len(pipe2.named_steps)
assert_raises(ValueError, lambda: pipe[::-1])
@pytest.mark.parametrize("start, end", [(0, 1), (0, 2), (1, 2), (1, 3),
(None, 1), (1, None), (None, None)])
def test_pipeline_slice(start, end):
pipe = Pipeline(
[("transf1", Transf()), ("transf2", Transf()), ("clf", FitParamT())],
memory="123",
verbose=True,
)
pipe_slice = pipe[start:end]
# Test class
assert isinstance(pipe_slice, Pipeline)
# Test steps
assert pipe_slice.steps == pipe.steps[start:end]
# Test named_steps attribute
assert list(pipe_slice.named_steps.items()) == list(
pipe.named_steps.items())[start:end]
# Test the rest of the parameters
pipe_params = pipe.get_params(deep=False)
pipe_slice_params = pipe_slice.get_params(deep=False)
del pipe_params["steps"]
del pipe_slice_params["steps"]
assert pipe_params == pipe_slice_params
# Test exception
msg = "Pipeline slicing only supports a step of 1"
with pytest.raises(ValueError, match=msg):
pipe[start:end:-1]


def test_pipeline_index():
Expand Down

0 comments on commit 174f935

Please sign in to comment.