Skip to content
This repository

MRG Feature stacker #1173

Merged
merged 11 commits into from over 1 year ago

7 participants

Andreas Mueller Olivier Grisel Lars Buitinck Gael Varoquaux Mathieu Blondel Alexandre Gramfort Vlad Niculae
Andreas Mueller
Owner

This estimator provides a Y piece for the pipeline.
I used it to combine word ngrams and char ngrams into a single transformer.
Basically it just concatenates the output of several transformers into one large feature.

If you think this is helpful, I'll add some docs and an example.
With this, together with Pipeline, one can build arbitrary complex graphs (with one source and one sink) of estimators in sklearn :)

TODO

  • tests
  • narrative documentation
  • example

Thanks to the awesome implementation of the BaseEstimator, grid search simply works - though with complicated graphs you get parameter names like feature_stacker__first_feature__feature_selection__percentile (more or less from my code ^^).

sklearn/linear_model/tests/test_randomized_l1.py
((13 lines not shown))
  100
+
  101
+    # center here because sparse matrices are usually not centered
  102
+    X, y, _, _, _ = center_data(X, y, True, True)
  103
+
  104
+    X_sp = sparse.csr_matrix(X)
  105
+
  106
+    F, _ = f_classif(X, y)
  107
+
  108
+    scaling = 0.3
  109
+    clf = RandomizedLogisticRegression(verbose=False, C=1., random_state=42,
  110
+                                scaling=scaling, n_resampling=50, tol=1e-3)
  111
+    feature_scores = clf.fit(X, y).scores_
  112
+    clf = RandomizedLogisticRegression(verbose=False, C=1., random_state=42,
  113
+                                scaling=scaling, n_resampling=50, tol=1e-3)
  114
+    feature_scores_sp = clf.fit(X_sp, y).scores_
  115
+    assert_equal(feature_scores, feature_scores_sp)
2
Olivier Grisel Owner

This hunk seems to be unrelated to this PR.

Andreas Mueller Owner

whoops sorry, forked from wrong branch. just a sec.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Olivier Grisel
Owner

Very interesting. I want an example first! (then documentation and tests :)

Andreas Mueller
Owner

on it :)

Olivier Grisel
Owner

@amueller to avoid forking from non-master branches you should use something such as http://volnitsky.com/project/git-prompt/

sklearn/pipeline.py
((22 lines not shown))
  223
+
  224
+    def get_feature_names(self):
  225
+        pass
  226
+
  227
+    def fit(self, X, y=None):
  228
+        for name, trans in self.transformer_list:
  229
+            trans.fit(X, y)
  230
+        return self
  231
+
  232
+    def transform(self, X):
  233
+        features = []
  234
+        for name, trans in self.transformer_list:
  235
+            features.append(trans.transform(X))
  236
+        issparse = [sparse.issparse(f) for f in features]
  237
+        if np.any(issparse):
  238
+            features = sparse.hstack(features).tocsr()
3
Olivier Grisel Owner

Maybe the tocsr() can be avoided. For instance the downstream model might prefer CSC such as ElasticNet for instance.

Lars Buitinck Owner

Then again, bugs crop up every now and then where estimators that are supposed to handle any sparse format turn out to only handle CSR. It's a good defensive strategy to produce CSR by default (and it's unfortunate that sparse.hstack doesn't do this already).

Andreas Mueller Owner

I wrote this thing in the heat of the battle and I don't remember if there was a reason or if it was just a precaution. I'm inclined to think that I put it there because something, somewhere, broke.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Andreas Mueller
Owner

Yes, it should derive from transformer mixin.
@larsmans can I interpret your comments such that you think this is a good thing to have?

Andreas Mueller
Owner

Added a toy example.

Olivier Grisel
Owner

I think such a feature stack should provide some way to do feature group normalization in one way or another. But this probably require some experiments to know which normalization pattern is useful on such beast in practice.

Anybody has practical experience or insight to share on this?

Lars Buitinck
Owner

GREAT idea! However, I don't like the name FeatureStacker much, as stacking implies putting things on top of each other, while this class concatenates things side-by-side.

I tried to find a "plumbing equivalent" of this class to keep with the pipeline metaphor, but I can't seem to find it. It's not quite a tee as it connects the various streams back together in the end. Maybe one of the other devs is more experienced with plumbing? :)

Olivier Grisel
Owner

BTW I think the example could be improved my using a less trivial example (e.g. using the digits dataset) and showing that the cross validate score best grid searched parameter set of the pipeline with stacked features is better than the pipeline with individual feature transformers separately.

Olivier Grisel
Owner

@larsmans maybe FeatureConcatenator?

Olivier Grisel ogrisel closed this
Olivier Grisel
Owner

FeatureUnion?

Lars Buitinck larsmans reopened this
Lars Buitinck
Owner

MultiTransformer?

Andreas Mueller
Owner
Gael Varoquaux
Andreas Mueller
Owner

I also like FeatureUnion.
Other possibilities: FeatureBinder, FeatureAssembler, FeatureCombiner.
Or maybe go away from feature? TransformerUnion, TransformBinder, TransformerBundle?

Hm i think I like TransformerBundle

Olivier Grisel
Owner

+1 for FeatureAssembler or FeatureUnion or TransformerBundle

Lars Buitinck
Owner

+1 for TransformerBundle.

Andreas Mueller
Owner

In my application, I found the get_feature_names very helpful - I was using text data and some handcrafted features.
I fear in general this is hard to do. I thought about doing hasattr("get_feature_names") and otherwise just return estimator_name_0, estimator_name_1,.... This might be a bit problematic, though, as I don't think there is a reliable method to get the output dimensionality of a transformer :-/

Oh and @ogrisel for the normalization, each feature should be normalized separately, right?
This is "easily" possible but feeding the object pipelines of preprocessing and transformers. As normalization might be quite application specific, I think this solution is ok for the moment.
The code doesn't actually get too messy doing this.

Andreas Mueller
Owner

ugh I just tried to work on the example and noticed that #1034 wasn't in master yet.
Without a good way to look at the grid search results, this PR is a lot less useful I think.
Have to work on #1034 more :-/

Lars Buitinck
Owner

We might introduce an n_features_out_ attribute/property on all transformers that work on feature vectors. For now, only supporting get_feature_names only when all underlying transformers do would be good enough, IMHO.

Andreas Mueller
Owner

@larsmans ok, will do that. Should be easy enough.

Andreas Mueller
Owner

Having a bit of a hard time creating a good example :-/

Olivier Grisel
Owner

Have you been able to use this kind of tool successfully for your kaggle contest? If so then we can stick to a simplistic toy example and tell in the narrative documentation which kind of feature bundle was proven useful in practice on which kind of problem (e.g. PCA feature + raw TF-IDF for text classification for instance).

Andreas Mueller
Owner

I can tell you how successful I was tomorrow ;)
It was definitely helpful to combine handcrafted features with word n-grams. Doing it using this estimator, I was still able to grid-seach for count-vectorize parameters such as min_df, ngram_range, etc. So that definitely helped.

sklearn/pipeline.py
... ...
@@ -199,3 +202,81 @@ def score(self, X, y=None):
199 202
     def _pairwise(self):
200 203
         # check if first estimator expects pairwise input
201 204
         return getattr(self.steps[0][1], '_pairwise', False)
  205
+
  206
+
  207
+class FeatureStacker(BaseEstimator, TransformerMixin):
  208
+    """Concatenates results of multiple transformer objects.
  209
+
  210
+    This estimator applies a list of transformer objects in parallel to the
  211
+    input data, then concatenates the results. This is useful to combine
  212
+    several feature extraction mechanisms into a single estimator.
4
Mathieu Blondel Owner

single feature representation?

Andreas Mueller Owner

I prefer it the way it is, as getting the features out is not the important part, the important part is formulating it as an estimator.

Mathieu Blondel Owner

I misunderstood what you meant. Since you're talking about extraction mechanisms, it may be clearer to say "in a single transformer".

Andreas Mueller Owner

agreed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
sklearn/pipeline.py
((31 lines not shown))
  232
+        for name, trans in self.transformer_list:
  233
+            if not hasattr(trans, 'get_feature_names'):
  234
+                raise AttributeError("Transformer %s does not provide"
  235
+                        " get_feature_names." % str(name))
  236
+            feature_names.extend([name + "__" + f for f in trans.get_feature_names()])
  237
+        return feature_names
  238
+
  239
+    def fit(self, X, y=None):
  240
+        """Fit all transformers using X.
  241
+
  242
+        Parameters
  243
+        ----------
  244
+        X : array-like or sparse matrix, shape (n_samples, n_features)
  245
+            Input data, used to fit transformers.
  246
+        """
  247
+        for name, trans in self.transformer_list:
5
Mathieu Blondel Owner

supporting n_jobs would be nice :)

Andreas Mueller Owner

In principle +1. Are there any transformers that use n_jobs? I am always afraid of having it on the wrong abstraction level....

Mathieu Blondel Owner

Since it is embarrassingly parallel and each transformer can take time to fit, I think supporting n_jobs would make sense.

Are there any transformers that use n_jobs?

Not that I know of but I hope that users have enough common sense to not enable n_jobs at two different levels :)

Andreas Mueller Owner

So I see you are of the optimist persuasion ;)
I'll add it.

Gael Varoquaux Owner
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
sklearn/pipeline.py
((4 lines not shown))
  205
+
  206
+
  207
+class FeatureStacker(BaseEstimator, TransformerMixin):
  208
+    """Concatenates results of multiple transformer objects.
  209
+
  210
+    This estimator applies a list of transformer objects in parallel to the
  211
+    input data, then concatenates the results. This is useful to combine
  212
+    several feature extraction mechanisms into a single estimator.
  213
+
  214
+    Parameters
  215
+    ----------
  216
+    transformers: list of (name, transformer)
  217
+        List of transformer objects to be applied to the data.
  218
+
  219
+    """
  220
+    def __init__(self, transformer_list):
2
Mathieu Blondel Owner

a transformer_weight option to give more importance to some transformers could be useful!

Andreas Mueller Owner

hm ok, why not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Mathieu Blondel
Owner

Nice idea indeed!

Andreas Mueller
Owner

@mblondel any votes on the name?

Mathieu Blondel
Owner

Some I like include FeatureAssembler, FeatureCombiner and FeatureUnion.

Andreas Mueller
Owner

Name votes:
FeatureAssembler II
FeatureCombiner I
FeatureUnion IIII
TransformerBundle III

(If I counted correctly, which is unlikely given my degree in math)
If no-one objects I'll rename to FeatureUnion and change the state of the PR to MRG.

Alexandre Gramfort
Owner
Andreas Mueller
Owner

Renamed, think this is good to go.

Andreas Mueller
Owner

Any more comments? (github claims this can not be merged but I just rebased, so it should be a fast-forward merge).

Olivier Grisel
Owner

This cannot be merged in master currently but appart from that +1 for merging :)

Gael Varoquaux

LGTM. :+1: for merge. Thanks @amueller !

Andreas Mueller amueller merged commit d087830 into from
Andreas Mueller amueller closed this
Vlad Niculae
Owner

Thank you for this convenient transformer. In my application I had to hack it a bit, and I wonder whether the feature I wanted could be more generally useful.

Basically, sometimes you want to concatenate the same feature extractor multiple times, and have some of the parameters tied when grid searching.

In my case, I was learning a hyphenator, so my data points consist of 2 strings: the one to the left of the current position and the one to the right of the current position. For this I defined a ProjectionVectorizer that has a column attribute that just says "I only work on X[:, column]" and concatenated two of these. Now, when grid searching, it is common sense to use the same n-gram range for both transformers, so the cleanest way to do this was this quick hack (no error handling):

class HomogeneousFeatureUnion(FeatureUnion):
    def set_params(self, **params):
        for key, value in params.iteritems():
            for _, transf in self.transformer_list:
                transf.set_params(**{key: value})

This can be easily extended to support both tied params and specific params. I'm not sure whether I overengineered this, but I still have the feeling that this might pop up in other people's applications, so I wanted to raise the question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
This page is out of date. Refresh to see the latest.
1  doc/modules/classes.rst
Source Rendered
@@ -817,6 +817,7 @@ Pairwise metrics
817 817
    :template: class.rst
818 818
 
819 819
    pipeline.Pipeline
  820
+   pipeline.FeatureUnion
820 821
 
821 822
 
822 823
 .. _preprocessing_ref:
48  doc/modules/pipeline.rst
Source Rendered
@@ -84,3 +84,51 @@ The pipeline has all the methods that the last estimator in the pipline has,
84 84
 i.e. if the last estimator is a classifier, the :class:`Pipeline` can be used
85 85
 as a classifier. If the last estimator is a transformer, again, so is the
86 86
 pipeline.
  87
+
  88
+
  89
+.. _feature_union:
  90
+
  91
+======================================
  92
+FeatureUnion: Concatenating features
  93
+======================================
  94
+
  95
+.. currentmodule:: sklearn.pipeline
  96
+
  97
+:class:`FeatureUnion` combines several transformer objects into a new
  98
+transformer that combines their output. A :class:`FeatureUnion` takes
  99
+a list of transformer objects. During fitting, each of these
  100
+is fit to the data independently. For transforming data, the
  101
+transformers are applied in parallel, and their output combined into a
  102
+single output array or matrix.
  103
+
  104
+:class:`FeatureUnion` serves the same purposes as :class:`Pipeline` -
  105
+convenience and joint parameter estimation and validation.
  106
+
  107
+:class:`FeatureUnion` and :class:`Pipeline` can be combined to
  108
+create complex models.
  109
+
  110
+
  111
+Usage
  112
+=====
  113
+
  114
+The :class:`FeatureUnion` is build using a list of ``(key, value)`` pairs, where
  115
+the ``key`` a string containing the name you want to give to a given transformation and ``value``
  116
+is an estimator object::
  117
+
  118
+    >>> from sklearn.pipeline import FeatureUnion
  119
+    >>> from sklearn.decomposition import PCA
  120
+    >>> from sklearn.decomposition import KernelPCA
  121
+    >>> estimators = [('linear_pca', PCA()), ('kernel_pca', KernelPCA())]
  122
+    >>> combined = FeatureUnion(estimators)   
  123
+    >>> combined # doctest: +NORMALIZE_WHITESPACE
  124
+    FeatureUnion(n_jobs=1, transformer_list=[('linear_pca', PCA(copy=True,
  125
+        n_components=None, whiten=False)), ('kernel_pca', KernelPCA(alpha=1.0,
  126
+        coef0=1, degree=3, eigen_solver='auto', fit_inverse_transform=False,
  127
+        gamma=0, kernel='linear', max_iter=None, n_components=None, tol=0))],
  128
+        transformer_weights=None)
  129
+
  130
+
  131
+                                                                       
  132
+.. topic:: Examples:
  133
+
  134
+ * :ref:`example_feature_stacker.py`
10  doc/whats_new.rst
Source Rendered
@@ -21,9 +21,13 @@ Changelog
21 21
    - Speed up of :func:`metrics.precision_recall_curve` by Conrad Lee.
22 22
 
23 23
    - Added support for reading/writing svmlight files with pairwise
24  
-   preference attribute (qid in svmlight file format) in
25  
-   :func:`datasets.dump_svmlight_file` and
26  
-   :func:`datasets.load_svmlight_file` by `Fabian Pedregosa`_.
  24
+     preference attribute (qid in svmlight file format) in
  25
+     :func:`datasets.dump_svmlight_file` and
  26
+     :func:`datasets.load_svmlight_file` by `Fabian Pedregosa`_.
  27
+
  28
+   - New estimator :ref:`FeatureUnion <feature_union>` that concatenates results
  29
+     of several transformers by `Andreas Müller`_.
  30
+
27 31
 
28 32
 API changes summary
29 33
 -------------------
60  examples/feature_stacker.py
... ...
@@ -0,0 +1,60 @@
  1
+"""
  2
+=================================================
  3
+Concatenating multiple feature extraction methods
  4
+=================================================
  5
+
  6
+In many real-world examples, there are many ways to extract features from a
  7
+dataset. Often it is benefitial to combine several methods to obtain good
  8
+performance. This example shows how to use ``FeatureUnion`` to combine
  9
+features obtained by PCA and univariate selection.
  10
+
  11
+Combining features using this transformer has the benefit that it allows
  12
+cross validation and grid searches over the whole process.
  13
+
  14
+The combination used in this example is not particularly helpful on this
  15
+dataset and is only used to illustrate the usage of FeatureUnion.
  16
+"""
  17
+
  18
+# Author: Andreas Mueller <amueller@ais.uni-bonn.de>
  19
+#
  20
+# License: BSD 3-clause
  21
+
  22
+from sklearn.pipeline import Pipeline, FeatureUnion
  23
+from sklearn.grid_search import GridSearchCV
  24
+from sklearn.svm import SVC
  25
+from sklearn.datasets import load_iris
  26
+from sklearn.decomposition import PCA
  27
+from sklearn.feature_selection import SelectKBest
  28
+
  29
+iris = load_iris()
  30
+
  31
+X, y = iris.data, iris.target
  32
+
  33
+# This dataset is way to high-dimensional. Better do PCA:
  34
+pca = PCA(n_components=2)
  35
+
  36
+# Maybe some original features where good, too?
  37
+selection = SelectKBest(k=1)
  38
+
  39
+# Build estimator from PCA and Univariate selection:
  40
+
  41
+combined_features = FeatureUnion([("pca", pca), ("univ_select", selection)])
  42
+
  43
+# Use combined features to transform dataset:
  44
+X_features = combined_features.fit(X, y).transform(X)
  45
+
  46
+# Classify:
  47
+svm = SVC(kernel="linear")
  48
+svm.fit(X_features, y)
  49
+
  50
+# Do grid search over k, n_components and C:
  51
+
  52
+pipeline = Pipeline([("features", combined_features), ("svm", svm)])
  53
+
  54
+param_grid = dict(features__pca__n_components=[1, 2, 3],
  55
+                  features__univ_select__k=[1, 2],
  56
+                  svm__C=[0.1, 1, 10])
  57
+
  58
+grid_search = GridSearchCV(pipeline, param_grid=param_grid, verbose=10)
  59
+grid_search.fit(X, y)
  60
+print(grid_search.best_estimator_)
107  sklearn/pipeline.py
@@ -8,9 +8,13 @@
8 8
 #         Alexandre Gramfort
9 9
 # Licence: BSD
10 10
 
11  
-from .base import BaseEstimator
  11
+import numpy as np
  12
+from scipy import sparse
12 13
 
13  
-__all__ = ['Pipeline']
  14
+from .base import BaseEstimator, TransformerMixin
  15
+from .externals.joblib import Parallel, delayed
  16
+
  17
+__all__ = ['Pipeline', 'FeatureUnion']
14 18
 
15 19
 
16 20
 # One round of beers on me if someone finds out why the backslash
@@ -199,3 +203,102 @@ def score(self, X, y=None):
199 203
     def _pairwise(self):
200 204
         # check if first estimator expects pairwise input
201 205
         return getattr(self.steps[0][1], '_pairwise', False)
  206
+
  207
+
  208
+def _fit_one_transformer(transformer, X, y):
  209
+    transformer.fit(X, y)
  210
+
  211
+
  212
+def _transform_one(transformer, name, X, transformer_weights):
  213
+    if transformer_weights is not None and name in transformer_weights:
  214
+        # if we have a weight for this transformer, muliply output
  215
+        return transformer.transform(X) * transformer_weights[name]
  216
+    return transformer.transform(X)
  217
+
  218
+
  219
+class FeatureUnion(BaseEstimator, TransformerMixin):
  220
+    """Concatenates results of multiple transformer objects.
  221
+
  222
+    This estimator applies a list of transformer objects in parallel to the
  223
+    input data, then concatenates the results. This is useful to combine
  224
+    several feature extraction mechanisms into a single transformer.
  225
+
  226
+    Parameters
  227
+    ----------
  228
+    transformers: list of (name, transformer)
  229
+        List of transformer objects to be applied to the data.
  230
+
  231
+    n_jobs: int, optional
  232
+        Number of jobs to run in parallel (default 1).
  233
+
  234
+    transformer_weights: dict, optional
  235
+        Multiplicative weights for features per transformer.
  236
+        Keys are transformer names, values the weights.
  237
+
  238
+    """
  239
+    def __init__(self, transformer_list, n_jobs=1, transformer_weights=None):
  240
+        self.transformer_list = transformer_list
  241
+        self.n_jobs = n_jobs
  242
+        self.transformer_weights = transformer_weights
  243
+
  244
+    def get_feature_names(self):
  245
+        """Get feature names from all transformers.
  246
+
  247
+        Returns
  248
+        -------
  249
+        feature_names : list of strings
  250
+            Names of the features produced by transform.
  251
+        """
  252
+        feature_names = []
  253
+        for name, trans in self.transformer_list:
  254
+            if not hasattr(trans, 'get_feature_names'):
  255
+                raise AttributeError("Transformer %s does not provide"
  256
+                        " get_feature_names." % str(name))
  257
+            feature_names.extend([name + "__" + f
  258
+                for f in trans.get_feature_names()])
  259
+        return feature_names
  260
+
  261
+    def fit(self, X, y=None):
  262
+        """Fit all transformers using X.
  263
+
  264
+        Parameters
  265
+        ----------
  266
+        X : array-like or sparse matrix, shape (n_samples, n_features)
  267
+            Input data, used to fit transformers.
  268
+        """
  269
+        Parallel(n_jobs=self.n_jobs)(delayed(_fit_one_transformer)(trans, X, y)
  270
+                for name, trans in self.transformer_list)
  271
+        return self
  272
+
  273
+    def transform(self, X):
  274
+        """Transform X separately by each transformer, concatenate results.
  275
+
  276
+        Parameters
  277
+        ----------
  278
+        X : array-like or sparse matrix, shape (n_samples, n_features)
  279
+            Input data to be transformed.
  280
+
  281
+        Returns
  282
+        -------
  283
+        X_t : array-like or sparse matrix, shape (n_samples, sum_n_components)
  284
+            hstack of results of transformers. sum_n_components is the
  285
+            sum of n_components (output dimension) over transformers.
  286
+        """
  287
+        Xs = Parallel(n_jobs=self.n_jobs)(
  288
+            delayed(_transform_one)(trans, name, X, self.transformer_weights)
  289
+            for name, trans in self.transformer_list)
  290
+        if any(sparse.issparse(f) for f in Xs):
  291
+            Xs = sparse.hstack(Xs).tocsr()
  292
+        else:
  293
+            Xs = np.hstack(Xs)
  294
+        return Xs
  295
+
  296
+    def get_params(self, deep=True):
  297
+        if not deep:
  298
+            return super(FeatureUnion, self).get_params(deep=False)
  299
+        else:
  300
+            out = dict(self.transformer_list)
  301
+            for name, trans in self.transformer_list:
  302
+                for key, value in trans.get_params(deep=True).iteritems():
  303
+                    out['%s__%s' % (name, key)] = value
  304
+            return out
8  sklearn/tests/test_common.py
@@ -28,7 +28,7 @@
28 28
 # import "special" estimators
29 29
 from sklearn.grid_search import GridSearchCV
30 30
 from sklearn.decomposition import SparseCoder
31  
-from sklearn.pipeline import Pipeline
  31
+from sklearn.pipeline import Pipeline, FeatureUnion
32 32
 from sklearn.pls import _PLS, PLSCanonical, PLSRegression, CCA, PLSSVD
33 33
 from sklearn.ensemble import BaseEnsemble
34 34
 from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier,\
@@ -45,9 +45,9 @@
45 45
         SpectralClustering
46 46
 from sklearn.linear_model import IsotonicRegression
47 47
 
48  
-dont_test = [Pipeline, GridSearchCV, SparseCoder, EllipticEnvelope,
49  
-        EllipticEnvelop, DictVectorizer, LabelBinarizer, LabelEncoder,
50  
-        TfidfTransformer, IsotonicRegression]
  48
+dont_test = [Pipeline, FeatureUnion, GridSearchCV, SparseCoder,
  49
+        EllipticEnvelope, EllipticEnvelop, DictVectorizer, LabelBinarizer,
  50
+        LabelEncoder, TfidfTransformer, IsotonicRegression]
51 51
 meta_estimators = [BaseEnsemble, OneVsOneClassifier, OutputCodeClassifier,
52 52
         OneVsRestClassifier, RFE, RFECV]
53 53
 
70  sklearn/tests/test_pipeline.py
@@ -2,17 +2,21 @@
2 2
 Test the pipeline module.
3 3
 """
4 4
 import numpy as np
  5
+from scipy import sparse
5 6
 
6 7
 from nose.tools import assert_raises, assert_equal, assert_false, assert_true
  8
+from numpy.testing import assert_array_equal, \
  9
+        assert_array_almost_equal
7 10
 
8 11
 from sklearn.base import BaseEstimator, clone
9  
-from sklearn.pipeline import Pipeline
  12
+from sklearn.pipeline import Pipeline, FeatureUnion
10 13
 from sklearn.svm import SVC
11 14
 from sklearn.linear_model import LogisticRegression
12 15
 from sklearn.feature_selection import SelectKBest, f_classif
13 16
 from sklearn.decomposition.pca import PCA, RandomizedPCA
14 17
 from sklearn.datasets import load_iris
15 18
 from sklearn.preprocessing import StandardScaler
  19
+from sklearn.feature_extraction.text import CountVectorizer
16 20
 
17 21
 
18 22
 class IncorrectT(BaseEstimator):
@@ -174,3 +178,67 @@ def test_pipeline_methods_preprocessing_svm():
174 178
         assert_equal(decision_function.shape, (n_samples, n_classes))
175 179
 
176 180
         pipe.score(X, y)
  181
+
  182
+
  183
+def test_feature_stacker():
  184
+    # basic sanity check for feature stacker
  185
+    iris = load_iris()
  186
+    X = iris.data
  187
+    X -= X.mean(axis=0)
  188
+    y = iris.target
  189
+    pca = RandomizedPCA(n_components=2)
  190
+    select = SelectKBest(k=1)
  191
+    fs = FeatureUnion([("pca", pca), ("select", select)])
  192
+    fs.fit(X, y)
  193
+    X_transformed = fs.transform(X)
  194
+    assert_equal(X_transformed.shape, (X.shape[0], 3))
  195
+
  196
+    # check if it does the expected thing
  197
+    assert_array_almost_equal(X_transformed[:, :-1], pca.fit_transform(X))
  198
+    assert_array_equal(X_transformed[:, -1],
  199
+            select.fit_transform(X, y).ravel())
  200
+
  201
+    # test if it also works for sparse input
  202
+    X_sp = sparse.csr_matrix(X)
  203
+    X_sp_transformed = fs.fit_transform(X_sp, y)
  204
+    assert_array_almost_equal(X_transformed, X_sp_transformed.toarray())
  205
+
  206
+    # test setting parameters
  207
+    fs.set_params(select__k=2)
  208
+    assert_equal(fs.fit_transform(X, y).shape, (X.shape[0], 4))
  209
+
  210
+
  211
+def test_feature_stacker_weights():
  212
+    # test feature stacker with transformer weights
  213
+    iris = load_iris()
  214
+    X = iris.data
  215
+    y = iris.target
  216
+    pca = RandomizedPCA(n_components=2)
  217
+    select = SelectKBest(k=1)
  218
+    fs = FeatureUnion([("pca", pca), ("select", select)],
  219
+            transformer_weights={"pca": 10})
  220
+    fs.fit(X, y)
  221
+    X_transformed = fs.transform(X)
  222
+    # check against expected result
  223
+    assert_array_almost_equal(X_transformed[:, :-1], 10 * pca.fit_transform(X))
  224
+    assert_array_equal(X_transformed[:, -1],
  225
+            select.fit_transform(X, y).ravel())
  226
+
  227
+
  228
+def test_feature_stacker_feature_names():
  229
+    JUNK_FOOD_DOCS = (
  230
+        "the pizza pizza beer copyright",
  231
+        "the pizza burger beer copyright",
  232
+        "the the pizza beer beer copyright",
  233
+        "the burger beer beer copyright",
  234
+        "the coke burger coke copyright",
  235
+        "the coke burger burger",
  236
+    )
  237
+    word_vect = CountVectorizer(analyzer="word")
  238
+    char_vect = CountVectorizer(analyzer="char_wb", ngram_range=(3, 3))
  239
+    ft = FeatureUnion([("chars", char_vect), ("words", word_vect)])
  240
+    ft.fit(JUNK_FOOD_DOCS)
  241
+    feature_names = ft.get_feature_names()
  242
+    for feat in feature_names:
  243
+        assert_true("chars__" in feat or "words__" in feat)
  244
+    assert_equal(len(feature_names), 35)
Commit_comment_tip

Tip: You can add notes to lines in a file. Hover to the left of a line to make a note

Something went wrong with that request. Please try again.