diff --git a/_unittests/ut_tools/test_export_onnx.py b/_unittests/ut_tools/test_export_onnx.py index 6ecd902ca..ed6478bef 100644 --- a/_unittests/ut_tools/test_export_onnx.py +++ b/_unittests/ut_tools/test_export_onnx.py @@ -12,6 +12,7 @@ from onnx.helper import ( make_model, make_node, set_model_props, make_tensor, make_graph, make_tensor_value_info) +from onnxruntime import SessionOptions, GraphOptimizationLevel from sklearn.cluster import KMeans import autopep8 from pyquickhelper.pycode import ExtTestCase @@ -909,14 +910,17 @@ def test_export_einsum(self): r = numpy.einsum("bac,cd,def->ebc", x1, x2, x3) seq_clean = decompose_einsum_equation( "bac,cd,def->ebc", strategy='numpy', clean=True) - onx = seq_clean.to_onnx("Y", "X1", "X2", "X3", dtype=numpy.float32) + onx = seq_clean.to_onnx("Y", "X1", "X2", "X3", dtype=numpy.float32, + target_opset=15) - with self.subTest(rt='python'): - oinf = OnnxInference(onx) + with self.subTest(rt='onnxruntime1'): + opts = SessionOptions() + opts.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL + oinf = OnnxInference(onx, runtime='onnxruntime1', runtime_options=opts) rr = oinf.run({'X1': x1, 'X2': x2, 'X3': x3}) self.assertEqualArray(r, rr['Y']) - with self.subTest(rt='onnxruntime1'): - oinf = OnnxInference(onx, runtime='onnxruntime1') + with self.subTest(rt='python'): + oinf = OnnxInference(onx) rr = oinf.run({'X1': x1, 'X2': x2, 'X3': x3}) self.assertEqualArray(r, rr['Y']) diff --git a/mlprodict/onnxrt/ops_whole/session.py b/mlprodict/onnxrt/ops_whole/session.py index 8904a6017..69bf8af39 100644 --- a/mlprodict/onnxrt/ops_whole/session.py +++ b/mlprodict/onnxrt/ops_whole/session.py @@ -32,11 +32,16 @@ def __init__(self, onnx_data, runtime, runtime_options=None): "runtime '{}' is not implemented.".format(runtime)) if hasattr(onnx_data, 'SerializeToString'): onnx_data = onnx_data.SerializeToString() - session_options = ( - None if runtime_options is None - else runtime_options.get('session_options', None)) - self.runtime = runtime - sess_options = session_options or SessionOptions() + if isinstance(runtime_options, SessionOptions): + sess_options = runtime_options + session_options = None + runtime_options = None + else: + session_options = ( + None if runtime_options is None + else runtime_options.get('session_options', None)) + self.runtime = runtime + sess_options = session_options or SessionOptions() self.run_options = RunOptions() if session_options is None: @@ -56,11 +61,11 @@ def __init__(self, onnx_data, runtime, runtime_options=None): GraphOptimizationLevel.ORT_ENABLE_ALL) if runtime_options.get('enable_profiling', True): sess_options.enable_profiling = True - elif 'enable_profiling' in runtime_options: + elif runtime_options is not None and 'enable_profiling' in runtime_options: raise RuntimeError( # pragma: no cover "session_options and enable_profiling cannot be defined at the " "same time.") - elif 'disable_optimisation' in runtime_options: + elif runtime_options is not None and 'disable_optimisation' in runtime_options: raise RuntimeError( # pragma: no cover "session_options and disable_optimisation cannot be defined at the " "same time.")