Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
Fixes #161, add support for disable_optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Aug 12, 2020
1 parent 301bddc commit 066963d
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 51 deletions.
49 changes: 26 additions & 23 deletions _unittests/ut__skl2onnx/test_sklearn_gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def remove_dim1(self, arr):

def check_outputs(self, model, model_onnx, Xtest,
predict_attributes, decimal=5,
skip_if_float32=False):
skip_if_float32=False, disable_optimisation=False):
if predict_attributes is None:
predict_attributes = {}
exp = model.predict(Xtest, **predict_attributes)
sess = OnnxInference(model_onnx)
runtime_options = dict(disable_optimisation=disable_optimisation)
sess = OnnxInference(model_onnx, runtime_options=runtime_options)
got = sess.run({'X': Xtest})
got = [got[k] for k in sess.output_names]
if isinstance(exp, tuple):
Expand Down Expand Up @@ -212,14 +213,15 @@ def test_gpr_rbf_fitted_return_std_true(self):
gp, initial_types=[('X', FloatTensorType([None, None]))],
options=options, target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
self.check_outputs(gp, model_onnx, Xtest_.astype(np.float32),
predict_attributes=options[
GaussianProcessRegressor],
decimal=4)
self.check_outputs(
gp, model_onnx, Xtest_.astype(np.float32),
predict_attributes=options[GaussianProcessRegressor],
decimal=4, disable_optimisation=True)
dump_data_and_model(Xtest_.astype(np.float32), gp, model_onnx,
verbose=False,
basename="SklearnGaussianProcessRBFStd-Out0",
check_error="misses a kernel")
check_error="misses a kernel",
disable_optimisation=True)

def test_gpr_rbf_fitted_return_std_exp_sine_squared_true(self):

Expand All @@ -239,10 +241,10 @@ def test_gpr_rbf_fitted_return_std_exp_sine_squared_true(self):
Xtest_.astype(np.float64), gp, model_onnx,
verbose=False,
basename="SklearnGaussianProcessExpSineSquaredStdT-Out0-Dec3",
check_error="misses a kernel")
check_error="misses a kernel", disable_optimisation=True)
self.check_outputs(gp, model_onnx, Xtest_.astype(np.float64),
predict_attributes=options[GaussianProcessRegressor],
decimal=4)
decimal=4, disable_optimisation=True)

def test_gpr_rbf_fitted_return_std_exp_sine_squared_false(self):

Expand Down Expand Up @@ -284,11 +286,11 @@ def test_gpr_rbf_fitted_return_std_exp_sine_squared_double_true(self):
dump_data_and_model(
Xtest_.astype(np.float64), gp, model_onnx,
basename="SklearnGaussianProcessExpSineSquaredStdDouble-Out0-Dec4",
check_error="misses a kernel")
self.check_outputs(gp, model_onnx, Xtest_.astype(np.float64),
predict_attributes=options[
GaussianProcessRegressor],
decimal=4)
check_error="misses a kernel", disable_optimisation=True)
self.check_outputs(
gp, model_onnx, Xtest_.astype(np.float64),
predict_attributes=options[GaussianProcessRegressor],
decimal=4, disable_optimisation=True)

def test_gpr_rbf_fitted_return_std_dot_product_true(self):

Expand All @@ -307,11 +309,11 @@ def test_gpr_rbf_fitted_return_std_dot_product_true(self):
dump_data_and_model(
Xtest_.astype(np.float64), gp, model_onnx,
basename="SklearnGaussianProcessDotProductStdDouble-Out0-Dec3",
check_error="misses a kernel")
self.check_outputs(gp, model_onnx, Xtest_.astype(np.float64),
predict_attributes=options[
GaussianProcessRegressor],
decimal=3)
check_error="misses a kernel", disable_optimisation=True)
self.check_outputs(
gp, model_onnx, Xtest_.astype(np.float64),
predict_attributes=options[GaussianProcessRegressor],
decimal=3, disable_optimisation=True)

def test_gpr_rbf_fitted_return_std_rational_quadratic_true(self):

Expand All @@ -330,10 +332,11 @@ def test_gpr_rbf_fitted_return_std_rational_quadratic_true(self):
dump_data_and_model(
Xtest_.astype(np.float64), gp, model_onnx,
basename="SklearnGaussianProcessRationalQuadraticStdDouble-Out0",
check_error="misses a kernel")
self.check_outputs(gp, model_onnx, Xtest_.astype(np.float64),
predict_attributes=options[
GaussianProcessRegressor])
check_error="misses a kernel", disable_optimisation=True)
self.check_outputs(
gp, model_onnx, Xtest_.astype(np.float64),
predict_attributes=options[GaussianProcessRegressor],
disable_optimisation=True)

def test_gpr_fitted_shapes(self):
data = load_iris()
Expand Down
9 changes: 6 additions & 3 deletions mlprodict/asv_benchmark/_create_asv_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@
"Pillow": [],
"pybind11": [],
"scipy": [],
"onnxconverter-common": ["http://localhost:8067/simple/"], # "git+https://github.com/xadupre/onnxconverter-common.git@jenkins"],
"skl2onnx": ["http://localhost:8067/simple/"], # "git+https://github.com/xadupre/sklearn-onnx.git@jenkins"],
"scikit-learn": ["http://localhost:8067/simple/"], # "git+https://github.com/scikit-learn/scikit-learn.git"],
# "git+https://github.com/xadupre/onnxconverter-common.git@jenkins"],
"onnxconverter-common": ["http://localhost:8067/simple/"],
# "git+https://github.com/xadupre/sklearn-onnx.git@jenkins"],
"skl2onnx": ["http://localhost:8067/simple/"],
# "git+https://github.com/scikit-learn/scikit-learn.git"],
"scikit-learn": ["http://localhost:8067/simple/"],
"xgboost": [],
},
"benchmark_dir": "benches",
Expand Down
13 changes: 9 additions & 4 deletions mlprodict/onnxrt/onnx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class OnnxInference:
def __init__(self, onnx_or_bytes_or_stream, runtime=None,
skip_run=False, inplace=True,
input_inplace=False, ir_version=None,
target_opset=None):
target_opset=None, runtime_options=None):
"""
@param onnx_or_bytes_or_stream :epkg:`onnx` object,
bytes, or filename or stream
Expand All @@ -58,6 +58,7 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None,
<mlprodict.onnxrt.onnx_inference.OnnxInference._guess_inplace>`
@param ir_version if not None, overwrite the default version
@param target_opset used to overwrite *target_opset*
@param runtime_options specific options for the runtime
"""
if isinstance(onnx_or_bytes_or_stream, bytes):
self.obj = load_model(BytesIO(onnx_or_bytes_or_stream))
Expand All @@ -80,6 +81,7 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None,
self.input_inplace = input_inplace
self.inplace = inplace
self.force_target_opset = target_opset
self.runtime_options = runtime_options
self._init()

def __getstate__(self):
Expand Down Expand Up @@ -128,7 +130,8 @@ def _init(self):
# Loads the onnx with onnxruntime as a single file.
del self.graph_
from .ops_whole.session import OnnxWholeSession
self._whole = OnnxWholeSession(self.obj, self.runtime)
self._whole = OnnxWholeSession(
self.obj, self.runtime, self.runtime_options)
self._run = self._run_whole_runtime
else:
self.sequence_ = self.graph_['sequence']
Expand All @@ -141,11 +144,13 @@ def _init(self):
if self.runtime == 'onnxruntime2':
node.setup_runtime(self.runtime, variables, self.__class__,
target_opset=target_opset, dtype=dtype,
domain=domain, ir_version=self.ir_version_)
domain=domain, ir_version=self.ir_version_,
runtime_options=self.runtime_options)
else:
node.setup_runtime(self.runtime, variables, self.__class__,
target_opset=target_opset, domain=domain,
ir_version=self.ir_version_)
ir_version=self.ir_version_,
runtime_options=self.runtime_options)
if hasattr(node, 'ops_') and hasattr(node.ops_, 'typed_outputs_'):
for k, v in node.ops_.typed_outputs_:
variables[k] = v
Expand Down
5 changes: 4 additions & 1 deletion mlprodict/onnxrt/onnx_inference_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __repr__(self):

def setup_runtime(self, runtime=None, variables=None, rt_class=None,
target_opset=None, dtype=None, domain=None,
ir_version=None):
ir_version=None, runtime_options=None):
"""
Loads runtime.
Expand All @@ -85,6 +85,7 @@ def setup_runtime(self, runtime=None, variables=None, rt_class=None,
@param domain node domain
@param ir_version if not None, changes the default value
given by :epkg:`ONNX`
@param runtime_options runtime options
"""
if self.desc is None:
raise AttributeError(
Expand All @@ -98,6 +99,8 @@ def setup_runtime(self, runtime=None, variables=None, rt_class=None,
options['target_opset'] = target_opset
if ir_version is not None:
options['ir_version'] = ir_version
if runtime_options is not None:
options.update(runtime_options)
if runtime == 'onnxruntime2':
self.ops_ = load_op(self.onnx_node, desc=self.desc,
options=options if options else None,
Expand Down
12 changes: 9 additions & 3 deletions mlprodict/onnxrt/ops_onnxruntime/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import onnx.defs
from onnx.helper import make_tensor
from onnx import TensorProto
from onnxruntime import InferenceSession, SessionOptions, RunOptions
from onnxruntime import (
InferenceSession, SessionOptions, RunOptions, GraphOptimizationLevel)
try:
from onnxruntime.capi.onnxruntime_pybind11_state import (
InvalidArgument as OrtInvalidArgument,
Expand Down Expand Up @@ -112,7 +113,9 @@ def _init(self, variables=None):
options = self.options.copy()
target_opset = options.pop('target_opset', None)
domain = options.pop('domain', None)
disable_optimisation = options.pop('disable_optimisation', False)
ir_version = options.pop('ir_version', None)

if domain == '' and target_opset < 9:
# target_opset should be >= 9 not {} for main domain.
# We assume it was the case when the graph was created.
Expand Down Expand Up @@ -229,9 +232,12 @@ def _init(self, variables=None):
pass
if ir_version is not None:
self.onnx_.ir_version = ir_version
if disable_optimisation:
sess_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL)
try:
self.sess_ = InferenceSession(self.onnx_.SerializeToString(),
sess_options=sess_options)
self.sess_ = InferenceSession(
self.onnx_.SerializeToString(), sess_options=sess_options)
except (RuntimeError, OrtNotImplemented, OrtInvalidGraph, OrtFail) as e:
raise RuntimeError("Unable to load node '{}' (output type was {})\n{}".format(
self.onnx_node.op_type, "guessed" if forced else "inferred",
Expand Down
13 changes: 9 additions & 4 deletions mlprodict/onnxrt/ops_whole/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
"""
from io import BytesIO
import onnx
from onnxruntime import InferenceSession, SessionOptions, RunOptions
from onnxruntime import (
InferenceSession, SessionOptions, RunOptions, GraphOptimizationLevel)
try:
from onnxruntime.capi.onnxruntime_pybind11_state import (
Fail as OrtFail,
InvalidGraph as OrtInvalidGraph,
InvalidArgument as OrtInvalidArgument,
NotImplemented as OrtNotImplemented,
RuntimeException as OrtRuntimeException,
)
RuntimeException as OrtRuntimeException)
except ImportError: # pragma: no cover
OrtFail = Exception
OrtNotImplemented = RuntimeError
Expand All @@ -29,11 +29,12 @@ class OnnxWholeSession:
it lets the runtime handle the graph logic as well.
"""

def __init__(self, onnx_data, runtime):
def __init__(self, onnx_data, runtime, runtime_options=None):
"""
@param onnx_data :epkg:`ONNX` model or data
@param runtime runtime to be used,
mostly :epkg:`onnxruntime`
@param runtime_options runtime options
"""
if runtime != 'onnxruntime1':
raise NotImplementedError( # pragma: no cover
Expand All @@ -53,6 +54,10 @@ def __init__(self, onnx_data, runtime):
except AttributeError: # pragma: no cover
# onnxruntime not recent enough.
pass
if (runtime_options is not None and
runtime_options.get('disable_optimisation', False)):
sess_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL)
try:
self.sess = InferenceSession(onnx_data, sess_options=sess_options)
except (OrtFail, OrtNotImplemented, OrtInvalidGraph,
Expand Down
10 changes: 7 additions & 3 deletions mlprodict/testing/test_utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def dump_data_and_model( # pylint: disable=R0912
context=None, allow_failure=None, methods=None,
dump_error_log=None, benchmark=None, comparable_outputs=None,
intermediate_steps=False, fail_evenif_notimplemented=False,
verbose=False, classes=None, check_error=None):
verbose=False, classes=None, check_error=None, disable_optimisation=False):
"""
Saves data with pickle, saves the model with pickle and *onnx*,
runs and saves the predictions for the given model.
Expand Down Expand Up @@ -199,6 +199,8 @@ def dump_data_and_model( # pylint: disable=R0912
(only for classifier, mandatory if option 'nocl' is used)
:param check_error: do not raise an exception if the error message
contains this text
:param disable_optimisation: disable all optimisations *onnxruntime*
could do
:return: the created files
Some convention for the name,
Expand Down Expand Up @@ -367,14 +369,16 @@ def dump_data_and_model( # pylint: disable=R0912
b, runtime_test, options=extract_options(basename),
context=context, verbose=verbose,
comparable_outputs=comparable_outputs,
intermediate_steps=intermediate_steps)
intermediate_steps=intermediate_steps,
disable_optimisation=disable_optimisation)
elif check_error:
try:
output, lambda_onnx = compare_backend(
b, runtime_test, options=extract_options(basename),
context=context, verbose=verbose,
comparable_outputs=comparable_outputs,
intermediate_steps=intermediate_steps)
intermediate_steps=intermediate_steps,
disable_optimisation=disable_optimisation)
except Exception as e: # pragma: no cover
if check_error in str(e):
warnings.warn(str(e))
Expand Down
11 changes: 8 additions & 3 deletions mlprodict/testing/test_utils/utils_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

def compare_backend(backend, test, decimal=5, options=None, verbose=False,
context=None, comparable_outputs=None,
intermediate_steps=False, classes=None):
intermediate_steps=False, classes=None,
disable_optimisation=False):
"""
The function compares the expected output (computed with
the model before being converted to ONNX) and the ONNX output.
Expand All @@ -27,6 +28,8 @@ def compare_backend(backend, test, decimal=5, options=None, verbose=False,
:param intermediate_steps: displays intermediate steps
in case of an error
:param classes: classes names (if option 'nocl' is used)
:param disable_optimisation: disable optimisation onnxruntime
could do
The function does not return anything but raises an error
if the comparison failed.
Expand All @@ -36,10 +39,12 @@ def compare_backend(backend, test, decimal=5, options=None, verbose=False,
return compare_runtime_ort(
test, decimal, options=options, verbose=verbose,
comparable_outputs=comparable_outputs,
intermediate_steps=False, classes=classes)
intermediate_steps=False, classes=classes,
disable_optimisation=disable_optimisation)
if backend == "python":
return compare_runtime_pyrt(
test, decimal, options=options, verbose=verbose,
comparable_outputs=comparable_outputs,
intermediate_steps=intermediate_steps, classes=classes)
intermediate_steps=intermediate_steps, classes=classes,
disable_optimisation=disable_optimisation)
raise ValueError("Does not support backend '{0}'.".format(backend))
11 changes: 9 additions & 2 deletions mlprodict/testing/test_utils/utils_backend_common_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
def compare_runtime_session( # pylint: disable=R0912
cls_session, test, decimal=5, options=None,
verbose=False, context=None, comparable_outputs=None,
intermediate_steps=False, classes=None):
intermediate_steps=False, classes=None,
disable_optimisation=False):
"""
The function compares the expected output (computed with
the model before being converted to ONNX) and the ONNX output
Expand All @@ -40,6 +41,7 @@ def compare_runtime_session( # pylint: disable=R0912
:param intermediate_steps: displays intermediate steps
in case of an error
:param classes: classes names (if option 'nocl' is used)
:param disable_optimisation: disable optimisation the runtime may do
:return: tuple (outut, lambda function to run the predictions)
The function does not return anything but raises an error
Expand All @@ -53,6 +55,7 @@ def compare_runtime_session( # pylint: disable=R0912
print("[compare_runtime] test '{}' loaded".format(test['onnx']))

onx = test['onnx']

if options is None:
if isinstance(onx, str):
options = extract_options(onx)
Expand All @@ -67,8 +70,12 @@ def compare_runtime_session( # pylint: disable=R0912
if verbose: # pragma no cover
print("[compare_runtime] InferenceSession('{}')".format(onx))

runtime_options = dict(disable_optimisation=disable_optimisation)
try:
sess = cls_session(onx)
sess = cls_session(onx, runtime_options=runtime_options)
except TypeError as e:
raise TypeError(
"Wrong signature for '{}'.".format(cls_session.__name__))
except ExpectedAssertionError as expe: # pragma no cover
raise expe
except Exception as e: # pylint: disable=W0703
Expand Down
Loading

0 comments on commit 066963d

Please sign in to comment.