Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
Fixes issue #9, implements onnxruntime runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jun 16, 2019
1 parent 4f26141 commit 11c2d7c
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 5 deletions.
1 change: 1 addition & 0 deletions .local.jenkins.lin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ virtualenv:

install:
- $PYINT -c "from pip._internal import main;main(\"install --no-cache-dir --no-deps --index http://localhost:8067/simple/ jyquickhelper pyquickhelper --extra-index-url=https://pypi.python.org/simple/\".split())"
- $PYINT -c "from pip._internal import main;main(\"install --no-cache-dir --no-deps --index http://localhost:8067/simple/ onnx onnxruntime skl2onnx onnxmltools onnxverter_common scikit-onnxruntime --extra-index-url=https://pypi.python.org/simple/\".split())"
- $PYINT -c "from pip._internal import main;main(\"install -r requirements.txt\".split())"
- $PYINT --version
- $PYINT -c "from pip._internal import main;main([\"freeze\"])"
Expand Down
54 changes: 54 additions & 0 deletions _doc/sphinxdoc/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
ONNX
====

*mlprodict* implements two runtimes.
The first uses :epkg:`numpy` and implements
mathematical functions defined by :epkg:`ONNX`.
The second one leverages :epkg:`onnxruntime`.

.. contents::
:local:

Python Runtime
++++++++++++++

This module implements a python runtime for :epkg:`ONNX`.
It is a work constantly in progress. It was started to
facilitate the implementation of :epkg:`scikit-learn`
Expand Down Expand Up @@ -107,3 +118,46 @@ what is working.
print(df2rst(piv))

build_table()

onnxruntime
+++++++++++

This runtime does not load the :epkg:`ONNX` in a single
session but instead calls :epkg:`onnxruntime` for each node
independently. This was developped mostly to facilitate
the implementation of converters from :epkg:`scikit-learn`
object to :epkg:`ONNX`. We create the same table.

.. runpython::
:showcode:
:rst:
:warningout: PendingDeprecationWarning UserWarning RuntimeWarning

from logging import getLogger
from pyquickhelper.loghelper import noLOG
from pandas import DataFrame
from pyquickhelper.pandashelper import df2rst
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.testing import ignore_warnings
from mlprodict.onnxrt.validate import validate_operator_opsets, summary_report

@ignore_warnings(category=(UserWarning, ConvergenceWarning, RuntimeWarning, FutureWarning))
def build_table():
logger = getLogger('skl2onnx')
logger.disabled = True
rows = validate_operator_opsets(0, debug=None, fLOG=noLOG, runtime='onnxruntime')
df = DataFrame(rows)
piv = summary_report(df)

if "ERROR-msg" in piv.columns:
def shorten(text):
text = str(text)
if len(text) > 75:
text = text[:75] + "..."
return text

piv["ERROR-msg"] = piv["ERROR-msg"].apply(shorten)

print(df2rst(piv))

build_table()
30 changes: 30 additions & 0 deletions _unittests/ut_onnxrt/test_onnxrt_onnxruntime_runtime_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
@brief test log(time=2s)
"""
import unittest
from logging import getLogger
import numpy
from pyquickhelper.pycode import ExtTestCase
from skl2onnx.algebra.onnx_ops import OnnxAdd # pylint: disable=E0611
from mlprodict.onnxrt import OnnxInference


class TestOnnxrtOnnxRuntimeRuntime(ExtTestCase):

def setUp(self):
logger = getLogger('skl2onnx')
logger.disabled = True

def test_onnxt_runtime_add(self):
idi = numpy.identity(2)
onx = OnnxAdd('X', idi, output_names=['Y'])
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)})
X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float32)
oinf = OnnxInference(model_def, runtime='onnxruntime')
got = oinf.run({'X': X})
self.assertEqual(list(sorted(got)), ['Y'])
self.assertEqualArray(idi + X, got['Y'], decimal=6)


if __name__ == "__main__":
unittest.main()
44 changes: 44 additions & 0 deletions _unittests/ut_onnxrt/test_onnxrt_validate_onnxruntime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
@brief test log(time=20s)
"""
import os
import unittest
from logging import getLogger
from pandas import DataFrame
from pyquickhelper.loghelper import fLOG
from pyquickhelper.pycode import get_temp_folder, ExtTestCase
from mlprodict.onnxrt.validate import sklearn_operators, validate_operator_opsets


class TestOnnxrtValidateOnnxRuntime(ExtTestCase):

def test_sklearn_operators(self):
res = sklearn_operators()
self.assertGreater(len(res), 1)
self.assertEqual(len(res[0]), 3)

def test_validate_sklearn_operators_all_onnxruntime(self):
fLOG(__file__, self._testMethodName, OutputPrint=__name__ == "__main__")
logger = getLogger('skl2onnx')
logger.disabled = True
verbose = 1 if __name__ == "__main__" else 0
if False: # pylint: disable=W0125
rows = validate_operator_opsets(
verbose, debug={"LinearRegression"}, opset_min=10, fLOG=fLOG,
runtime='onnxruntime')
else:
rows = validate_operator_opsets(verbose, debug=None, fLOG=fLOG,
runtime='onnxruntime')
self.assertGreater(len(rows), 1)
df = DataFrame(rows)
self.assertGreater(df.shape[1], 1)
temp = get_temp_folder(
__file__, "temp_validate_sklearn_operators_all_onnxruntime")
fLOG("output results")
df.to_csv(os.path.join(temp, "sklearn_opsets_report.csv"), index=False)
df.to_excel(os.path.join(
temp, "sklearn_opsets_report.xlsx"), index=False)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion mlprodict/onnxrt/onnx_inference_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def setup_runtime(self, runtime=None):
if self.desc is None:
raise AttributeError("desc should not be None.")
self.ops_ = load_op(self.onnx_node, desc=self.desc,
options=runtime)
options={'provider': runtime} if runtime else None)

def run(self, values):
"""
Expand Down
3 changes: 3 additions & 0 deletions mlprodict/onnxrt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,8 @@ def load_op(onnx_node, desc=None, options=None):
if provider == 'CPU':
from .ops_cpu import load_op as lo
return lo(onnx_node, desc=desc, options=options)
elif provider == 'onnxruntime':
from .ops_onnxruntime import load_op as lo
return lo(onnx_node, desc=desc, options=options)
else:
raise ValueError("Unable to handle provider '{}'.".format(provider))
20 changes: 20 additions & 0 deletions mlprodict/onnxrt/ops_onnxruntime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- encoding: utf-8 -*-
"""
@file
@brief Shortcut to *ops_cpu*.
"""
from ._op import OpRunOnnxRuntime


def load_op(onnx_node, desc=None, options=None):
"""
Gets the operator related to the *onnx* node.
@param onnx_node :epkg:`onnx` node
@param desc internal representation
@param options runtime options
@return runtime class
"""
if desc is None:
raise ValueError("desc should not be None.")
return OpRunOnnxRuntime(onnx_node, desc, **options)
69 changes: 69 additions & 0 deletions mlprodict/onnxrt/ops_onnxruntime/_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- encoding: utf-8 -*-
"""
@file
@brief Shortcut to *ops_onnxruntime*.
"""
import numpy
import onnx.defs
from onnxruntime import InferenceSession
import skl2onnx.algebra.onnx_ops as alg
from skl2onnx.common.data_types import FloatTensorType


_schemas = {
schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()}


class OpRunOnnxRuntime:
"""
Unique operator which calls :epkg:`onnxruntime`
to compute predictions for one operator.
"""

def __init__(self, onnx_node, desc=None, **options):
"""
@param onnx_node :epkg:`onnx` node
@param desc internal representation
@param options runtime options
"""
self._provider = 'onnxruntime'
self.onnx_node = onnx_node
self.desc = desc
self._schema = _schemas[onnx_node.op_type]
if desc is not None:
if 'atts' in desc:
for a, b in desc['atts'].items():
if not isinstance(b, dict) or 'value' not in b:
raise ValueError("Unexpected value {}.".format(b))
options[a] = b['value']

self.options = options
self._init()

def _init(self):
"""
Initializes the node.
"""
self.alg_class = getattr(alg, 'Onnx' + self.onnx_node.op_type)
self.inputs = list(self.onnx_node.input)
self.outputs = list(self.onnx_node.output)
self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
**self.options)
inputs = [(name, FloatTensorType()) for name in self.inputs]
outputs = [(name, FloatTensorType()) for name in self.outputs]
self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs)
self.sess_ = InferenceSession(self.onnx_.SerializeToString())

def run(self, *args, **kwargs):
"""
Should be overwritten.
"""
def f32(X):
if hasattr(X, 'dtype') and X.dtype == numpy.float64:
return X.astype(numpy.float32)
else:
return X

inputs = {name: f32(val) for name, val in zip(self.inputs, args)}
res = self.sess_.run(None, inputs)
return tuple(res)
11 changes: 7 additions & 4 deletions mlprodict/onnxrt/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _measure_absolute_difference(skl_pred, ort_pred):

def enumerate_compatible_opset(model, opset_min=9, opset_max=None,
check_runtime=True, debug=False,
fLOG=print):
runtime='CPU', fLOG=print):
"""
Lists all compatiable opsets for a specific model.
Expand All @@ -400,6 +400,7 @@ def enumerate_compatible_opset(model, opset_min=9, opset_max=None,
@param check_runtime checks that runtime can consume the
model and compute predictions
@param debug catch exception (True) or not (False)
@param runtime test a specific runtime, by default ``'CPU'``
@param fLOG logging function
@return dictionaries, each row has the following
keys: opset, exception if any, conversion time,
Expand Down Expand Up @@ -524,7 +525,8 @@ def fct_skl(itt=inst, it=init_types[0][1], ops=opset): # pylint: disable=W0102

# load
try:
sess, t6 = _measure_time(lambda: OnnxInference(ser))
sess, t6 = _measure_time(
lambda: OnnxInference(ser, runtime=runtime))
obs_op['tostring_time'] = t6
except (RuntimeError, ValueError) as e:
if debug:
Expand Down Expand Up @@ -620,7 +622,7 @@ def fct_single(se=sess, xo=Xort_test, it=init_types): # pylint: disable=W0102

@ignore_warnings(category=(UserWarning, ConvergenceWarning, RuntimeWarning))
def validate_operator_opsets(verbose=0, opset_min=9, opset_max=None,
check_runtime=True, debug=None,
check_runtime=True, debug=None, runtime='CPU',
fLOG=print):
"""
Tests all possible configuration for all possible
Expand All @@ -633,6 +635,7 @@ def validate_operator_opsets(verbose=0, opset_min=9, opset_max=None,
@param check_runtime checks the python runtime
@param debug only checks a small list of operators,
set of model names
@param runtime test a specific runtime, by default ``'CPU'``
@param fLOG logging function
@return list of dictionaries
"""
Expand Down Expand Up @@ -668,7 +671,7 @@ def iterate():

for obs in enumerate_compatible_opset(
model, opset_min=opset_min, opset_max=opset_max,
check_runtime=check_runtime,
check_runtime=check_runtime, runtime=runtime,
debug=debug is not None, fLOG=fLOG):
if verbose > 1:
fLOG(" ", obs)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ coverage
guzzle_sphinx_theme
jyquickhelper
onnx
onnxruntime
openpyxl
pylint
pyquickhelper>=1.9
Expand Down

0 comments on commit 11c2d7c

Please sign in to comment.