Skip to content

Commit

Permalink
[MRG] ENH enable setting pipeline components as parameters (#1769)
Browse files Browse the repository at this point in the history
Pipeline and FeatureUnion steps may now be set with set_params, and transformers may be replaced with None to effectively remove them.

Also test and improve ducktyping of Pipeline methods
  • Loading branch information
jnothman committed Aug 29, 2016
1 parent ebb0645 commit 5b20d48
Show file tree
Hide file tree
Showing 7 changed files with 761 additions and 166 deletions.
38 changes: 27 additions & 11 deletions doc/modules/pipeline.rst
Expand Up @@ -37,17 +37,16 @@ is an estimator object::
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.svm import SVC
>>> from sklearn.decomposition import PCA
>>> estimators = [('reduce_dim', PCA()), ('svm', SVC())]
>>> clf = Pipeline(estimators)
>>> clf # doctest: +NORMALIZE_WHITESPACE
>>> estimators = [('reduce_dim', PCA()), ('clf', SVC())]
>>> pipe = Pipeline(estimators)
>>> pipe # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power=4,
n_components=None, random_state=None, svd_solver='auto', tol=0.0,
whiten=False)), ('svm', SVC(C=1.0, cache_size=200, class_weight=None,
whiten=False)), ('clf', SVC(C=1.0, cache_size=200, class_weight=None,
coef0=0.0, decision_function_shape=None, degree=3, gamma='auto',
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])


The utility function :func:`make_pipeline` is a shorthand
for constructing pipelines;
it takes a variable number of estimators and returns a pipeline,
Expand All @@ -64,23 +63,23 @@ filling in the names automatically::

The estimators of a pipeline are stored as a list in the ``steps`` attribute::

>>> clf.steps[0]
>>> pipe.steps[0]
('reduce_dim', PCA(copy=True, iterated_power=4, n_components=None, random_state=None,
svd_solver='auto', tol=0.0, whiten=False))

and as a ``dict`` in ``named_steps``::

>>> clf.named_steps['reduce_dim']
>>> pipe.named_steps['reduce_dim']
PCA(copy=True, iterated_power=4, n_components=None, random_state=None,
svd_solver='auto', tol=0.0, whiten=False)

Parameters of the estimators in the pipeline can be accessed using the
``<estimator>__<parameter>`` syntax::

>>> clf.set_params(svm__C=10) # doctest: +NORMALIZE_WHITESPACE
>>> pipe.set_params(clf__C=10) # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power=4,
n_components=None, random_state=None, svd_solver='auto', tol=0.0,
whiten=False)), ('svm', SVC(C=10, cache_size=200, class_weight=None,
whiten=False)), ('clf', SVC(C=10, cache_size=200, class_weight=None,
coef0=0.0, decision_function_shape=None, degree=3, gamma='auto',
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])
Expand All @@ -90,9 +89,17 @@ This is particularly important for doing grid searches::

>>> from sklearn.model_selection import GridSearchCV
>>> params = dict(reduce_dim__n_components=[2, 5, 10],
... svm__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(clf, param_grid=params)
... clf__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(pipe, param_grid=params)

Individual steps may also be replaced as parameters, and non-final steps may be
ignored by setting them to ``None``::

>>> from sklearn.linear_model import LogisticRegression
>>> params = dict(reduce_dim=[None, PCA(5), PCA(10)],
... clf=[SVC(), LogisticRegression()],
... clf__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(pipe, param_grid=params)

.. topic:: Examples:

Expand Down Expand Up @@ -172,6 +179,15 @@ Like pipelines, feature unions have a shorthand constructor called
:func:`make_union` that does not require explicit naming of the components.


Like ``Pipeline``, individual steps may be replaced using ``set_params``,
and ignored by setting to ``None``::

>>> combined.set_params(kernel_pca=None) # doctest: +NORMALIZE_WHITESPACE
FeatureUnion(n_jobs=1, transformer_list=[('linear_pca', PCA(copy=True,
iterated_power=4, n_components=None, random_state=None,
svd_solver='auto', tol=0.0, whiten=False)), ('kernel_pca', None)],
transformer_weights=None)

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_feature_stacker.py`
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -286,6 +286,12 @@ Enhancements
(`#5805 <https://github.com/scikit-learn/scikit-learn/pull/5805>`_)
By `Ibraim Ganiev`_.

- Added support for substituting or disabling :class:`pipeline.Pipeline`
and :class:`pipeline.FeatureUnion` components using the ``set_params``
interface that powers :mod:`sklearn.grid_search`.
See :ref:`example_plot_compare_reduction.py`. By `Joel Nothman`_ and
`Robert McGibbon`_.

Bug fixes
.........

Expand Down
75 changes: 75 additions & 0 deletions examples/plot_compare_reduction.py
@@ -0,0 +1,75 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
=================================================================
Selecting dimensionality reduction with Pipeline and GridSearchCV
=================================================================
This example constructs a pipeline that does dimensionality
reduction followed by prediction with a support vector
classifier. It demonstrates the use of GridSearchCV and
Pipeline to optimize over different classes of estimators in a
single CV run -- unsupervised PCA and NMF dimensionality
reductions are compared to univariate feature selection during
the grid search.
"""
# Authors: Robert McGibbon, Joel Nothman

from __future__ import print_function, division

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.decomposition import PCA, NMF
from sklearn.feature_selection import SelectKBest, chi2

print(__doc__)

pipe = Pipeline([
('reduce_dim', PCA()),
('classify', LinearSVC())
])

N_FEATURES_OPTIONS = [2, 4, 8]
C_OPTIONS = [1, 10, 100, 1000]
param_grid = [
{
'reduce_dim': [PCA(iterated_power=7), NMF()],
'reduce_dim__n_components': N_FEATURES_OPTIONS,
'classify__C': C_OPTIONS
},
{
'reduce_dim': [SelectKBest(chi2)],
'reduce_dim__k': N_FEATURES_OPTIONS,
'classify__C': C_OPTIONS
},
]
reducer_labels = ['PCA', 'NMF', 'KBest(chi2)']

grid = GridSearchCV(pipe, cv=3, n_jobs=2, param_grid=param_grid)
digits = load_digits()
grid.fit(digits.data, digits.target)

mean_scores = np.array(grid.results_['test_mean_score'])
# scores are in the order of param_grid iteration, which is alphabetical
mean_scores = mean_scores.reshape(len(C_OPTIONS), -1, len(N_FEATURES_OPTIONS))
# select score for best C
mean_scores = mean_scores.max(axis=0)
bar_offsets = (np.arange(len(N_FEATURES_OPTIONS)) *
(len(reducer_labels) + 1) + .5)

plt.figure()
COLORS = 'bgrcmyk'
for i, (label, reducer_scores) in enumerate(zip(reducer_labels, mean_scores)):
plt.bar(bar_offsets + i, reducer_scores, label=label, color=COLORS[i])

plt.title("Comparing feature reduction techniques")
plt.xlabel('Reduced number of features')
plt.xticks(bar_offsets + len(reducer_labels) / 2, N_FEATURES_OPTIONS)
plt.ylabel('Digit classification accuracy')
plt.ylim((0, 1))
plt.legend(loc='upper left')
plt.show()

0 comments on commit 5b20d48

Please sign in to comment.