diff --git a/_doc/sphinxdoc/source/api/onnxrt.rst b/_doc/sphinxdoc/source/api/onnxrt.rst index 786f9a0e5..9297abbb8 100644 --- a/_doc/sphinxdoc/source/api/onnxrt.rst +++ b/_doc/sphinxdoc/source/api/onnxrt.rst @@ -37,6 +37,13 @@ ONNX Structure .. autosignature:: mlprodict.onnx_tools.onnx_manipulations.select_model_inputs_outputs +onnxruntime ++++++++++++ + +.. autosignature:: mlprodict.onnxrt.onnx_inference_ort.device_to_providers + +.. autosignature:: mlprodict.onnxrt.onnx_inference_ort.get_ort_device + Validation ++++++++++ diff --git a/_unittests/ut_onnxrt/test_onnxrt_iobinding.py b/_unittests/ut_onnxrt/test_onnxrt_iobinding.py new file mode 100644 index 000000000..31234afee --- /dev/null +++ b/_unittests/ut_onnxrt/test_onnxrt_iobinding.py @@ -0,0 +1,134 @@ +""" +@brief test log(time=6s) +""" +import unittest +import numpy +from pyquickhelper.pycode import ExtTestCase, ignore_warnings +from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611 + OrtDevice as C_OrtDevice, OrtValue as C_OrtValue) +from onnxruntime import get_device +from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611,W0611 + OnnxAdd) +from mlprodict.onnxrt import OnnxInference +from mlprodict.tools.onnx_inference_ort_helper import get_ort_device +from mlprodict.tools import get_opset_number_from_onnx + + +DEVICE = "cuda" if get_device().upper() == 'GPU' else 'cpu' + + +class TestOnnxrtIOBinding(ExtTestCase): + + @ignore_warnings(DeprecationWarning) + def test_onnxt_cpu_numpy_python(self): + idi = numpy.identity(2, dtype=numpy.float32) + idi2 = (numpy.identity(2) * 2).astype(numpy.float32) + onx = OnnxAdd( + OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()), + idi2, output_names=['Y'], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + oinf = OnnxInference(model_def) + X = numpy.array([[1, 1], [3, 3]]) + y = oinf.run({'X': X.astype(numpy.float32)}) + exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32) + self.assertEqual(list(y), ['Y']) + self.assertEqualArray(y['Y'], exp) + + @ignore_warnings(DeprecationWarning) + def test_onnxt_cpu_numpy_onnxruntime1(self): + idi = numpy.identity(2, dtype=numpy.float32) + idi2 = (numpy.identity(2) * 2).astype(numpy.float32) + onx = OnnxAdd( + OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()), + idi2, output_names=['Y'], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + oinf = OnnxInference(model_def, runtime="onnxruntime1") + X = numpy.array([[1, 1], [3, 3]]) + y = oinf.run({'X': X.astype(numpy.float32)}) + exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32) + self.assertEqual(list(y), ['Y']) + self.assertEqualArray(y['Y'], exp) + + @ignore_warnings(DeprecationWarning) + def test_onnxt_cpu_ortvalue_python(self): + idi = numpy.identity(2, dtype=numpy.float32) + idi2 = (numpy.identity(2) * 2).astype(numpy.float32) + onx = OnnxAdd( + OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()), + idi2, output_names=['Y'], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + oinf = OnnxInference(model_def) + X = numpy.array([[1, 1], [3, 3]]) + X32 = X.astype(numpy.float32) + ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cpu')) + self.assertRaise(lambda: oinf.run({'X': ov}), AttributeError) + + @ignore_warnings(DeprecationWarning) + def test_onnxt_cpu_ortvalue_ort(self): + idi = numpy.identity(2, dtype=numpy.float32) + idi2 = (numpy.identity(2) * 2).astype(numpy.float32) + onx = OnnxAdd( + OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()), + idi2, output_names=['Y'], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + oinf = OnnxInference(model_def, runtime="onnxruntime1") + X = numpy.array([[1, 1], [3, 3]]) + X32 = X.astype(numpy.float32) + ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cpu')) + y = oinf.run({'X': ov}) + exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32) + self.assertEqual(list(y), ['Y']) + self.assertEqualArray(y['Y'].numpy(), exp) + + @ignore_warnings(DeprecationWarning) + def test_onnxt_cpu_ortvalue_ort_cpu(self): + idi = numpy.identity(2, dtype=numpy.float32) + idi2 = (numpy.identity(2) * 2).astype(numpy.float32) + onx = OnnxAdd( + OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()), + idi2, output_names=['Y'], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + self.assertRaise(lambda: OnnxInference(model_def, device='cpu'), + ValueError) + oinf = OnnxInference(model_def, runtime="onnxruntime1", device='cpu') + X = numpy.array([[1, 1], [3, 3]]) + X32 = X.astype(numpy.float32) + ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cpu')) + y = oinf.run({'X': ov}) + exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32) + self.assertEqual(list(y), ['Y']) + self.assertEqualArray(y['Y'].numpy(), exp) + + @unittest.skipIf(DEVICE != 'cuda', reason="runs only on GPU") + @ignore_warnings(DeprecationWarning) + def test_onnxt_ortvalue_ort_gpu(self): + idi = numpy.identity(2, dtype=numpy.float32) + idi2 = (numpy.identity(2) * 2).astype(numpy.float32) + onx = OnnxAdd( + OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()), + idi2, output_names=['Y'], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + oinf = OnnxInference(model_def, runtime="onnxruntime1", device='cuda') + X = numpy.array([[1, 1], [3, 3]]) + X32 = X.astype(numpy.float32) + ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cuda')) + y = oinf.run({'X': ov}) + exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32) + self.assertEqual(list(y), ['Y']) + self.assertEqualArray(y['Y'].cpu().numpy(), exp) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/onnxrt/onnx_inference.py b/mlprodict/onnxrt/onnx_inference.py index 416dcecd9..1a30fa90b 100644 --- a/mlprodict/onnxrt/onnx_inference.py +++ b/mlprodict/onnxrt/onnx_inference.py @@ -72,6 +72,8 @@ class OnnxInference: be cut to have these new_outputs as the final outputs :param new_opset: overwrite the main opset and replaces by this new one + :param device: device, a string `cpu`, `cuda`, `cuda:0`..., + this option is only available with runtime *onnxruntime1* Among the possible runtime_options, there are: * *enable_profiling*: enables profiling for :epkg:`onnxruntime` @@ -83,7 +85,7 @@ class OnnxInference: Parameters *new_outputs*, *new_opset* were added. .. versionchanged:: 0.8 - Parameter *static_inputs* was added. + Parameters *static_inputs*, *device* were added. """ def __init__(self, onnx_or_bytes_or_stream, runtime=None, @@ -91,7 +93,8 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None, input_inplace=False, ir_version=None, target_opset=None, runtime_options=None, session_options=None, inside_loop=False, - static_inputs=None, new_outputs=None, new_opset=None): + static_inputs=None, new_outputs=None, new_opset=None, + device=None): if isinstance(onnx_or_bytes_or_stream, bytes): self.obj = load_model(BytesIO(onnx_or_bytes_or_stream)) elif isinstance(onnx_or_bytes_or_stream, BytesIO): @@ -113,6 +116,10 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None, self.obj, outputs=new_outputs, infer_shapes=True) if new_opset is not None: self.obj = overwrite_opset(self.obj, new_opset) + if device is not None and runtime != 'onnxruntime1': + raise ValueError( + "Incompatible values, device can be specified with " + "runtime 'onnxruntime1', not %r." % runtime) self.runtime = runtime self.skip_run = skip_run @@ -122,6 +129,7 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None, self.runtime_options = runtime_options self.inside_loop = inside_loop self.static_inputs = static_inputs + self.device = device self._init() def __getstate__(self): @@ -136,7 +144,8 @@ def __getstate__(self): 'inplace': self.inplace, 'force_target_opset': self.force_target_opset, 'static_inputs': self.static_inputs, - 'inside_loop': self.inside_loop} + 'inside_loop': self.inside_loop, + 'device': self.device} def __setstate__(self, state): """ @@ -152,6 +161,7 @@ def __setstate__(self, state): self.force_target_opset = state['force_target_opset'] self.static_inputs = state['static_inputs'] self.inside_loop = state['inside_loop'] + self.device = state['device'] self._init() def _init(self): @@ -190,7 +200,8 @@ def _init(self): del self.graph_ from .ops_whole.session import OnnxWholeSession self._whole = OnnxWholeSession( - self.obj, self.runtime, self.runtime_options) + self.obj, self.runtime, self.runtime_options, + self.device) self._run = self._run_whole_runtime else: self.sequence_ = self.graph_['sequence'] diff --git a/mlprodict/onnxrt/ops_whole/session.py b/mlprodict/onnxrt/ops_whole/session.py index 78be92140..78e8a2695 100644 --- a/mlprodict/onnxrt/ops_whole/session.py +++ b/mlprodict/onnxrt/ops_whole/session.py @@ -18,15 +18,17 @@ class OnnxWholeSession: """ Runs the prediction for a single :epkg:`ONNX`, it lets the runtime handle the graph logic as well. + + :param onnx_data: :epkg:`ONNX` model or data + :param runtime: runtime to be used, mostly :epkg:`onnxruntime` + :param runtime_options: runtime options + :param device: device, a string `cpu`, `cuda`, `cuda:0`... + + .. versionchanged:: 0.8 + Parameter *device* was added. """ - def __init__(self, onnx_data, runtime, runtime_options=None): - """ - @param onnx_data :epkg:`ONNX` model or data - @param runtime runtime to be used, - mostly :epkg:`onnxruntime` - @param runtime_options runtime options - """ + def __init__(self, onnx_data, runtime, runtime_options=None, device=None): if runtime != 'onnxruntime1': raise NotImplementedError( # pragma: no cover "runtime '{}' is not implemented.".format(runtime)) @@ -68,7 +70,8 @@ def __init__(self, onnx_data, runtime, runtime_options=None): "session_options and log_severity_level cannot be defined at the " "same time.") try: - self.sess = InferenceSession(onnx_data, sess_options=sess_options) + self.sess = InferenceSession(onnx_data, sess_options=sess_options, + device=device) except (OrtFail, OrtNotImplemented, OrtInvalidGraph, OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e: raise RuntimeError( diff --git a/mlprodict/tools/onnx_inference_ort_helper.py b/mlprodict/tools/onnx_inference_ort_helper.py new file mode 100644 index 000000000..2aeabeaf4 --- /dev/null +++ b/mlprodict/tools/onnx_inference_ort_helper.py @@ -0,0 +1,64 @@ +# pylint: disable=C0302 +""" +@file +@brief Helpers for :epkg:`onnxruntime`. +""" +from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611 + OrtDevice as C_OrtDevice) + + +def get_ort_device(device): + """ + Converts device into :epkg:`C_OrtDevice`. + + :param device: any type + :return: :epkg:`C_OrtDevice` + + Example: + + :: + + get_ort_device('cpu') + get_ort_device('gpu') + get_ort_device('cuda') + get_ort_device('cuda:0') + """ + if isinstance(device, C_OrtDevice): + return device + if isinstance(device, str): + if device == 'cpu': + return C_OrtDevice( + C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) + if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}: + return C_OrtDevice( + C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + if device.startswith('gpu:'): + idx = int(device[4:]) + return C_OrtDevice( + C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) + if device.startswith('cuda:'): + idx = int(device[5:]) + return C_OrtDevice( + C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) + raise ValueError( + "Unable to interpret string %r as a device." % device) + raise TypeError( + "Unable to interpret type %r, (%r) as de device." % ( + type(device), device)) + + +def device_to_providers(device): + """ + Returns the corresponding providers for a specific device. + + :param device: :epkg:`C_OrtDevice` + :return: providers + """ + if isinstance(device, str): + device = get_ort_device(device) + if device.device_type() == device.cpu(): + return ['CPUExecutionProvider'] + if device.device_type() == device.cuda(): + return ['CUDAExecutionProvider'] + raise ValueError( # pragma: no cover + "Unexpected device %r." % device) diff --git a/mlprodict/tools/ort_wrapper.py b/mlprodict/tools/ort_wrapper.py index e061ab08b..847ee4614 100644 --- a/mlprodict/tools/ort_wrapper.py +++ b/mlprodict/tools/ort_wrapper.py @@ -13,12 +13,15 @@ InferenceSession as OrtInferenceSession, __version__ as onnxrt_version, GraphOptimizationLevel) + from .onnx_inference_ort_helper import get_ort_device, device_to_providers except ImportError: # pragma: no cover SessionOptions = None RunOptions = None OrtInferenceSession = None onnxrt_version = "0.0.0" GraphOptimizationLevel = None + get_ort_device = None + device_to_providers = None try: from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=W0611 @@ -26,7 +29,8 @@ NotImplemented as OrtNotImplemented, InvalidArgument as OrtInvalidArgument, InvalidGraph as OrtInvalidGraph, - RuntimeException as OrtRuntimeException) + RuntimeException as OrtRuntimeException, + OrtValue as C_OrtValue) except ImportError: # pragma: no cover SessionOptions = None RunOptions = None @@ -38,6 +42,7 @@ OrtInvalidArgument = RuntimeError OrtInvalidGraph = RuntimeError OrtRuntimeException = RuntimeError + C_OrtValue = None class InferenceSession: # pylint: disable=E0102 @@ -46,22 +51,34 @@ class InferenceSession: # pylint: disable=E0102 :param onnx_bytes: onnx bytes :param session_options: session options + :param log_severity_level: change the logging level + :param device: device, a string `cpu`, `cuda`, `cuda:0`... """ - def __init__(self, onnx_bytes, sess_options=None, log_severity_level=4): + def __init__(self, onnx_bytes, sess_options=None, log_severity_level=4, + device=None): if InferenceSession is None: raise ImportError( # pragma: no cover "onnxruntime is not available.") self.log_severity_level = log_severity_level + if device is None: + self.device = get_ort_device('cpu') + else: + self.device = get_ort_device(device) + self.providers = device_to_providers(self.device) if sess_options is None: self.so = SessionOptions() self.so.log_severity_level = log_severity_level - self.sess = OrtInferenceSession(onnx_bytes, sess_options=self.so) + self.sess = OrtInferenceSession( + onnx_bytes, sess_options=self.so, + providers=self.providers) else: self.sess = OrtInferenceSession( - onnx_bytes, sess_options=sess_options) + onnx_bytes, sess_options=sess_options, + providers=self.providers) self.ro = RunOptions() self.ro.log_severity_level = log_severity_level + self.output_names = [o.name for o in self.get_outputs()] def run(self, output_names, input_feed, run_options=None): """ @@ -72,6 +89,10 @@ def run(self, output_names, input_feed, run_options=None): :param run_options: None or RunOptions :return: array """ + if any(map(lambda v: isinstance(v, C_OrtValue), + input_feed.values())): + return self.sess._sess.run_with_ort_values( + input_feed, self.output_names, run_options or self.ro) return self.sess.run(output_names, input_feed, run_options or self.ro) def get_inputs(self):