This repository has been archived by the owner on Jan 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes issue #9, implements onnxruntime runtime
- Loading branch information
Showing
10 changed files
with
230 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
@brief test log(time=2s) | ||
""" | ||
import unittest | ||
from logging import getLogger | ||
import numpy | ||
from pyquickhelper.pycode import ExtTestCase | ||
from skl2onnx.algebra.onnx_ops import OnnxAdd # pylint: disable=E0611 | ||
from mlprodict.onnxrt import OnnxInference | ||
|
||
|
||
class TestOnnxrtOnnxRuntimeRuntime(ExtTestCase): | ||
|
||
def setUp(self): | ||
logger = getLogger('skl2onnx') | ||
logger.disabled = True | ||
|
||
def test_onnxt_runtime_add(self): | ||
idi = numpy.identity(2) | ||
onx = OnnxAdd('X', idi, output_names=['Y']) | ||
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}) | ||
X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float32) | ||
oinf = OnnxInference(model_def, runtime='onnxruntime') | ||
got = oinf.run({'X': X}) | ||
self.assertEqual(list(sorted(got)), ['Y']) | ||
self.assertEqualArray(idi + X, got['Y'], decimal=6) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
@brief test log(time=20s) | ||
""" | ||
import os | ||
import unittest | ||
from logging import getLogger | ||
from pandas import DataFrame | ||
from pyquickhelper.loghelper import fLOG | ||
from pyquickhelper.pycode import get_temp_folder, ExtTestCase | ||
from mlprodict.onnxrt.validate import sklearn_operators, validate_operator_opsets | ||
|
||
|
||
class TestOnnxrtValidateOnnxRuntime(ExtTestCase): | ||
|
||
def test_sklearn_operators(self): | ||
res = sklearn_operators() | ||
self.assertGreater(len(res), 1) | ||
self.assertEqual(len(res[0]), 3) | ||
|
||
def test_validate_sklearn_operators_all_onnxruntime(self): | ||
fLOG(__file__, self._testMethodName, OutputPrint=__name__ == "__main__") | ||
logger = getLogger('skl2onnx') | ||
logger.disabled = True | ||
verbose = 1 if __name__ == "__main__" else 0 | ||
if False: # pylint: disable=W0125 | ||
rows = validate_operator_opsets( | ||
verbose, debug={"LinearRegression"}, opset_min=10, fLOG=fLOG, | ||
runtime='onnxruntime') | ||
else: | ||
rows = validate_operator_opsets(verbose, debug=None, fLOG=fLOG, | ||
runtime='onnxruntime') | ||
self.assertGreater(len(rows), 1) | ||
df = DataFrame(rows) | ||
self.assertGreater(df.shape[1], 1) | ||
temp = get_temp_folder( | ||
__file__, "temp_validate_sklearn_operators_all_onnxruntime") | ||
fLOG("output results") | ||
df.to_csv(os.path.join(temp, "sklearn_opsets_report.csv"), index=False) | ||
df.to_excel(os.path.join( | ||
temp, "sklearn_opsets_report.xlsx"), index=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# -*- encoding: utf-8 -*- | ||
""" | ||
@file | ||
@brief Shortcut to *ops_cpu*. | ||
""" | ||
from ._op import OpRunOnnxRuntime | ||
|
||
|
||
def load_op(onnx_node, desc=None, options=None): | ||
""" | ||
Gets the operator related to the *onnx* node. | ||
@param onnx_node :epkg:`onnx` node | ||
@param desc internal representation | ||
@param options runtime options | ||
@return runtime class | ||
""" | ||
if desc is None: | ||
raise ValueError("desc should not be None.") | ||
return OpRunOnnxRuntime(onnx_node, desc, **options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# -*- encoding: utf-8 -*- | ||
""" | ||
@file | ||
@brief Shortcut to *ops_onnxruntime*. | ||
""" | ||
import numpy | ||
import onnx.defs | ||
from onnxruntime import InferenceSession | ||
import skl2onnx.algebra.onnx_ops as alg | ||
from skl2onnx.common.data_types import FloatTensorType | ||
|
||
|
||
_schemas = { | ||
schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()} | ||
|
||
|
||
class OpRunOnnxRuntime: | ||
""" | ||
Unique operator which calls :epkg:`onnxruntime` | ||
to compute predictions for one operator. | ||
""" | ||
|
||
def __init__(self, onnx_node, desc=None, **options): | ||
""" | ||
@param onnx_node :epkg:`onnx` node | ||
@param desc internal representation | ||
@param options runtime options | ||
""" | ||
self._provider = 'onnxruntime' | ||
self.onnx_node = onnx_node | ||
self.desc = desc | ||
self._schema = _schemas[onnx_node.op_type] | ||
if desc is not None: | ||
if 'atts' in desc: | ||
for a, b in desc['atts'].items(): | ||
if not isinstance(b, dict) or 'value' not in b: | ||
raise ValueError("Unexpected value {}.".format(b)) | ||
options[a] = b['value'] | ||
|
||
self.options = options | ||
self._init() | ||
|
||
def _init(self): | ||
""" | ||
Initializes the node. | ||
""" | ||
self.alg_class = getattr(alg, 'Onnx' + self.onnx_node.op_type) | ||
self.inputs = list(self.onnx_node.input) | ||
self.outputs = list(self.onnx_node.output) | ||
self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, | ||
**self.options) | ||
inputs = [(name, FloatTensorType()) for name in self.inputs] | ||
outputs = [(name, FloatTensorType()) for name in self.outputs] | ||
self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs) | ||
self.sess_ = InferenceSession(self.onnx_.SerializeToString()) | ||
|
||
def run(self, *args, **kwargs): | ||
""" | ||
Should be overwritten. | ||
""" | ||
def f32(X): | ||
if hasattr(X, 'dtype') and X.dtype == numpy.float64: | ||
return X.astype(numpy.float32) | ||
else: | ||
return X | ||
|
||
inputs = {name: f32(val) for name, val in zip(self.inputs, args)} | ||
res = self.sess_.run(None, inputs) | ||
return tuple(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ coverage | |
guzzle_sphinx_theme | ||
jyquickhelper | ||
onnx | ||
onnxruntime | ||
openpyxl | ||
pylint | ||
pyquickhelper>=1.9 | ||
|