Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions _doc/sphinxdoc/source/api/onnxrt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ ONNX Structure

.. autosignature:: mlprodict.onnx_tools.onnx_manipulations.select_model_inputs_outputs

onnxruntime
+++++++++++

.. autosignature:: mlprodict.onnxrt.onnx_inference_ort.device_to_providers

.. autosignature:: mlprodict.onnxrt.onnx_inference_ort.get_ort_device

Validation
++++++++++

Expand Down
134 changes: 134 additions & 0 deletions _unittests/ut_onnxrt/test_onnxrt_iobinding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
@brief test log(time=6s)
"""
import unittest
import numpy
from pyquickhelper.pycode import ExtTestCase, ignore_warnings
from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611
OrtDevice as C_OrtDevice, OrtValue as C_OrtValue)
from onnxruntime import get_device
from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611,W0611
OnnxAdd)
from mlprodict.onnxrt import OnnxInference
from mlprodict.tools.onnx_inference_ort_helper import get_ort_device
from mlprodict.tools import get_opset_number_from_onnx


DEVICE = "cuda" if get_device().upper() == 'GPU' else 'cpu'


class TestOnnxrtIOBinding(ExtTestCase):

@ignore_warnings(DeprecationWarning)
def test_onnxt_cpu_numpy_python(self):
idi = numpy.identity(2, dtype=numpy.float32)
idi2 = (numpy.identity(2) * 2).astype(numpy.float32)
onx = OnnxAdd(
OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()),
idi2, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
oinf = OnnxInference(model_def)
X = numpy.array([[1, 1], [3, 3]])
y = oinf.run({'X': X.astype(numpy.float32)})
exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32)
self.assertEqual(list(y), ['Y'])
self.assertEqualArray(y['Y'], exp)

@ignore_warnings(DeprecationWarning)
def test_onnxt_cpu_numpy_onnxruntime1(self):
idi = numpy.identity(2, dtype=numpy.float32)
idi2 = (numpy.identity(2) * 2).astype(numpy.float32)
onx = OnnxAdd(
OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()),
idi2, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
oinf = OnnxInference(model_def, runtime="onnxruntime1")
X = numpy.array([[1, 1], [3, 3]])
y = oinf.run({'X': X.astype(numpy.float32)})
exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32)
self.assertEqual(list(y), ['Y'])
self.assertEqualArray(y['Y'], exp)

@ignore_warnings(DeprecationWarning)
def test_onnxt_cpu_ortvalue_python(self):
idi = numpy.identity(2, dtype=numpy.float32)
idi2 = (numpy.identity(2) * 2).astype(numpy.float32)
onx = OnnxAdd(
OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()),
idi2, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
oinf = OnnxInference(model_def)
X = numpy.array([[1, 1], [3, 3]])
X32 = X.astype(numpy.float32)
ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cpu'))
self.assertRaise(lambda: oinf.run({'X': ov}), AttributeError)

@ignore_warnings(DeprecationWarning)
def test_onnxt_cpu_ortvalue_ort(self):
idi = numpy.identity(2, dtype=numpy.float32)
idi2 = (numpy.identity(2) * 2).astype(numpy.float32)
onx = OnnxAdd(
OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()),
idi2, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
oinf = OnnxInference(model_def, runtime="onnxruntime1")
X = numpy.array([[1, 1], [3, 3]])
X32 = X.astype(numpy.float32)
ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cpu'))
y = oinf.run({'X': ov})
exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32)
self.assertEqual(list(y), ['Y'])
self.assertEqualArray(y['Y'].numpy(), exp)

@ignore_warnings(DeprecationWarning)
def test_onnxt_cpu_ortvalue_ort_cpu(self):
idi = numpy.identity(2, dtype=numpy.float32)
idi2 = (numpy.identity(2) * 2).astype(numpy.float32)
onx = OnnxAdd(
OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()),
idi2, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
self.assertRaise(lambda: OnnxInference(model_def, device='cpu'),
ValueError)
oinf = OnnxInference(model_def, runtime="onnxruntime1", device='cpu')
X = numpy.array([[1, 1], [3, 3]])
X32 = X.astype(numpy.float32)
ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cpu'))
y = oinf.run({'X': ov})
exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32)
self.assertEqual(list(y), ['Y'])
self.assertEqualArray(y['Y'].numpy(), exp)

@unittest.skipIf(DEVICE != 'cuda', reason="runs only on GPU")
@ignore_warnings(DeprecationWarning)
def test_onnxt_ortvalue_ort_gpu(self):
idi = numpy.identity(2, dtype=numpy.float32)
idi2 = (numpy.identity(2) * 2).astype(numpy.float32)
onx = OnnxAdd(
OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()),
idi2, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
oinf = OnnxInference(model_def, runtime="onnxruntime1", device='cuda')
X = numpy.array([[1, 1], [3, 3]])
X32 = X.astype(numpy.float32)
ov = C_OrtValue.ortvalue_from_numpy(X32, get_ort_device('cuda'))
y = oinf.run({'X': ov})
exp = numpy.array([[4, 1], [3, 6]], dtype=numpy.float32)
self.assertEqual(list(y), ['Y'])
self.assertEqualArray(y['Y'].cpu().numpy(), exp)


if __name__ == "__main__":
unittest.main()
19 changes: 15 additions & 4 deletions mlprodict/onnxrt/onnx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class OnnxInference:
be cut to have these new_outputs as the final outputs
:param new_opset: overwrite the main opset and replaces
by this new one
:param device: device, a string `cpu`, `cuda`, `cuda:0`...,
this option is only available with runtime *onnxruntime1*

Among the possible runtime_options, there are:
* *enable_profiling*: enables profiling for :epkg:`onnxruntime`
Expand All @@ -83,15 +85,16 @@ class OnnxInference:
Parameters *new_outputs*, *new_opset* were added.

.. versionchanged:: 0.8
Parameter *static_inputs* was added.
Parameters *static_inputs*, *device* were added.
"""

def __init__(self, onnx_or_bytes_or_stream, runtime=None,
skip_run=False, inplace=True,
input_inplace=False, ir_version=None,
target_opset=None, runtime_options=None,
session_options=None, inside_loop=False,
static_inputs=None, new_outputs=None, new_opset=None):
static_inputs=None, new_outputs=None, new_opset=None,
device=None):
if isinstance(onnx_or_bytes_or_stream, bytes):
self.obj = load_model(BytesIO(onnx_or_bytes_or_stream))
elif isinstance(onnx_or_bytes_or_stream, BytesIO):
Expand All @@ -113,6 +116,10 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None,
self.obj, outputs=new_outputs, infer_shapes=True)
if new_opset is not None:
self.obj = overwrite_opset(self.obj, new_opset)
if device is not None and runtime != 'onnxruntime1':
raise ValueError(
"Incompatible values, device can be specified with "
"runtime 'onnxruntime1', not %r." % runtime)

self.runtime = runtime
self.skip_run = skip_run
Expand All @@ -122,6 +129,7 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None,
self.runtime_options = runtime_options
self.inside_loop = inside_loop
self.static_inputs = static_inputs
self.device = device
self._init()

def __getstate__(self):
Expand All @@ -136,7 +144,8 @@ def __getstate__(self):
'inplace': self.inplace,
'force_target_opset': self.force_target_opset,
'static_inputs': self.static_inputs,
'inside_loop': self.inside_loop}
'inside_loop': self.inside_loop,
'device': self.device}

def __setstate__(self, state):
"""
Expand All @@ -152,6 +161,7 @@ def __setstate__(self, state):
self.force_target_opset = state['force_target_opset']
self.static_inputs = state['static_inputs']
self.inside_loop = state['inside_loop']
self.device = state['device']
self._init()

def _init(self):
Expand Down Expand Up @@ -190,7 +200,8 @@ def _init(self):
del self.graph_
from .ops_whole.session import OnnxWholeSession
self._whole = OnnxWholeSession(
self.obj, self.runtime, self.runtime_options)
self.obj, self.runtime, self.runtime_options,
self.device)
self._run = self._run_whole_runtime
else:
self.sequence_ = self.graph_['sequence']
Expand Down
19 changes: 11 additions & 8 deletions mlprodict/onnxrt/ops_whole/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ class OnnxWholeSession:
"""
Runs the prediction for a single :epkg:`ONNX`,
it lets the runtime handle the graph logic as well.

:param onnx_data: :epkg:`ONNX` model or data
:param runtime: runtime to be used, mostly :epkg:`onnxruntime`
:param runtime_options: runtime options
:param device: device, a string `cpu`, `cuda`, `cuda:0`...

.. versionchanged:: 0.8
Parameter *device* was added.
"""

def __init__(self, onnx_data, runtime, runtime_options=None):
"""
@param onnx_data :epkg:`ONNX` model or data
@param runtime runtime to be used,
mostly :epkg:`onnxruntime`
@param runtime_options runtime options
"""
def __init__(self, onnx_data, runtime, runtime_options=None, device=None):
if runtime != 'onnxruntime1':
raise NotImplementedError( # pragma: no cover
"runtime '{}' is not implemented.".format(runtime))
Expand Down Expand Up @@ -68,7 +70,8 @@ def __init__(self, onnx_data, runtime, runtime_options=None):
"session_options and log_severity_level cannot be defined at the "
"same time.")
try:
self.sess = InferenceSession(onnx_data, sess_options=sess_options)
self.sess = InferenceSession(onnx_data, sess_options=sess_options,
device=device)
except (OrtFail, OrtNotImplemented, OrtInvalidGraph,
OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e:
raise RuntimeError(
Expand Down
64 changes: 64 additions & 0 deletions mlprodict/tools/onnx_inference_ort_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# pylint: disable=C0302
"""
@file
@brief Helpers for :epkg:`onnxruntime`.
"""
from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611
OrtDevice as C_OrtDevice)


def get_ort_device(device):
"""
Converts device into :epkg:`C_OrtDevice`.

:param device: any type
:return: :epkg:`C_OrtDevice`

Example:

::

get_ort_device('cpu')
get_ort_device('gpu')
get_ort_device('cuda')
get_ort_device('cuda:0')
"""
if isinstance(device, C_OrtDevice):
return device
if isinstance(device, str):
if device == 'cpu':
return C_OrtDevice(
C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}:
return C_OrtDevice(
C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)
if device.startswith('gpu:'):
idx = int(device[4:])
return C_OrtDevice(
C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
if device.startswith('cuda:'):
idx = int(device[5:])
return C_OrtDevice(
C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
raise ValueError(
"Unable to interpret string %r as a device." % device)
raise TypeError(
"Unable to interpret type %r, (%r) as de device." % (
type(device), device))


def device_to_providers(device):
"""
Returns the corresponding providers for a specific device.

:param device: :epkg:`C_OrtDevice`
:return: providers
"""
if isinstance(device, str):
device = get_ort_device(device)
if device.device_type() == device.cpu():
return ['CPUExecutionProvider']
if device.device_type() == device.cuda():
return ['CUDAExecutionProvider']
raise ValueError( # pragma: no cover
"Unexpected device %r." % device)
Loading