Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions _unittests/ut_onnxrt/test_onnx_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
@brief test log(time=2s)
"""
import unittest
from logging import getLogger
import numpy
from onnx import helper, TensorProto
from pyquickhelper.pycode import ExtTestCase, ignore_warnings
from mlprodict.onnxrt import OnnxInference
from mlprodict.onnxrt.onnx_tools import insert_node
from mlprodict.onnxrt.ops_cpu._op import RuntimeTypeError


class TestOnnxTools(ExtTestCase):

def setUp(self):
logger = getLogger('skl2onnx')
logger.disabled = True

@ignore_warnings(DeprecationWarning)
def test_onnx_inference_name_confusion(self):
X = helper.make_tensor_value_info(
'X', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101
Y = helper.make_tensor_value_info(
'Y', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101
Z = helper.make_tensor_value_info(
'Z', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101
node_def = helper.make_node('Add', ['X', 'Y'], ['Zt'], name='Zt')
node_def2 = helper.make_node('Add', ['X', 'Zt'], ['Z'], name='Z')
graph_def = helper.make_graph(
[node_def, node_def2], 'test-model', [X, Y], [Z])
model_def = helper.make_model(graph_def, producer_name='onnx-example')
model_def = insert_node(
model_def, node='Z', op_type='Cast', to=TensorProto.INT64, # pylint: disable=E1101
name='castop')
self.assertIn('castop', str(model_def))

oinf = OnnxInference(model_def)
X = (numpy.random.randn(4, 2) * 100000).astype( # pylint: disable=E1101
numpy.float32)
Y = (numpy.random.randn(4, 2) * 100000).astype( # pylint: disable=E1101
numpy.float32)
exp = (X * 2 + Y).astype(numpy.float32)
self.assertRaise(lambda: oinf.run({'X': X, 'Y': Y}), RuntimeTypeError)

model_def = insert_node(
model_def, node='Z', op_type='Cast', to=TensorProto.FLOAT, # pylint: disable=E1101
name='castop2')
oinf = OnnxInference(model_def)
res = oinf.run({'X': X, 'Y': Y})
got = res['Z']
self.assertEqualArray(exp / 100000, got / 100000, decimal=5)


if __name__ == "__main__":
unittest.main()
104 changes: 104 additions & 0 deletions mlprodict/onnxrt/onnx_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
@file
@brief Functions to manipulate ONNX file.
"""
from onnx import helper


def find_node_name(model, name):
"""
Finds a node by its name.
:param model: onnx graph
:param name: node name
:return: node pointer
"""
if not hasattr(model, "graph"):
raise TypeError( # pragma: no cover
"Parameter model is not an ONNX model but "
"{}".format(type(model)))
for node in model.graph.node:
if node.name == name:
return node
return None


def find_node_input_name(node, name):
"""
Finds a node input by its name.
:param node: onnx node
:param name: node name
:return: input index
"""
for i, inode in enumerate(node.input.node):
if inode.name == name:
return i
return -1


def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs):
"""
Inserts a node before one node input.
:param model: onnx graph
:param op_type:
:param node: node or node name
:param input_index: input index or input name
:param attrs: node attributes
:return: updated graph
"""
if isinstance(node, str):
inode = find_node_name(model, node)
else:
inode = node
if isinstance(input_index, str):
input_index_ = find_node_input_name(node, input_index)
if input_index_ == -1:
raise RuntimeError(
"Unable to find input_index %r in node %r." % (
input_index, node.name)) # pylint: disable=E1120
input_index = input_index_

# guess a new name
names = []
for n in model.graph.node:
names.extend(n.input)
names.extend(n.output)
names = set(names)
if new_name is None:
new_name = op_type.lower()
root_name = new_name
i = 0
while new_name in names:
new_name = "%s_%d" % (root_name, i)
i += 1

new_node = helper.make_node(
op_type, [inode.input[input_index]], [new_name], **attrs)
inode.input[input_index] = new_name
keep_nodes = list(model.graph.node)
keep_nodes.append(new_node)

graph = helper.make_graph(
keep_nodes, model.graph.name, model.graph.input,
model.graph.output, model.graph.initializer)
onnx_model = helper.make_model(graph)
onnx_model.ir_version = model.ir_version
onnx_model.producer_name = model.producer_name
onnx_model.producer_version = model.producer_version
onnx_model.domain = model.domain
onnx_model.model_version = model.model_version
onnx_model.doc_string = model.doc_string
if len(model.metadata_props) > 0:
values = {p.key: p.value for p in model.metadata_props}
helper.set_model_props(onnx_model, values)

del onnx_model.opset_import[:] # pylint: disable=E1101
for oimp in model.opset_import:
op_set = onnx_model.opset_import.add() # pylint: disable=E1101
op_set.domain = oimp.domain
op_set.version = oimp.version

if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101
raise RuntimeError( # pragma: no cover
"Input mismatch {} != {}".format(
len(onnx_model.input), len(model.input))) # pylint: disable=E1101
return onnx_model