Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
Fixes #57, dataframe when converting model
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Sep 15, 2019
1 parent 9c3f460 commit 31c21fc
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
61 changes: 61 additions & 0 deletions _unittests/ut_onnx_conv/test_onnx_conv_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
@brief test log(time=2s)
"""
import unittest
from logging import getLogger
from io import StringIO
import numpy
import pandas
from pyquickhelper.pycode import ExtTestCase
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from mlprodict.onnx_conv import to_onnx
from mlprodict.onnxrt import OnnxInference


class TestOnnxConvDataframe(ExtTestCase):

def setUp(self):
logger = getLogger('skl2onnx')
logger.disabled = True

def test_pipeline_dataframe(self):
text = """
fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality,color
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5,red
7.8,0.88,0.0,2.6,0.098,25.0,67.0,0.9968,3.2,0.68,9.8,5,red
7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.997,3.26,0.65,9.8,5,red
11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.998,3.16,0.58,9.8,6,red
""".replace(" ", "")
X_train = pandas.read_csv(StringIO(text))
for c in X_train.columns:
if c != 'color':
X_train[c] = X_train[c].astype(numpy.float32)
numeric_features = [c for c in X_train if c != 'color']

pipe = Pipeline([
("prep", ColumnTransformer([
("color", Pipeline([
('one', OneHotEncoder()),
('select', ColumnTransformer(
[('sel1', 'passthrough', [0])]))
]), ['color']),
("others", "passthrough", numeric_features)
])),
])

pipe.fit(X_train)
model_onnx = to_onnx(pipe, X_train)
oinf = OnnxInference(model_onnx)

pred = pipe.transform(X_train)
inputs = {c: X_train[c].values for c in X_train.columns}
inputs = {c: v.reshape((v.shape[0], 1)) for c, v in inputs.items()}
onxp = oinf.run(inputs)
got = onxp['transformed_column']
self.assertEqualArray(pred, got)


if __name__ == "__main__":
unittest.main()
98 changes: 97 additions & 1 deletion mlprodict/onnx_conv/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
"""
from collections import OrderedDict
import numpy
import pandas
from sklearn.metrics.scorer import _PredictScorer
from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version
from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType, DataType
from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin
from skl2onnx.algebra.type_helper import guess_initial_types
from skl2onnx import convert_sklearn
from skl2onnx.algebra.type_helper import _guess_type
from .rewritten_converters import register_rewritten_operators
from .register import register_converters
from .scorers import CustomScorerTransform
Expand Down Expand Up @@ -67,6 +68,31 @@ def convert_scorer(fct, initial_types, name=None,
custom_parsers=custom_parsers)


def guess_initial_types(X, initial_types):
"""
Guesses initial types from an array or a dataframe.
@param X array or dataframe
@param initial_types hints about X
@return data types
"""
if X is None and initial_types is None:
raise NotImplementedError("Initial types must be specified.")
elif initial_types is None:
if isinstance(X, (numpy.ndarray, pandas.DataFrame)):
X = X[:1]
if isinstance(X, pandas.DataFrame):
initial_types = []
for c in X.columns:
g = _guess_type(X[c].values)
g.shape = [None, 1]
initial_types.append((c, g))
else:
gt = _guess_type(X)
initial_types = [('X', gt)]
return initial_types


def to_onnx(model, X=None, name=None, initial_types=None,
target_opset=None, options=None,
dtype=numpy.float32, rewrite_ops=False):
Expand Down Expand Up @@ -94,6 +120,76 @@ def to_onnx(model, X=None, name=None, initial_types=None,
For example, :epkg:`ONNX` only supports *TreeEnsembleRegressor*
for float but not for double. It becomes available
if ``dtype=numpy.float64`` and ``rewrite_ops=True``.
.. faqref::
:title: How to deal with a dataframe as input?
Each column of the dataframe is considered as an named input.
The first step is to make sure that every column type is correct.
:epkg:`pandas` tends to select the least generic type to
hold the content of one column. :epkg:`ONNX` does not automatically
cast the data it receives. The data must have the same type with
the model is converted and when the converted model receives
the data to predict.
.. runpython::
:showcode:
from io import StringIO
from textwrap import dedent
import numpy
import pandas
from pyquickhelper.pycode import ExtTestCase
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from mlprodict.onnx_conv import to_onnx
from mlprodict.onnxrt import OnnxInference
text = dedent('''
__SCHEMA__
7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5,red
7.8,0.88,0.0,2.6,0.098,25.0,67.0,0.9968,3.2,0.68,9.8,5,red
7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.997,3.26,0.65,9.8,5,red
11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.998,3.16,0.58,9.8,6,red
''')
text = text.replace(
"__SCHEMA__",
"fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,"
"free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,"
"alcohol,quality,color")
X_train = pandas.read_csv(StringIO(text))
for c in X_train.columns:
if c != 'color':
X_train[c] = X_train[c].astype(numpy.float32)
numeric_features = [c for c in X_train if c != 'color']
pipe = Pipeline([
("prep", ColumnTransformer([
("color", Pipeline([
('one', OneHotEncoder()),
('select', ColumnTransformer(
[('sel1', 'passthrough', [0])]))
]), ['color']),
("others", "passthrough", numeric_features)
])),
])
pipe.fit(X_train)
pred = pipe.transform(X_train)
print(pred)
model_onnx = to_onnx(pipe, X_train)
oinf = OnnxInference(model_onnx)
# The dataframe is converted into a dictionary,
# each key is a column name, each value is a numpy array.
inputs = {c: X_train[c].values for c in X_train.columns}
inputs = {c: v.reshape((v.shape[0], 1)) for c, v in inputs.items()}
onxp = oinf.run(inputs)
print(onxp)
"""
if isinstance(model, OnnxOperatorMixin):
if options is not None:
Expand Down

0 comments on commit 31c21fc

Please sign in to comment.