From 3f4c1bfd8d32c9d3a71a6d13b851c257b8b4c097 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 13 Sep 2021 16:31:17 +0200 Subject: [PATCH] Add function to rename all results in ONNX graphs --- _doc/sphinxdoc/source/api/tools.rst | 8 +- .../test_onnx_conv_graph_optimisation.py | 66 +++++++ .../ut_tools/test_onnx_manipulations.py | 138 +++++++++++++- mlprodict/onnx_conv/convert.py | 12 +- mlprodict/onnx_tools/onnx_manipulations.py | 179 ++++++++++++++++++ .../optim/onnx_optimisation_identity.py | 3 +- .../optim/onnx_optimisation_redundant.py | 3 +- 7 files changed, 398 insertions(+), 11 deletions(-) create mode 100644 _unittests/ut_onnx_conv/test_onnx_conv_graph_optimisation.py diff --git a/_doc/sphinxdoc/source/api/tools.rst b/_doc/sphinxdoc/source/api/tools.rst index e2fb4c1ac..578b46973 100644 --- a/_doc/sphinxdoc/source/api/tools.rst +++ b/_doc/sphinxdoc/source/api/tools.rst @@ -40,9 +40,7 @@ Functions to help understand models or modify them. .. autosignature:: mlprodict.onnx_tools.model_checker.onnx_shaker -.. autosignature:: mlprodict.onnx_tools.optimisation._main_onnx_optim.onnx_optimisations - -.. autosignature:: mlprodict.onnx_tools.optim.onnx_statistics +.. autosignature:: mlprodict.onnx_tools.optim.onnx_helper.onnx_statistics .. autosignature:: mlprodict.onnx_tools.onnx_manipulations.select_model_inputs_outputs @@ -59,8 +57,12 @@ is left unchanged. .. autosignature:: mlprodict.onnx_tools.onnx_tools.ensure_topological_order +.. autosignature:: mlprodict.onnx_tools.onnx_manipulations.onnx_rename_names + .. autosignature:: mlprodict.onnx_tools.optim.onnx_optimisation.onnx_remove_node +.. autosignature:: mlprodict.onnx_tools.optimisation._main_onnx_optim.onnx_optimisations + .. autosignature:: mlprodict.onnx_tools.optim.onnx_optimisation_identity.onnx_remove_node_identity .. autosignature:: mlprodict.onnx_tools.optim.onnx_optimisation_redundant.onnx_remove_node_redundant diff --git a/_unittests/ut_onnx_conv/test_onnx_conv_graph_optimisation.py b/_unittests/ut_onnx_conv/test_onnx_conv_graph_optimisation.py new file mode 100644 index 000000000..03950bdd2 --- /dev/null +++ b/_unittests/ut_onnx_conv/test_onnx_conv_graph_optimisation.py @@ -0,0 +1,66 @@ +""" +@brief test log(time=3s) +""" +from collections import OrderedDict +import unittest +import numpy +from pyquickhelper.pycode import ExtTestCase, ignore_warnings +from sklearn.datasets import load_iris +from sklearn.neighbors import KNeighborsRegressor +from sklearn.metrics import make_scorer +from mlprodict.onnx_conv import to_onnx +from mlprodict.onnxrt import OnnxInference +from mlprodict.tools.asv_options_helper import ( + get_opset_number_from_onnx) +from mlprodict.onnx_conv.scorers.cdist_score import score_cdist_sum + + +class TestOnnxConvGraphOptimisation(ExtTestCase): + + def test_to_onnx_rename_names(self): + data = load_iris() + X, y = data.data, data.target + model = KNeighborsRegressor(n_neighbors=2).fit(X, y) + + model_onnx = to_onnx( + model, X[:1], target_opset=get_opset_number_from_onnx()) + oinf1 = OnnxInference(model_onnx) + y1 = oinf1.run({'X': X})['variable'] + + model_onnx = to_onnx( + model, X[:1], target_opset=get_opset_number_from_onnx(), + rename_strategy='simple') + oinf1 = OnnxInference(model_onnx) + y2 = oinf1.run({'X': X})['variable'] + self.assertEqualArray(y1, y2) + + @ignore_warnings((DeprecationWarning, UserWarning)) + def test_to_onnx_rename_names_scorer(self): + X = numpy.array([[0, 1, 0, 2], + [1, 0, 4, 5], + [9, 8, 5, 6]], dtype=numpy.float64) + Y = X[:2].copy() + Y[0, :] = 0 + + init_types = OrderedDict([('X', X), ('Y', Y)]) + opset = get_opset_number_from_onnx() + scorer = make_scorer( + score_cdist_sum, metric='sqeuclidean', + greater_is_better=False) + + monx1 = to_onnx(scorer, init_types, target_opset=opset, + rewrite_ops=True) + monx2 = to_onnx(scorer, init_types, target_opset=opset, + rewrite_ops=True, rename_strategy='simple') + + oinf1 = OnnxInference(monx1) + oinf2 = OnnxInference(monx2) + res0 = score_cdist_sum(X, Y, metric='sqeuclidean') + res1 = oinf1.run({'X': X, 'Y': Y})['scores'] + res2 = oinf2.run({'X': X, 'Y': Y})['scores'] + self.assertEqualArray(res1, res0, decimal=5) + self.assertEqualArray(res2, res0, decimal=5) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_tools/test_onnx_manipulations.py b/_unittests/ut_tools/test_onnx_manipulations.py index e30c5c9b2..b96bfe0d0 100644 --- a/_unittests/ut_tools/test_onnx_manipulations.py +++ b/_unittests/ut_tools/test_onnx_manipulations.py @@ -2,19 +2,23 @@ @brief test log(time=2s) """ import unittest +from collections import OrderedDict import numpy from pyquickhelper.pycode import ExtTestCase from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 - OnnxAdd, OnnxMul, OnnxSub) + OnnxAdd, OnnxMul, OnnxSub, OnnxIdentity, OnnxScan, + OnnxReduceSumSquare, OnnxSqueezeApi11) +from skl2onnx.common.data_types import FloatTensorType from mlprodict.onnx_tools.optim.onnx_helper import onnx_statistics from mlprodict.onnxrt import OnnxInference from mlprodict.onnx_tools.optim import onnx_remove_node_unused from mlprodict.onnx_tools.onnx_manipulations import ( - select_model_inputs_outputs, enumerate_model_node_outputs) + select_model_inputs_outputs, enumerate_model_node_outputs, + onnx_rename_names) from mlprodict.tools import get_opset_number_from_onnx -class TestOptimOnnxUnused(ExtTestCase): +class TestOptimOnnxManipulations(ExtTestCase): def test_onnx_remove_unused_outputs(self): dtype = numpy.float32 @@ -202,6 +206,134 @@ def test_enumerate_model_node_outputs(self): expected = ['Ad_Addcst2', 'Ad_C0', 'inter', 'Ad_C02', 'Mu_C0', 'final'] self.assertEqual(nodes2, expected) + def test_onnx_rename_names_exc(self): + dtype = numpy.float32 + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop2 = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop3 = OnnxAdd('X', numpy.array([2], dtype=dtype), + op_version=get_opset_number_from_onnx(), + output_names=['inter']) + cop4 = OnnxSub( + OnnxMul(cop, cop3, op_version=get_opset_number_from_onnx()), + cop2, output_names=['final'], + op_version=get_opset_number_from_onnx()) + model_def = cop4.to_onnx({'X': x}) + self.assertRaise( + lambda: onnx_rename_names(model_def, strategy="none"), + ValueError) + + def test_onnx_rename_names_simple(self): + rows = [] + + def flog(*s): + rows.append(" ".join(map(str, s))) + + dtype = numpy.float32 + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop2 = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop3 = OnnxAdd('X', numpy.array([2], dtype=dtype), + op_version=get_opset_number_from_onnx(), + output_names=['inter']) + cop4 = OnnxSub( + OnnxMul(cop, cop3, op_version=get_opset_number_from_onnx()), + cop2, output_names=['final'], + op_version=get_opset_number_from_onnx()) + model_def = cop4.to_onnx({'X': x}) + oinf1 = OnnxInference(model_def) + new_model = onnx_rename_names(model_def, verbose=1, fLOG=flog) + total = "\n".join(rows) + self.assertIn("[onnx_rename_names] 'Ad_Addcst1' -> 'i1'", total) + oinf2 = OnnxInference(new_model) + y1 = oinf1.run({'X': x}) + y2 = oinf2.run({'X': x}) + self.assertEqualArray(y1['final'], y2['final']) + + def test_onnx_rename_names_type(self): + rows = [] + + def flog(*s): + rows.append(" ".join(map(str, s))) + + dtype = numpy.float32 + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop2 = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop3 = OnnxAdd('X', numpy.array([2], dtype=dtype), + op_version=get_opset_number_from_onnx(), + output_names=['inter']) + cop4 = OnnxSub( + OnnxMul(cop, cop3, op_version=get_opset_number_from_onnx()), + cop2, output_names=['final'], + op_version=get_opset_number_from_onnx()) + model_def = cop4.to_onnx({'X': x}) + oinf1 = OnnxInference(model_def) + new_model = onnx_rename_names( + model_def, verbose=1, fLOG=flog, strategy='type') + total = "\n".join(rows) + self.assertIn("'Ad_Addcst' -> 'i_05'", total) + oinf2 = OnnxInference(new_model) + y1 = oinf1.run({'X': x}) + y2 = oinf2.run({'X': x}) + self.assertEqualArray(y1['final'], y2['final']) + + def test_onnx_rename_node_scan(self): + + def squareform_pdist(X, **kwargs): + opv = get_opset_number_from_onnx() + diff = OnnxSub('next_in', 'next', output_names=[ + 'diff'], op_version=opv) + id_next = OnnxIdentity('next_in', output_names=[ + 'next_out'], op_version=opv) + norm = OnnxReduceSumSquare( + diff, output_names=['norm'], axes=[1], op_version=opv) + flat = OnnxSqueezeApi11( + norm, output_names=['scan_out'], axes=[1], op_version=opv) + scan_body = id_next.to_onnx( + OrderedDict([('next_in', FloatTensorType()), + ('next', FloatTensorType())]), + outputs=[('next_out', FloatTensorType([None, None])), + ('scan_out', FloatTensorType([None]))], + other_outputs=[flat]) + + node = OnnxScan(X, X, output_names=['scan0_{idself}', 'scan1_{idself}'], + num_scan_inputs=1, body=scan_body.graph, op_version=opv, + **kwargs) + return node[1] + + rows = [] + + def flog(*s): + rows.append(" ".join(map(str, s))) + + opv = get_opset_number_from_onnx() + onnx_fct = OnnxIdentity(squareform_pdist( + 'x'), output_names='Y', op_version=opv) + model_def = onnx_fct.to_onnx(inputs=[('x', FloatTensorType())]) + + oinf1 = OnnxInference(model_def) + new_model = onnx_rename_names( + model_def, verbose=1, fLOG=flog, strategy='type') + total = "\n".join(rows) + self.assertNotIn('name: "Re_ReduceSumSquare"', str(new_model)) + self.assertIn("'Re_ReduceSumSquare' -> 'n_24'", total) + oinf2 = OnnxInference(new_model) + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + y1 = oinf1.run({'x': x}) + y2 = oinf2.run({'x': x}) + self.assertEqualArray(y1['Y'], y2['Y']) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/onnx_conv/convert.py b/mlprodict/onnx_conv/convert.py index 9aa9ca40d..6cefa461c 100644 --- a/mlprodict/onnx_conv/convert.py +++ b/mlprodict/onnx_conv/convert.py @@ -19,6 +19,7 @@ from skl2onnx import convert_sklearn from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin from skl2onnx.algebra.type_helper import _guess_type +from ..onnx_tools.onnx_manipulations import onnx_rename_names from .register_rewritten_converters import register_rewritten_operators from .register import register_converters from .scorers import CustomScorerTransform @@ -243,7 +244,7 @@ def guess_schema_from_model(model, tensor_type=None, schema=None): def to_onnx(model, X=None, name=None, initial_types=None, target_opset=None, options=None, rewrite_ops=False, white_op=None, black_op=None, final_types=None, - verbose=0): + rename_strategy=None, verbose=0): """ Converts a model using on :epkg:`sklearn-onnx`. @@ -269,6 +270,8 @@ def to_onnx(model, X=None, name=None, initial_types=None, initial_types but not mandatory, it is used to overwrites the type (if type is not None) and the name of every output. + :param rename_strategy: rename any name in the graph, select shorter + names, see @see fn onnx_rename_names :param verbose: display information while converting the model :return: converted model @@ -348,6 +351,9 @@ def to_onnx(model, X=None, name=None, initial_types=None, onxp = oinf.run(inputs) print(onxp) + + .. versionchanged:: 0.7 + Parameter *rename_strategy* was added. """ if isinstance(model, OnnxOperatorMixin): if not hasattr(model, 'op_version'): @@ -435,4 +441,8 @@ def _guess_type_(X, itype, dtype): final_types=final_types, verbose=verbose) register_rewritten_operators(old_values, old_shapes) + + # optimisation + if rename_strategy is not None: + res = onnx_rename_names(res, strategy=rename_strategy) return res diff --git a/mlprodict/onnx_tools/onnx_manipulations.py b/mlprodict/onnx_tools/onnx_manipulations.py index 06db62629..88a2e4446 100644 --- a/mlprodict/onnx_tools/onnx_manipulations.py +++ b/mlprodict/onnx_tools/onnx_manipulations.py @@ -3,6 +3,7 @@ @brief Implements a class able to compute the predictions from on an :epkg:`ONNX` model. """ +import hashlib from onnx import helper, shape_inference from .onnx2py_helper import guess_proto_dtype from .optim import onnx_remove_node_unused @@ -303,3 +304,181 @@ def overwrite_opset(model, new_opset): op_set.domain = oimp.domain op_set.version = oimp.version return onnx_model + + +def hash_onnx_object(obj, max_size): + """ + Hash the content of an object. + """ + m = hashlib.sha256() + if hasattr(obj, 'op_type'): + # An operator. + m.update(obj.op_type.encode('ascii')) + m.update(str(len(obj.input)).encode('ascii')) + m.update(str(len(obj.output)).encode('ascii')) + if hasattr(obj, 'attribute'): + for att in obj.attribute: + m.update(att.name.encode('ascii')) + m.update(att.SerializeToString()) + else: + # An initializer. + name = obj.name + docf = obj.doc_string + obj.name = '' + obj.doc_string = '' + try: + m.update(obj.SerializeToString()) + except AttributeError as e: + raise RuntimeError( + "Unable to hash object type %r, value=%r." + "" % (type(obj), obj)) from e + finally: + obj.name = name + obj.doc_string = docf + + content = m.hexdigest() + if len(content) > max_size: + content = content[:max_size] + return content.upper() + + +def onnx_rename_names(model, strategy='simple', recursive=True, + verbose=0, fLOG=print, + counts=None, replace=None, taken=None): + """ + Renames all names except the inputs and outputs. + + :param model: onnx model + :param strategy: two strategies are implemented, see below + :param recursive: walk through subgraphs + :param verbose: verbose, if positive, reports on all changed names + :param fLOG: logging function + :param counts: used for recursion + :param replace: used for recursion + :param taken: used for recursion + :return: onnx model (the model is modified in place) + + Strategies: + * `'simple'`: use a letter `n` for node, `r`, `i` for initializer, + this letter is followed by a number + * `'type'`: the name depends on the node type and content, + the hash is kept as small as possible + """ + counts = counts or {'init': 0, 'node': 0, 'result': 0} + replace = replace or {} + taken = taken or set() + graph = model.graph if hasattr(model, 'graph') else model + + for obj in graph.input: + replace[obj.name] = obj.name + for obj in graph.output: + replace[obj.name] = obj.name + + def _check_name_simple(prefix): + if prefix not in replace: + return prefix + c = 1 + final = "%s_%d" % (prefix, c) + while final in taken: + c += 1 + final = "%s_%d" % (prefix, c) + taken.add(final) + return final + + def _check_name_type(obj, prefix): + c = 2 + hash = hash_onnx_object(obj, c) + final = "%s_%s" % (prefix, hash) + while final in taken: + c += 2 + hash = hash_onnx_object(obj, c) + final = "%s_%s" % (prefix, hash) + taken.add(final) + return final + + def get_name_init(init): + if init.name in replace: + return replace[init.name] + if strategy == 'simple': + name = _check_name_simple('i%d' % counts['init']) + counts['init'] += 1 + replace[init.name] = name + if verbose > 0 and fLOG is not None: + fLOG('[onnx_rename_names] %r -> %r' % (init.name, name)) + return name + if strategy == 'type': + name = _check_name_type(init, 'i') + counts['init'] += 1 + replace[init.name] = name + if verbose > 0 and fLOG is not None: + fLOG('[onnx_rename_names] %r -> %r' % (init.name, name)) + return name + raise ValueError( # pragma: no cover + "Unknown strategy %r." % strategy) + + def get_name_node(node): + if node.name in replace: + return replace[node.name] + if strategy == 'simple': + name = _check_name_simple('n%d' % counts['node']) + counts['node'] += 1 + replace[node.name] = name + if verbose > 0 and fLOG is not None: + fLOG('[onnx_rename_names] %r -> %r' % (node.name, name)) + return name + if strategy == 'type': + name = _check_name_type(node, 'n') + counts['node'] += 1 + replace[node.name] = name + if verbose > 0 and fLOG is not None: + fLOG('[onnx_rename_names] %r -> %r' % (node.name, name)) + return name + raise ValueError( # pragma: no cover + "Unknown strategy %r." % strategy) + + def get_name_result(node, i, name, suffix): + if name in replace: + return replace[name] + if strategy == 'simple': + new_name = _check_name_simple('r%d' % counts['result']) + counts['result'] += 1 + replace[name] = new_name + if verbose > 0 and fLOG is not None: + fLOG('[onnx_rename_names] %r -> %r' % (name, new_name)) + return new_name + if strategy == 'type': + new_name = _check_name_type(node, 'r%s%d' % (suffix, i)) + counts['result'] += 1 + replace[name] = new_name + if verbose > 0 and fLOG is not None: + fLOG('[onnx_rename_names] %r -> %r' % (name, new_name)) + return new_name + raise ValueError( # pragma: no cover + "Unknown strategy %r." % strategy) + + def get_name_input(node, i): + return get_name_result(node, i, node.input[i], 'i') + + def get_name_output(node, i): + return get_name_result(node, i, node.output[i], 'o') + + for init in graph.initializer: + init.name = get_name_init(init) + + for node in graph.node: + node.name = get_name_node(node) + for i in range(len(node.input)): # pylint: disable=C0200 + node.input[i] = get_name_input(node, i) + for i in range(len(node.output)): # pylint: disable=C0200 + node.output[i] = get_name_output(node, i) + if not recursive or node.op_type not in {'Scan', 'If', 'Loop'}: + continue + # recursion + for att in node.attribute: + if att.name not in {'if_branch', 'else_branch', 'body'}: + continue + onnx_rename_names( + att.g, strategy=strategy, fLOG=fLOG, verbose=verbose, + counts=counts, replace=replace, taken=taken) + + return model diff --git a/mlprodict/onnx_tools/optim/onnx_optimisation_identity.py b/mlprodict/onnx_tools/optim/onnx_optimisation_identity.py index f721b0aaa..33fa634ea 100644 --- a/mlprodict/onnx_tools/optim/onnx_optimisation_identity.py +++ b/mlprodict/onnx_tools/optim/onnx_optimisation_identity.py @@ -7,8 +7,7 @@ _rename_node_input, _rename_node_output, _apply_optimisation_on_graph, - _apply_remove_node_fct_node -) + _apply_remove_node_fct_node) def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options): diff --git a/mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py b/mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py index 4c91fe92c..c3e435dd8 100644 --- a/mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py +++ b/mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py @@ -9,8 +9,7 @@ _rename_node_input, _rename_node_output, _apply_optimisation_on_graph, - _apply_remove_node_fct_node -) + _apply_remove_node_fct_node) def _hash_obj_content(obj, max_size=1000):