diff --git a/_unittests/ut_onnxrt/test_onnx_tools.py b/_unittests/ut_onnxrt/test_onnx_tools.py new file mode 100644 index 000000000..ab398679c --- /dev/null +++ b/_unittests/ut_onnxrt/test_onnx_tools.py @@ -0,0 +1,56 @@ +""" +@brief test log(time=2s) +""" +import unittest +from logging import getLogger +import numpy +from onnx import helper, TensorProto +from pyquickhelper.pycode import ExtTestCase, ignore_warnings +from mlprodict.onnxrt import OnnxInference +from mlprodict.onnxrt.onnx_tools import insert_node +from mlprodict.onnxrt.ops_cpu._op import RuntimeTypeError + + +class TestOnnxTools(ExtTestCase): + + def setUp(self): + logger = getLogger('skl2onnx') + logger.disabled = True + + @ignore_warnings(DeprecationWarning) + def test_onnx_inference_name_confusion(self): + X = helper.make_tensor_value_info( + 'X', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101 + Y = helper.make_tensor_value_info( + 'Y', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101 + Z = helper.make_tensor_value_info( + 'Z', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101 + node_def = helper.make_node('Add', ['X', 'Y'], ['Zt'], name='Zt') + node_def2 = helper.make_node('Add', ['X', 'Zt'], ['Z'], name='Z') + graph_def = helper.make_graph( + [node_def, node_def2], 'test-model', [X, Y], [Z]) + model_def = helper.make_model(graph_def, producer_name='onnx-example') + model_def = insert_node( + model_def, node='Z', op_type='Cast', to=TensorProto.INT64, # pylint: disable=E1101 + name='castop') + self.assertIn('castop', str(model_def)) + + oinf = OnnxInference(model_def) + X = (numpy.random.randn(4, 2) * 100000).astype( # pylint: disable=E1101 + numpy.float32) + Y = (numpy.random.randn(4, 2) * 100000).astype( # pylint: disable=E1101 + numpy.float32) + exp = (X * 2 + Y).astype(numpy.float32) + self.assertRaise(lambda: oinf.run({'X': X, 'Y': Y}), RuntimeTypeError) + + model_def = insert_node( + model_def, node='Z', op_type='Cast', to=TensorProto.FLOAT, # pylint: disable=E1101 + name='castop2') + oinf = OnnxInference(model_def) + res = oinf.run({'X': X, 'Y': Y}) + got = res['Z'] + self.assertEqualArray(exp / 100000, got / 100000, decimal=5) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/onnxrt/onnx_tools.py b/mlprodict/onnxrt/onnx_tools.py new file mode 100644 index 000000000..2a7c222f2 --- /dev/null +++ b/mlprodict/onnxrt/onnx_tools.py @@ -0,0 +1,104 @@ +""" +@file +@brief Functions to manipulate ONNX file. +""" +from onnx import helper + + +def find_node_name(model, name): + """ + Finds a node by its name. + :param model: onnx graph + :param name: node name + :return: node pointer + """ + if not hasattr(model, "graph"): + raise TypeError( # pragma: no cover + "Parameter model is not an ONNX model but " + "{}".format(type(model))) + for node in model.graph.node: + if node.name == name: + return node + return None + + +def find_node_input_name(node, name): + """ + Finds a node input by its name. + :param node: onnx node + :param name: node name + :return: input index + """ + for i, inode in enumerate(node.input.node): + if inode.name == name: + return i + return -1 + + +def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs): + """ + Inserts a node before one node input. + :param model: onnx graph + :param op_type: + :param node: node or node name + :param input_index: input index or input name + :param attrs: node attributes + :return: updated graph + """ + if isinstance(node, str): + inode = find_node_name(model, node) + else: + inode = node + if isinstance(input_index, str): + input_index_ = find_node_input_name(node, input_index) + if input_index_ == -1: + raise RuntimeError( + "Unable to find input_index %r in node %r." % ( + input_index, node.name)) # pylint: disable=E1120 + input_index = input_index_ + + # guess a new name + names = [] + for n in model.graph.node: + names.extend(n.input) + names.extend(n.output) + names = set(names) + if new_name is None: + new_name = op_type.lower() + root_name = new_name + i = 0 + while new_name in names: + new_name = "%s_%d" % (root_name, i) + i += 1 + + new_node = helper.make_node( + op_type, [inode.input[input_index]], [new_name], **attrs) + inode.input[input_index] = new_name + keep_nodes = list(model.graph.node) + keep_nodes.append(new_node) + + graph = helper.make_graph( + keep_nodes, model.graph.name, model.graph.input, + model.graph.output, model.graph.initializer) + onnx_model = helper.make_model(graph) + onnx_model.ir_version = model.ir_version + onnx_model.producer_name = model.producer_name + onnx_model.producer_version = model.producer_version + onnx_model.domain = model.domain + onnx_model.model_version = model.model_version + onnx_model.doc_string = model.doc_string + if len(model.metadata_props) > 0: + values = {p.key: p.value for p in model.metadata_props} + helper.set_model_props(onnx_model, values) + + del onnx_model.opset_import[:] # pylint: disable=E1101 + for oimp in model.opset_import: + op_set = onnx_model.opset_import.add() # pylint: disable=E1101 + op_set.domain = oimp.domain + op_set.version = oimp.version + + if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101 + raise RuntimeError( # pragma: no cover + "Input mismatch {} != {}".format( + len(onnx_model.input), len(model.input))) # pylint: disable=E1101 + return onnx_model