Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+3] ENH Caching Pipeline by memoizing transformer #7990

Merged
merged 5 commits into from Feb 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
144 changes: 107 additions & 37 deletions doc/modules/pipeline.rst
Expand Up @@ -39,13 +39,10 @@ is an estimator object::
>>> from sklearn.decomposition import PCA
>>> estimators = [('reduce_dim', PCA()), ('clf', SVC())]
>>> pipe = Pipeline(estimators)
>>> pipe # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power='auto',
n_components=None, random_state=None, svd_solver='auto', tol=0.0,
whiten=False)), ('clf', SVC(C=1.0, cache_size=200, class_weight=None,
coef0=0.0, decision_function_shape='ovr', degree=3, gamma='auto',
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])
>>> pipe # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
Pipeline(memory=None,
steps=[('reduce_dim', PCA(copy=True,...)),
('clf', SVC(C=1.0,...))])

The utility function :func:`make_pipeline` is a shorthand
for constructing pipelines;
Expand All @@ -56,7 +53,8 @@ filling in the names automatically::
>>> from sklearn.naive_bayes import MultinomialNB
>>> from sklearn.preprocessing import Binarizer
>>> make_pipeline(Binarizer(), MultinomialNB()) # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('binarizer', Binarizer(copy=True, threshold=0.0)),
Pipeline(memory=None,
steps=[('binarizer', Binarizer(copy=True, threshold=0.0)),
('multinomialnb', MultinomialNB(alpha=1.0,
class_prior=None,
fit_prior=True))])
Expand All @@ -76,30 +74,26 @@ and as a ``dict`` in ``named_steps``::
Parameters of the estimators in the pipeline can be accessed using the
``<estimator>__<parameter>`` syntax::

>>> pipe.set_params(clf__C=10) # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power='auto',
n_components=None, random_state=None, svd_solver='auto', tol=0.0,
whiten=False)), ('clf', SVC(C=10, cache_size=200, class_weight=None,
coef0=0.0, decision_function_shape='ovr', degree=3, gamma='auto',
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False))])

>>> pipe.set_params(clf__C=10) # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
Pipeline(memory=None,
steps=[('reduce_dim', PCA(copy=True, iterated_power='auto',...)),
('clf', SVC(C=10, cache_size=200, class_weight=None,...))])

This is particularly important for doing grid searches::

>>> from sklearn.model_selection import GridSearchCV
>>> params = dict(reduce_dim__n_components=[2, 5, 10],
... clf__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(pipe, param_grid=params)
>>> param_grid = dict(reduce_dim__n_components=[2, 5, 10],
... clf__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(pipe, param_grid=param_grid)

Individual steps may also be replaced as parameters, and non-final steps may be
ignored by setting them to ``None``::

>>> from sklearn.linear_model import LogisticRegression
>>> params = dict(reduce_dim=[None, PCA(5), PCA(10)],
... clf=[SVC(), LogisticRegression()],
... clf__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(pipe, param_grid=params)
>>> param_grid = dict(reduce_dim=[None, PCA(5), PCA(10)],
... clf=[SVC(), LogisticRegression()],
... clf__C=[0.1, 10, 100])
>>> grid_search = GridSearchCV(pipe, param_grid=param_grid)

.. topic:: Examples:

Expand All @@ -108,6 +102,7 @@ ignored by setting them to ``None``::
* :ref:`sphx_glr_auto_examples_plot_digits_pipe.py`
* :ref:`sphx_glr_auto_examples_plot_kernel_approximation.py`
* :ref:`sphx_glr_auto_examples_svm_plot_svm_anova.py`
* :ref:`sphx_glr_auto_examples_plot_compare_reduction.py`

.. topic:: See also:

Expand All @@ -124,6 +119,84 @@ i.e. if the last estimator is a classifier, the :class:`Pipeline` can be used
as a classifier. If the last estimator is a transformer, again, so is the
pipeline.

Caching transformers: avoid repeated computation
-------------------------------------------------

.. currentmodule:: sklearn.pipeline

Fitting transformers may be computationally expensive. With its
``memory`` parameter set, :class:`Pipeline` will cache each transformer
after calling ``fit``.
This feature is used to avoid computing the fit transformers within a pipeline
if the parameters and input data are identical. A typical example is the case of
a grid search in which the transformers can be fitted only once and reused for
each configuration.

The parameter ``memory`` is needed in order to cache the transformers.
``memory`` can be either a string containing the directory where to cache the
transformers or a `joblib.Memory <https://pythonhosted.org/joblib/memory.html>`_
object::

>>> from tempfile import mkdtemp
>>> from shutil import rmtree
>>> from sklearn.decomposition import PCA
>>> from sklearn.svm import SVC
>>> from sklearn.pipeline import Pipeline
>>> estimators = [('reduce_dim', PCA()), ('clf', SVC())]
>>> cachedir = mkdtemp()
>>> pipe = Pipeline(estimators, memory=cachedir)
>>> pipe # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
Pipeline(...,
steps=[('reduce_dim', PCA(copy=True,...)),
('clf', SVC(C=1.0,...))])
>>> # Clear the cache directory when you don't need it anymore
>>> rmtree(cachedir)

.. warning:: **Side effect of caching transfomers**

Using a :class:`Pipeline` without cache enabled, it is possible to
inspect the original instance such as::

>>> from sklearn.datasets import load_digits
>>> digits = load_digits()
>>> pca1 = PCA()
>>> svm1 = SVC()
>>> pipe = Pipeline([('reduce_dim', pca1), ('clf', svm1)])
>>> pipe.fit(digits.data, digits.target)
... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
Pipeline(memory=None,
steps=[('reduce_dim', PCA(...)), ('clf', SVC(...))])
>>> # The pca instance can be inspected directly
>>> print(pca1.components_) # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
[[ -1.77484909e-19 ... 4.07058917e-18]]

Enabling caching triggers a clone of the transformers before fitting.
Therefore, the transformer instance given to the pipeline cannot be
inspected directly.
In following example, accessing the :class:`PCA` instance ``pca2``
will raise an ``AttributeError`` since ``pca2`` will be an unfitted
transformer.
Instead, use the attribute ``named_steps`` to inspect estimators within
the pipeline::

>>> cachedir = mkdtemp()
>>> pca2 = PCA()
>>> svm2 = SVC()
>>> cached_pipe = Pipeline([('reduce_dim', pca2), ('clf', svm2)],
... memory=cachedir)
>>> cached_pipe.fit(digits.data, digits.target)
... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
Pipeline(memory=...,
steps=[('reduce_dim', PCA(...)), ('clf', SVC(...))])
>>> print(cached_pipe.named_steps['reduce_dim'].components_)
... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
[[ -1.77484909e-19 ... 4.07058917e-18]]
>>> # Remove the cache directory
>>> rmtree(cachedir)

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_plot_compare_reduction.py`

.. _feature_union:

Expand Down Expand Up @@ -164,15 +237,11 @@ and ``value`` is an estimator object::
>>> from sklearn.decomposition import KernelPCA
>>> estimators = [('linear_pca', PCA()), ('kernel_pca', KernelPCA())]
>>> combined = FeatureUnion(estimators)
>>> combined # doctest: +NORMALIZE_WHITESPACE
FeatureUnion(n_jobs=1, transformer_list=[('linear_pca', PCA(copy=True,
iterated_power='auto', n_components=None, random_state=None,
svd_solver='auto', tol=0.0, whiten=False)), ('kernel_pca',
KernelPCA(alpha=1.0, coef0=1, copy_X=True, degree=3,
eigen_solver='auto', fit_inverse_transform=False, gamma=None,
kernel='linear', kernel_params=None, max_iter=None, n_components=None,
n_jobs=1, random_state=None, remove_zero_eig=False, tol=0))],
transformer_weights=None)
>>> combined # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
FeatureUnion(n_jobs=1,
transformer_list=[('linear_pca', PCA(copy=True,...)),
('kernel_pca', KernelPCA(alpha=1.0,...))],
transformer_weights=None)


Like pipelines, feature unions have a shorthand constructor called
Expand All @@ -182,11 +251,12 @@ Like pipelines, feature unions have a shorthand constructor called
Like ``Pipeline``, individual steps may be replaced using ``set_params``,
and ignored by setting to ``None``::

>>> combined.set_params(kernel_pca=None) # doctest: +NORMALIZE_WHITESPACE
FeatureUnion(n_jobs=1, transformer_list=[('linear_pca', PCA(copy=True,
iterated_power='auto', n_components=None, random_state=None,
svd_solver='auto', tol=0.0, whiten=False)), ('kernel_pca', None)],
transformer_weights=None)
>>> combined.set_params(kernel_pca=None)
... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
FeatureUnion(n_jobs=1,
transformer_list=[('linear_pca', PCA(copy=True,...)),
('kernel_pca', None)],
transformer_weights=None)

.. topic:: Examples:

Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -56,6 +56,9 @@ Enhancements
- :class:`multioutput.MultiOutputRegressor` and :class:`multioutput.MultiOutputClassifier`
now support online learning using `partial_fit`.
issue: `8053` by :user:`Peng Yu <yupbank>`.
- :class:`pipeline.Pipeline` allows to cache transformers
within a pipeline by using the ``memory`` constructor parameter.
By :issue:`7990` by :user:`Guillaume Lemaitre <glemaitre>`.

- :class:`decomposition.PCA`, :class:`decomposition.IncrementalPCA` and
:class:`decomposition.TruncatedSVD` now expose the singular values
Expand Down
67 changes: 61 additions & 6 deletions examples/plot_compare_reduction.py 100644 → 100755
@@ -1,4 +1,4 @@
#!/usr/bin/python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
=================================================================
Expand All @@ -7,13 +7,27 @@

This example constructs a pipeline that does dimensionality
reduction followed by prediction with a support vector
classifier. It demonstrates the use of GridSearchCV and
Pipeline to optimize over different classes of estimators in a
single CV run -- unsupervised PCA and NMF dimensionality
classifier. It demonstrates the use of ``GridSearchCV`` and
``Pipeline`` to optimize over different classes of estimators in a
single CV run -- unsupervised ``PCA`` and ``NMF`` dimensionality
reductions are compared to univariate feature selection during
the grid search.

Additionally, ``Pipeline`` can be instantiated with the ``memory``
argument to memoize the transformers within the pipeline, avoiding to fit
again the same transformers over and over.

Note that the use of ``memory`` to enable caching becomes interesting when the
fitting of a transformer is costly.
"""
# Authors: Robert McGibbon, Joel Nothman

###############################################################################
# Illustration of ``Pipeline`` and ``GridSearchCV``
###############################################################################
# This section illustrates the use of a ``Pipeline`` with
# ``GridSearchCV``

# Authors: Robert McGibbon, Joel Nothman, Guillaume Lemaitre

from __future__ import print_function, division

Expand Down Expand Up @@ -49,7 +63,7 @@
]
reducer_labels = ['PCA', 'NMF', 'KBest(chi2)']

grid = GridSearchCV(pipe, cv=3, n_jobs=2, param_grid=param_grid)
grid = GridSearchCV(pipe, cv=3, n_jobs=1, param_grid=param_grid)
digits = load_digits()
grid.fit(digits.data, digits.target)

Expand All @@ -72,4 +86,45 @@
plt.ylabel('Digit classification accuracy')
plt.ylim((0, 1))
plt.legend(loc='upper left')

###############################################################################
# Caching transformers within a ``Pipeline``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very happy with this appended to the existing example. Not sure what the right solution is, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to fight the proliferation of example (we have too many) and to make features easier to discover.

What we have found in nilearn is that the notebook-style example of sphinx-gallery enable to add sections in examples and hence enhance simple examples without compromising their readability. What's important is that each section contributes one idea and is kept simple.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I am fine with it (I prefer that to adding a new example).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I prefer to avoid a new example.

Copy link
Member

@ogrisel ogrisel Feb 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's revert back this example to its previous state and instead enable the memory on the following example that really benefit from caching:

examples/model_selection/grid_search_text_feature_extraction.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this example does not benefit either because the fit is too quick and pickling / unpickling large dicts with string keys (the vocabulary of the vectorizers) is comparatively more expensive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine here. So we need to think of our examples more like notebooks. Sounds fine.

###############################################################################
# It is sometimes worthwhile storing the state of a specific transformer
# since it could be used again. Using a pipeline in ``GridSearchCV`` triggers
# such situations. Therefore, we use the argument ``memory`` to enable caching.
#
# .. warning::
# Note that this example is, however, only an illustration since for this
# specific case fitting PCA is not necessarily slower than loading the
# cache. Hence, use the ``memory`` constructor parameter when the fitting
# of a transformer is costly.

from tempfile import mkdtemp
from shutil import rmtree
from sklearn.externals.joblib import Memory

# Create a temporary folder to store the transformers of the pipeline
cachedir = mkdtemp()
memory = Memory(cachedir=cachedir, verbose=10)
cached_pipe = Pipeline([('reduce_dim', PCA()),
('classify', LinearSVC())],
memory=memory)

# This time, a cached pipeline will be used within the grid search
grid = GridSearchCV(cached_pipe, cv=3, n_jobs=1, param_grid=param_grid)
digits = load_digits()
grid.fit(digits.data, digits.target)

# Delete the temporary cache before exiting
rmtree(cachedir)

###############################################################################
# The ``PCA`` fitting is only computed at the evaluation of the first
# configuration of the ``C`` parameter of the ``LinearSVC`` classifier. The
# other configurations of ``C`` will trigger the loading of the cached ``PCA``
# estimator data, leading to save processing time. Therefore, the use of
# caching the pipeline using ``memory`` is highly beneficial when fitting
# a transformer is costly.

plt.show()