Skip to content

Commit

Permalink
Fixes #60, visualization of a pipeline with passthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jul 9, 2019
1 parent 3100027 commit fff9212
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 2 deletions.
9 changes: 9 additions & 0 deletions README.rst
Expand Up @@ -53,3 +53,12 @@ which partitions the data before fitting a model on each bucket.
* `GitHub/mlinsights <https://github.com/sdpython/mlinsights/>`_
* `documentation <http://www.xavierdupre.fr/app/mlinsights/helpsphinx/index.html>`_
* `Blog <http://www.xavierdupre.fr/app/mlinsights/helpsphinx/blog/main_0000.html#ap-main-0>`_

Function ``pipeline2dot`` converts a pipeline into a graph:

::

from mlinsights.plotting import pipeline2dot
dot = pipeline2dot(clf, df)

.. image:: https://github.com/sdpython/mlinsights/raw/master/_doc/pipeline.png
Binary file added _doc/pipeline.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion _doc/sphinxdoc/source/HISTORY.rst
Expand Up @@ -55,4 +55,4 @@ current - 2019-05-23 - 0.00Mb
* :issue:`7`: add quantile regression (2018-05-07)
* :issue:`5`: replace flake8 by code style (2018-04-14)
* :issue:`1`: change background for cells in notebooks converted into rst then in html, highlight-ipython3 (2018-01-05)
* :issue:`2`: save features and metadatas for the search engine and retrieves them (2017-12-03)
* :issue:`2`: save features and metadatas for the search engine and retrieves them (2017-12-03)
50 changes: 49 additions & 1 deletion _unittests/ut_plotting/test_dot.py
Expand Up @@ -3,12 +3,15 @@
@brief test log(time=2s)
"""
import unittest
from io import StringIO
from textwrap import dedent
import pandas
from sklearn import datasets
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder
from pyquickhelper.pycode import ExtTestCase
Expand Down Expand Up @@ -91,6 +94,51 @@ def test_union_features(self):
self.assertIn("StandardScaler", dot)
self.assertIn("MinMaxScaler", dot)

def test_onehotencoder_dot(self):
data = dedent("""
date,value,notrend,trend,weekday,lag1,lag2,lag3,lag4,lag5,lag6,lag7,lag8
2017-07-10 13:27:04.669830,0.003463591425601385,0.0004596547917981044,0.0030039366338032807,
###0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2017-07-11 13:27:04.669830,0.004411953385609647,0.001342107238927262,0.003069846146682385,1,
###0.003463591425601385,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2017-07-12 13:27:04.669830,0.004277700876279705,0.0011426168863444912,0.0031350839899352135,2,
###0.004411953385609647,0.003463591425601385,0.0,0.0,0.0,0.0,0.0,0.0
2017-07-13 13:27:04.669830,0.006078151848127084,0.0028784976490072987,0.003199654199119785,3,
###0.004277700876279705,0.004411953385609647,0.003463591425601385,0.0,0.0,0.0,0.0,0.0
2017-07-14 13:27:04.669830,0.006336617719481035,0.003073056920386795,0.0032635607990942395,
###4,0.006078151848127084,0.004277700876279705,0.004411953385609647,0.003463591425601385,0.0,0.0,0.0,0.0
2017-07-15 13:27:04.669830,0.008716378909294038,0.0053895711052771985,0.0033268078040168394,5,
###0.006336617719481035,0.006078151848127084,0.004277700876279705,0.004411953385609647,0.003463591425601385,0.0,0.0,0.0
2017-07-17 13:27:04.669830,0.0035533180858140765,0.00010197905397394454,0.003451339031840132,0,
###0.008716378909294038,0.006336617719481035,0.006078151848127084,0.004277700876279705,0.004411953385609647,0.003463591425601385,0.0,0.0
2017-07-18 13:27:04.669830,0.0038464710972236286,0.0003338398676656705,0.003512631229557958,
###1,0.0035533180858140765,0.008716378909294038,0.006336617719481035,0.006078151848127084,0.004277700876279705,
###0.004411953385609647,0.003463591425601385,0.0
2017-07-19 13:27:04.669830,0.004200435956007872,0.0006271561741496745,0.003573279781858197,2,0.0038464710972236286,
###0.0035533180858140765,0.008716378909294038,0.006336617719481035,0.006078151848127084,0.004277700876279705,
###0.004411953385609647,0.003463591425601385
2017-07-20 13:27:04.669830,0.004773874566436903,0.0011405859170371827,0.00363328864939972,3,0.004200435956007872,
###0.0038464710972236286,0.0035533180858140765,0.008716378909294038,0.006336617719481035,0.006078151848127084,
###0.004277700876279705,0.004411953385609647
2017-07-21 13:27:04.669830,0.005866058541412791,0.00217339675927127,0.0036926617821415207,4,0.004773874566436903,
###0.004200435956007872,0.0038464710972236286,0.0035533180858140765,0.008716378909294038,0.006336617719481035,
###0.006078151848127084,0.004277700876279705
""").replace("\n###", "")
df = pandas.read_csv(StringIO(data))
cols = ['lag1', 'lag2', 'lag3',
'lag4', 'lag5', 'lag6', 'lag7', 'lag8']
model = make_pipeline(
make_pipeline(
ColumnTransformer(
[('pass', "passthrough", cols),
("dummies", OneHotEncoder(), ["weekday"])]),
PCA(n_components=2)),
LinearRegression())
train_cols = cols + ['weekday']
model.fit(df, df[train_cols])
dot = pipeline2dot(model, df)
self.assertIn('label="Identity"', dot)


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions mlinsights/plotting/visualize.py
Expand Up @@ -110,6 +110,16 @@ def _get_name(context, prefix='-v-'):
'type': 'transform'}, info]
return info

elif isinstance(pipe, str):
if pipe == "passthrough":
info = {'name': 'Identity', 'type': 'transform'}
info['outputs'] = data
info['inputs'] = data
info = [info]
else:
raise NotImplementedError(
"Not yet implemented for keyword '{}'.".format(type(pipe)))
return info
else:
raise NotImplementedError(
"Not yet implemented for {}.".format(type(pipe)))
Expand Down

0 comments on commit fff9212

Please sign in to comment.