Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions _unittests/ut_onnxrt/test_nb_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest
from logging import getLogger
import numpy
from pyquickhelper.pycode import ExtTestCase
from pyquickhelper.pycode import ExtTestCase, ignore_warnings
from skl2onnx.algebra.onnx_ops import OnnxAdd # pylint: disable=E0611
from mlprodict.onnxrt.doc.nb_helper import OnnxNotebook
from mlprodict.tools import get_opset_number_from_onnx
Expand All @@ -16,8 +16,8 @@ def setUp(self):
logger = getLogger('skl2onnx')
logger.disabled = True

@ignore_warnings(DeprecationWarning)
def test_onnxview(self):

idi = numpy.identity(2)
onx = OnnxAdd('X', idi, output_names=['Y'],
op_version=get_opset_number_from_onnx())
Expand Down Expand Up @@ -47,6 +47,21 @@ def test_onnxview(self):
self.assertNotEmpty(res)
self.assertIn('RenderJsDot', str(res))

@ignore_warnings(DeprecationWarning)
def test_onnxview_empty(self):
idi = numpy.identity(2)
onx = OnnxAdd('X', idi, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)})

mg = OnnxNotebook()
mg.add_context(
{"model": model_def})
cmd = "model --runtime=empty"
res = mg.onnxview(cmd)
self.assertNotEmpty(res)
self.assertIn('RenderJsDot', str(res))


if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions _unittests/ut_onnxrt/test_onnxrt_runtime_empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
@brief test log(time=2s)
"""
import unittest
from logging import getLogger
import numpy
from onnx import helper, TensorProto
from pyquickhelper.pycode import ExtTestCase, ignore_warnings
from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
OnnxAdd)
from mlprodict.onnxrt import OnnxInference
from mlprodict.tools.asv_options_helper import (
get_ir_version_from_onnx, get_opset_number_from_onnx)


class TestOnnxrtRuntimeEmpty(ExtTestCase):

def setUp(self):
logger = getLogger('skl2onnx')
logger.disabled = True

@ignore_warnings(DeprecationWarning)
def test_onnxt_runtime_empty(self):
idi = numpy.identity(2, dtype=numpy.float32)
onx = OnnxAdd('X', idi, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)})
model_def.ir_version = get_ir_version_from_onnx()
oinf = OnnxInference(model_def, runtime='empty')
self.assertNotEmpty(oinf)

@ignore_warnings(DeprecationWarning)
def test_onnxt_runtime_empty_dot(self):
idi = numpy.identity(2, dtype=numpy.float32)
onx = OnnxAdd('X', idi, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)})
model_def.ir_version = get_ir_version_from_onnx()
oinf = OnnxInference(model_def, runtime='empty')
self.assertNotEmpty(oinf)
dot = oinf.to_dot()
self.assertIn("-> Y;", dot)

@ignore_warnings(DeprecationWarning)
def test_onnxt_runtime_empty_unknown(self):
X = helper.make_tensor_value_info(
'X', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101
Y = helper.make_tensor_value_info(
'Y', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101
Z = helper.make_tensor_value_info(
'Z', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101
node_def = helper.make_node('Add', ['X', 'Y'], ['Zt'], name='Zt')
node_def2 = helper.make_node('AddUnknown', ['X', 'Zt'], ['Z'], name='Z')
graph_def = helper.make_graph(
[node_def, node_def2], 'test-model', [X, Y], [Z])
model_def = helper.make_model(graph_def, producer_name='onnx-example')
oinf = OnnxInference(model_def, runtime='empty')
self.assertNotEmpty(oinf)
dot = oinf.to_dot()
self.assertIn('AddUnknown', dot)
self.assertNotIn('x{', dot)


if __name__ == "__main__":
unittest.main()
10 changes: 8 additions & 2 deletions mlprodict/onnxrt/doc/nb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ..onnx_inference import OnnxInference


def onnxview(graph, recursive=False, local=False, add_rt_shapes=False):
def onnxview(graph, recursive=False, local=False, add_rt_shapes=False,
runtime='python'):
"""
Displays an :epkg:`ONNX` graph into a notebook.

Expand All @@ -20,8 +21,13 @@ def onnxview(graph, recursive=False, local=False, add_rt_shapes=False):
:param add_rt_shapes: add information about the shapes
the runtime was able to find out,
the runtime has to be `'python'`
:param runtime: the view fails if a runtime does not implement a specific
node unless *runtime* is `'empty'`

.. versionchanged:: 0.6
Parameter *runtime* was added.
"""
sess = OnnxInference(graph, skip_run=not add_rt_shapes)
sess = OnnxInference(graph, skip_run=not add_rt_shapes, runtime=runtime)
dot = sess.to_dot(recursive=recursive, add_rt_shapes=add_rt_shapes)
return RenderJsDot(dot, local=local)

Expand Down
4 changes: 2 additions & 2 deletions mlprodict/onnxrt/onnx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _init(self):
for node in self.sequence_:
domain = node.onnx_node.domain
target_opset = self.target_opset_.get(domain, None)
if self.runtime == 'onnxruntime2':
if self.runtime in ('onnxruntime2', 'empty'):
node.setup_runtime(self.runtime, variables, self.__class__,
target_opset=target_opset, dtype=dtype,
domain=domain, ir_version=self.ir_version_,
Expand Down Expand Up @@ -414,7 +414,7 @@ def to_sequence(self):
k, names[k, 0][0]))
names[k, 0] = ('I', v)
for k, v in outputs.items():
if (k, 0) in names:
if (k, 0) in names and self.runtime != 'empty':
raise RuntimeError( # pragma: no cover
"Output '{}' already exists (tag='{}').".format(
k, names[k, 0][0]))
Expand Down
61 changes: 44 additions & 17 deletions mlprodict/onnxrt/onnx_inference_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import os
import json
import re
from io import BytesIO
import pickle
import textwrap
Expand All @@ -23,8 +24,8 @@ def __init__(self, oinf):
"""
self.oinf = oinf

def to_dot(self, recursive=False, prefix='', add_rt_shapes=False,
use_onnx=False, **params):
def to_dot(self, recursive=False, prefix='', # pylint: disable=R0914
add_rt_shapes=False, use_onnx=False, **params):
"""
Produces a :epkg:`DOT` language string for the graph.

Expand Down Expand Up @@ -78,6 +79,19 @@ def to_dot(self, recursive=False, prefix='', add_rt_shapes=False,
See an example of representation in notebook
:ref:`onnxvisualizationrst`.
"""
clean_label_reg1 = re.compile("\\\\x\\{[0-9A-F]{1,6}\\}")
clean_label_reg2 = re.compile("\\\\p\\{[0-9P]{1,6}\\}")

def dot_name(text):
return text.replace("/", "_").replace(":", "__")

def dot_label(text):
for reg in [clean_label_reg1, clean_label_reg2]:
fall = reg.findall(text)
for f in fall:
text = text.replace(f, "_")
return text

options = {
'orientation': 'portrait',
'ranksep': '0.25',
Expand Down Expand Up @@ -123,8 +137,10 @@ def to_dot(self, recursive=False, prefix='', add_rt_shapes=False,
sh = shapes.get(dobj['name'], '')
if sh:
sh = "\\nshape={}".format(sh)
exp.append(' {3}{0} [shape=box color=red label="{0}\\n{1}{4}" fontsize={2}];'.format(
dobj['name'], _type_to_string(dobj['type']), fontsize, prefix, sh))
exp.append(
' {3}{0} [shape=box color=red label="{0}\\n{1}{4}" fontsize={2}];'.format(
dot_name(dobj['name']), _type_to_string(dobj['type']),
fontsize, prefix, dot_label(sh)))
inter_vars[obj.name] = obj

# outputs
Expand All @@ -134,8 +150,10 @@ def to_dot(self, recursive=False, prefix='', add_rt_shapes=False,
sh = shapes.get(dobj['name'], '')
if sh:
sh = "\\nshape={}".format(sh)
exp.append(' {3}{0} [shape=box color=green label="{0}\\n{1}{4}" fontsize={2}];'.format(
dobj['name'], _type_to_string(dobj['type']), fontsize, prefix, sh))
exp.append(
' {3}{0} [shape=box color=green label="{0}\\n{1}{4}" fontsize={2}];'.format(
dot_name(dobj['name']), _type_to_string(dobj['type']),
fontsize, prefix, dot_label(sh)))
inter_vars[obj.name] = obj

# initializer
Expand All @@ -152,9 +170,10 @@ def to_dot(self, recursive=False, prefix='', add_rt_shapes=False,
st = st[:50] + '...'
st = st.replace('\n', '\\n')
kind = ""
exp.append(' {6}{0} [shape=box label="{0}\\n{4}{1}({2})\\n{3}" fontsize={5}];'.format(
dobj['name'], dobj['value'].dtype,
dobj['value'].shape, st, kind, fontsize, prefix))
exp.append(
' {6}{0} [shape=box label="{0}\\n{4}{1}({2})\\n{3}" fontsize={5}];'.format(
dot_name(dobj['name']), dobj['value'].dtype,
dobj['value'].shape, dot_label(st), kind, fontsize, prefix))
inter_vars[obj.name] = obj

# nodes
Expand All @@ -169,7 +188,7 @@ def to_dot(self, recursive=False, prefix='', add_rt_shapes=False,
sh = "\\nshape={}".format(sh)
exp.append(
' {2}{0} [shape=box label="{0}{3}" fontsize={1}];'.format(
out, fontsize, prefix, sh))
dot_name(out), fontsize, dot_name(prefix), dot_label(sh)))

dobj = _var_as_dict(node)
if dobj['name'].strip() == '': # pragma: no cover
Expand Down Expand Up @@ -221,29 +240,36 @@ def to_dot(self, recursive=False, prefix='', add_rt_shapes=False,
exp.append(" subgraph cluster_{}{} {{".format(
node.op_type, id(node)))
exp.append(' label="{0}\\n({1}){2}";'.format(
dobj['op_type'], dobj['name'], satts))
dobj['op_type'], dot_name(dobj['name']), satts))
exp.append(' fontsize={0};'.format(fontsize))
exp.append(' color=black;')
exp.append(
'\n'.join(map(lambda s: ' ' + s, subgraph.split('\n'))))

for inp1, inp2 in zip(node.input, body.input):
exp.append(
" {0}{1} -> {2}{3};".format(prefix, inp1, subprefix, inp2.name))
" {0}{1} -> {2}{3};".format(
dot_name(prefix), dot_name(inp1),
dot_name(subprefix), dot_name(inp2.name)))
for out1, out2 in zip(body.output, node.output):
exp.append(
" {0}{1} -> {2}{3};".format(subprefix, out1.name, prefix, out2))
" {0}{1} -> {2}{3};".format(
dot_name(subprefix), dot_name(out1.name),
dot_name(prefix), dot_name(out2)))

else:
exp.append(' {4}{1} [shape=box style="filled,rounded" color=orange label="{0}\\n({1}){2}" fontsize={3}];'.format(
dobj['op_type'], dobj['name'], satts, fontsize, prefix))
dobj['op_type'], dot_name(dobj['name']), satts, fontsize,
dot_name(prefix)))

for inp in node.input:
exp.append(
" {0}{1} -> {0}{2};".format(prefix, inp, node.name))
" {0}{1} -> {0}{2};".format(
dot_name(prefix), dot_name(inp), dot_name(node.name)))
for out in node.output:
exp.append(
" {0}{1} -> {0}{2};".format(prefix, node.name, out))
" {0}{1} -> {0}{2};".format(
dot_name(prefix), dot_name(node.name), dot_name(out)))

exp.append('}')
return "\n".join(exp)
Expand Down Expand Up @@ -428,7 +454,8 @@ def clean_args(args):

if self.oinf.runtime != 'python':
raise ValueError(
"The runtime must be python not '{}'.".format(self.oinf.runtime))
"The runtime must be 'python' not '{}'.".format(
self.oinf.runtime))

# metadata
obj = {}
Expand Down
3 changes: 3 additions & 0 deletions mlprodict/onnxrt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def load_op(onnx_node, desc=None, options=None, variables=None, dtype=None):
if provider == 'python':
from .ops_cpu import load_op as lo
return lo(onnx_node, desc=desc, options=options)
if provider == 'empty':
from .ops_empty import load_op as lo
return lo(onnx_node, desc=desc, options=options)
if provider == 'onnxruntime2':
from .ops_onnxruntime import load_op as lo
return lo(onnx_node, desc=desc, options=options, # pylint: disable=E1123
Expand Down
25 changes: 25 additions & 0 deletions mlprodict/onnxrt/ops_empty/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- encoding: utf-8 -*-
"""
@file
@brief Shortcut to *ops_cpu*.
"""
from ._op import OpRunOnnxEmpty


def load_op(onnx_node, desc=None, options=None, variables=None, dtype=None):
"""
Gets the operator related to the *onnx* node.
This runtime does nothing and never complains.

:param onnx_node: :epkg:`onnx` node
:param desc: internal representation
:param options: runtime options
:param variables: registered variables created by previous operators
:param dtype: float computation type
:return: runtime class
"""
if desc is None:
raise ValueError( # pragma: no cover
"desc should not be None.")
return OpRunOnnxEmpty(onnx_node, desc, variables=variables,
dtype=dtype, **options)
Loading