Permalink
Browse files

ENH: pipeline doctest style improvements

  • Loading branch information...
2 parents 3aa87a6 + 7f357af commit d53352fbb1d0f020b1818fb0f47f4d4032cd2dad @ogrisel ogrisel committed Aug 24, 2011
Showing with 12 additions and 7 deletions.
  1. +12 −7 scikits/learn/pipeline.py
View
@@ -59,11 +59,13 @@ class Pipeline(BaseEstimator):
>>> from scikits.learn import svm
>>> from scikits.learn.datasets import samples_generator
- >>> from scikits.learn.feature_selection import SelectKBest, f_regression
+ >>> from scikits.learn.feature_selection import SelectKBest
+ >>> from scikits.learn.feature_selection import f_regression
>>> from scikits.learn.pipeline import Pipeline
>>> # generate some data to play with
- >>> X, y = samples_generator.make_classification(n_informative=5, n_redundant=0)
+ >>> X, y = samples_generator.make_classification(
+ ... n_informative=5, n_redundant=0, random_state=42)
>>> # ANOVA SVM-C
>>> anova_filter = SelectKBest(f_regression, k=5)
@@ -73,10 +75,13 @@ class Pipeline(BaseEstimator):
>>> # You can set the parameters using the names issued
>>> # For instance, fit using a k of 10 in the SelectKBest
>>> # and a parameter 'C' of the svn
- >>> anova_svm.fit(X, y, anova__k=10, svc__C=.1) #doctest: +ELLIPSIS
- Pipeline(steps=[('anova', SelectKBest(k=10, score_func=<function f_regression at ...>)), ('svc', SVC(C=0.1, coef0=0.0, degree=3, gamma=0.0, kernel='linear', probability=False,
- shrinking=True, tol=0.001))])
- >>> score = anova_svm.score(X)
+ >>> anova_svm.set_params(anova__k=10, svc__C=.1).fit(X, y)
+ ... # doctest: +ELLIPSIS
+ Pipeline(steps=[...])
+
+ >>> prediction = anova_svm.predict(X)
+ >>> anova_svm.score(X, y)
+ 0.75
"""
#--------------------------------------------------------------------------
@@ -126,7 +131,7 @@ def _get_params(self, deep=True):
#--------------------------------------------------------------------------
def _pre_transform(self, X, y=None, **fit_params):
- fit_params_steps = {step: {} for step, _ in self.steps}
+ fit_params_steps = dict((step, {}) for step, _ in self.steps)
for pname, pval in fit_params.iteritems():
step, param = pname.split('__', 1)
fit_params_steps[step][param] = pval

0 comments on commit d53352f

Please sign in to comment.