From e42ce4072be886132f4647613a7c5600751902fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 21 Feb 2022 12:45:38 +0100 Subject: [PATCH 1/6] Supports domains in Xop API --- _doc/sphinxdoc/source/conf.py | 1 - _unittests/ut_npy/test_xop.py | 157 +++++++++++++++++++++++++--------- mlprodict/npy/xop.py | 137 ++++++++++++++++++++--------- 3 files changed, 215 insertions(+), 80 deletions(-) diff --git a/_doc/sphinxdoc/source/conf.py b/_doc/sphinxdoc/source/conf.py index 312f5fb08..d454d1e0c 100644 --- a/_doc/sphinxdoc/source/conf.py +++ b/_doc/sphinxdoc/source/conf.py @@ -261,4 +261,3 @@ 'https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md' '#ai.onnx.ml.TreeEnsembleRegressor', }) - diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index 86b2bea24..bc3ce5c17 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -25,20 +25,39 @@ def test_float32(self): self.assertEqual(numpy.float32, numpy.dtype('float32')) def test_impossible(self): - cl = loadop("OnnxAdd") + cl = loadop("Add") self.assertEqual(cl.__name__, "OnnxAdd") - cl = loadop("OnnxCast") + cl = loadop("Cast") self.assertEqual(cl.__name__, "OnnxCast") cl = loadop("Cast_13") self.assertEqual(cl.__name__, "OnnxCast_13") - cl = loadop("OnnxCast_13") + cl = loadop("Cast_13") self.assertEqual(cl.__name__, "OnnxCast_13") - self.assertRaise(lambda: loadop("OnnxImpossible"), ValueError) - self.assertRaise(lambda: loadop("OnnxImpossible_1"), ValueError) - self.assertRaise(lambda: loadop("OnnxCast_9999"), ValueError) + self.assertRaise(lambda: loadop("OnnxCast"), ValueError) + self.assertRaise(lambda: loadop("Impossible"), ValueError) + self.assertRaise(lambda: loadop("Impossible_1"), ValueError) + self.assertRaise(lambda: loadop("Cast_9999"), ValueError) def test_onnx_abs(self): - OnnxAbs = loadop("OnnxAbs") + OnnxAbs = loadop("Abs") + ov = OnnxAbs('X', output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x), got['Y']) + + def test_onnx_abs_domain(self): + OnnxAbs = loadop(("", "Abs")) + ov = OnnxAbs('X', output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x), got['Y']) + + def test_onnx_abs_domain_ai(self): + OnnxAbs = loadop(("ai.onnx", "Abs")) ov = OnnxAbs('X', output_names=['Y']) onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) oinf = OnnxInference(onx) @@ -56,7 +75,7 @@ def test_onnx_add(self): self.assertEqualArray(x + x, got['Y']) def test_onnx_add_cst(self): - OnnxAdd = loadop("OnnxAdd") + OnnxAdd = loadop("Add") ov = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), output_names=['Y']) onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) @@ -72,7 +91,7 @@ def test_number2alpha(self): self.assertEqual(sel, sel2) def test_onnx_add_sub_left(self): - OnnxAdd, OnnxSub = loadop("OnnxAdd", "OnnxSub") + OnnxAdd, OnnxSub = loadop("Add", "Sub") self.assertEqual(OnnxAdd.operator_name, 'Add') self.assertEqual(OnnxSub.operator_name, 'Sub') ov = OnnxAdd('X', 'X') @@ -84,7 +103,7 @@ def test_onnx_add_sub_left(self): self.assertEqualArray(x, got['Y']) def test_onnx_add_sub_right(self): - OnnxAdd, OnnxSub = loadop("OnnxAdd", "OnnxSub") + OnnxAdd, OnnxSub = loadop("Add", "Sub") self.assertEqual(OnnxAdd.operator_name, 'Add') self.assertEqual(OnnxSub.operator_name, 'Sub') ov = OnnxAdd('X', 'X') @@ -96,7 +115,7 @@ def test_onnx_add_sub_right(self): self.assertEqualArray(-x, got['Y']) def test_onnx_transpose(self): - OnnxTranspose = loadop("OnnxTranspose") + OnnxTranspose = loadop("Transpose") ov = OnnxTranspose('X', perm=[1, 0], output_names=['Y']) onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) self.assertIn('perm', str(onx)) @@ -106,7 +125,7 @@ def test_onnx_transpose(self): self.assertEqualArray(x.T, got['Y']) def test_onnx_transpose3(self): - OnnxTranspose = loadop("OnnxTranspose") + OnnxTranspose = loadop("Transpose") ov = OnnxTranspose('X', perm=[1, 0, 2], output_names=['Y']) onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) self.assertIn('perm', str(onx)) @@ -116,7 +135,7 @@ def test_onnx_transpose3(self): self.assertEqualArray(numpy.transpose(x, axes=(1, 0, 2)), got['Y']) def test_onnx_cast(self): - OnnxCast = loadop("OnnxCast") + OnnxCast = loadop("Cast") ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) onx = ov.to_onnx(numpy.float32, numpy.int64, verbose=0) self.assertIn('to', str(onx)) @@ -126,7 +145,7 @@ def test_onnx_cast(self): self.assertEqualArray(x.astype(numpy.int64), got['Y']) def test_onnx_dict(self): - OnnxCast = loadop("OnnxCast") + OnnxCast = loadop("Cast") ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) onx = ov.to_onnx({'X': numpy.float32}, {'Y': numpy.int64}, verbose=0) self.assertIn('to', str(onx)) @@ -136,7 +155,7 @@ def test_onnx_dict(self): self.assertEqualArray(x.astype(numpy.int64), got['Y']) def test_onnx_var(self): - OnnxCast = loadop("OnnxCast") + OnnxCast = loadop("Cast") ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) onx = ov.to_onnx(Variable('X', numpy.float32), Variable('Y', numpy.float32), verbose=0) @@ -147,7 +166,7 @@ def test_onnx_var(self): self.assertEqualArray(x.astype(numpy.int64), got['Y']) def test_onnx_var_list(self): - OnnxCast = loadop("OnnxCast") + OnnxCast = loadop("Cast") ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) onx = ov.to_onnx([Variable('X', numpy.float32)], [Variable('Y', numpy.float32)], verbose=0) @@ -159,7 +178,7 @@ def test_onnx_var_list(self): def test_if(self): OnnxConstant, OnnxIf, OnnxGreater = loadop( - "OnnxConstant", "OnnxIf", "OnnxGreater") + "Constant", "If", "Greater") bthen = OnnxConstant( value_floats=numpy.array([0], dtype=numpy.float32), output_names=['res_then']) @@ -195,7 +214,7 @@ def test_if(self): def test_if2(self): OnnxAdd, OnnxSub, OnnxIf, OnnxGreater, OnnxReduceSum = loadop( - "OnnxAdd", "OnnxSub", "OnnxIf", "OnnxGreater", "OnnxReduceSum") + "Add", "Sub", "If", "Greater", "ReduceSum") node = OnnxAdd('x1', 'x2', output_names=['absxythen']) then_body = node.to_onnx( @@ -223,7 +242,7 @@ def test_if2(self): self.assertIn("out_red0 -> _greater;", dot) def test_onnx_abs_shape_variable(self): - OnnxAbs = loadop("OnnxAbs") + OnnxAbs = loadop("Abs") ov = OnnxAbs('X', output_names=['Y']) onx = ov.to_onnx([Variable('X', numpy.float32, [1, 2])], [Variable('Y', numpy.float32, [1, 2])], @@ -241,7 +260,7 @@ def test_onnx_abs_shape_variable(self): self.assertEqual(shape, (1, 2)) def test_onnx_abs_shape_variable_batch(self): - OnnxAbs = loadop("OnnxAbs") + OnnxAbs = loadop("Abs") ov = OnnxAbs('X', output_names=['Y']) onx = ov.to_onnx([Variable('X', numpy.float32, [None, 2])], [Variable('Y', numpy.float32, [None, 2])], @@ -258,7 +277,7 @@ def test_onnx_abs_shape_variable_batch(self): self.assertEqual(shape, (None, 2)) def test_onnx_abs_shape_numpy(self): - OnnxAbs = loadop("OnnxAbs") + OnnxAbs = loadop("Abs") ov = OnnxAbs('X', output_names=['Y']) x = numpy.array([-2, 2], dtype=numpy.float32) onx = ov.to_onnx({'X': x}, {'Y': x}, verbose=0) @@ -364,7 +383,7 @@ def test_syntax_onnx(self): self.assertEqualArray(y, numpy.array([[[2]]], dtype=numpy.float32)) def test_onnx_abs_undefined(self): - OnnxAbs = loadop("OnnxAbs") + OnnxAbs = loadop("Abs") ov = OnnxAbs('X', output_names=['Y']) onx = ov.to_onnx(numpy.float32, verbose=0) oinf = OnnxInference(onx) @@ -377,7 +396,7 @@ def test_onnx_abs_undefined(self): self.assertEqualArray(numpy.abs(x), got['Y']) def test_onnx_add_sub_left_undefined(self): - OnnxAdd, OnnxSub = loadop("OnnxAdd", "OnnxSub") + OnnxAdd, OnnxSub = loadop("Add", "Sub") self.assertEqual(OnnxAdd.operator_name, 'Add') self.assertEqual(OnnxSub.operator_name, 'Sub') ov = OnnxAdd('X', 'X') @@ -396,7 +415,7 @@ def test_onnx_add_sub_left_undefined(self): def test_topk_classic(self): opv = max_supported_opset() - OnnxIdentity, OnnxTopK = loadop("OnnxIdentity", "OnnxTopK") + OnnxIdentity, OnnxTopK = loadop("Identity", "TopK") X = numpy.array([[0, 1, 2, 3, 4], [1, -1, -2, 4, 5], [2, -2, -3, 5, -4]], @@ -422,7 +441,7 @@ def test_topk_classic(self): def test_topk_iter(self): opv = max_supported_opset() - OnnxIdentity, OnnxTopK = loadop("OnnxIdentity", "OnnxTopK") + OnnxIdentity, OnnxTopK = loadop("Identity", "TopK") X = numpy.array([[0, 1, 2, 3, 4], [1, -1, -2, 4, 5], [2, -2, -3, 5, -4]], @@ -448,7 +467,7 @@ def test_topk_iter(self): self.assertEqualArray(exp, got['Yi']) def test_onnx_add_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity", verbose=0) ov = OnnxAbs('X') ovf = ov + ov last = OnnxIdentity(ovf, output_names=['Y']) @@ -458,8 +477,42 @@ def test_onnx_add_op(self): got = oinf.run({'X': x}) self.assertEqualArray(numpy.abs(x) * 2, got['Y']) + def test_onnx_add_op_onnxruntime(self): + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") + ov = OnnxAbs('X') + ovf = ov + ov + last = OnnxIdentity(ovf, output_names=['Y']) + onx = last.to_onnx(numpy.float32, numpy.float32, verbose=0) + + opv = max_supported_opset() + ov = OnnxAbs('X', op_version=opv) + ovf = ov + ov + last = OnnxIdentity(ovf, output_names=['Y'], op_version=opv) + onx = last.to_onnx(numpy.float32, numpy.float32, verbose=0, + target_opset=opv) + + oinf = OnnxInference(onx, runtime='onnxruntime1') + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x) * 2, got['Y']) + + def test_onnx_add_op_onnxruntime_specific(self): + OnnxAbs_13, OnnxIdentity_14 = loadop("Abs_13", "Identity_14") + + opv = max_supported_opset() + ov = OnnxAbs_13('X') + ovf = ov + ov + last = OnnxIdentity_14(ovf, output_names=['Y'], op_version=opv) + onx = last.to_onnx(numpy.float32, numpy.float32, verbose=0, + target_opset=opv) + + oinf = OnnxInference(onx, runtime='onnxruntime1') + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x) * 2, got['Y']) + def test_onnx_sub_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovf = ov + ov - ov last = OnnxIdentity(ovf, output_names=['Y']) @@ -470,7 +523,7 @@ def test_onnx_sub_op(self): self.assertEqualArray(numpy.abs(x), got['Y']) def test_onnx_mul_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovf = ov * ov last = OnnxIdentity(ovf, output_names=['Y']) @@ -481,7 +534,7 @@ def test_onnx_mul_op(self): self.assertEqualArray(numpy.abs(x) ** 2, got['Y']) def test_onnx_div_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovf = ov / (ov + ov) last = OnnxIdentity(ovf, output_names=['Y']) @@ -493,7 +546,7 @@ def test_onnx_div_op(self): self.assertEqualArray(a / (a + a), got['Y']) def test_onnx_pow_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovf = ov ** ov last = OnnxIdentity(ovf, output_names=['Y']) @@ -505,7 +558,7 @@ def test_onnx_pow_op(self): self.assertEqualArray(a ** a, got['Y']) def test_onnx_matmul_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovf = ov @ ov last = OnnxIdentity(ovf, output_names=['Y']) @@ -517,7 +570,7 @@ def test_onnx_matmul_op(self): self.assertEqualArray(a @ a, got['Y']) def test_onnx_greater_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovi = OnnxIdentity('X') ovf = ov > ovi @@ -530,7 +583,7 @@ def test_onnx_greater_op(self): self.assertEqualArray(a > x, got['Y']) def test_onnx_less_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovi = OnnxIdentity('X') ovf = ov < ovi @@ -543,7 +596,7 @@ def test_onnx_less_op(self): self.assertEqualArray(a < x, got['Y']) def test_onnx_equal_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovi = OnnxIdentity('X') ovf = ov == ovi @@ -556,7 +609,7 @@ def test_onnx_equal_op(self): self.assertEqualArray(a == x, got['Y']) def test_onnx_and_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovi = OnnxIdentity('X') ovf = (ov == ovi).and_(ov > ovi) @@ -569,7 +622,7 @@ def test_onnx_and_op(self): self.assertEqualArray(a == -10, got['Y']) def test_onnx_or_op(self): - OnnxAbs, OnnxIdentity = loadop("OnnxAbs", "OnnxIdentity") + OnnxAbs, OnnxIdentity = loadop("Abs", "Identity") ov = OnnxAbs('X') ovi = OnnxIdentity('X') ovf = (ov == ovi).or_(ov > ovi) @@ -582,7 +635,7 @@ def test_onnx_or_op(self): self.assertEqualArray(a >= x, got['Y']) def test_onnx_abs_op(self): - OnnxIdentity = loadop("OnnxIdentity") + OnnxIdentity = loadop("Identity") ovi = OnnxIdentity('X') ovf = abs(ovi) last = OnnxIdentity(ovf, output_names=['Y']) @@ -594,7 +647,7 @@ def test_onnx_abs_op(self): self.assertEqualArray(a, got['Y']) def test_onnx_not_op(self): - OnnxIdentity = loadop("OnnxIdentity") + OnnxIdentity = loadop("Identity") ovi = OnnxIdentity('X') ovf = (abs(ovi) == ovi).not_() last = OnnxIdentity(ovf, output_names=['Y']) @@ -606,7 +659,7 @@ def test_onnx_not_op(self): self.assertEqualArray(a != x, got['Y']) def test_onnx_mod_op(self): - OnnxIdentity = loadop("OnnxIdentity") + OnnxIdentity = loadop("Identity") ovi = OnnxIdentity('X') ovf = ovi % numpy.array([10], dtype=numpy.int64) last = OnnxIdentity(ovf, output_names=['Y']) @@ -616,7 +669,31 @@ def test_onnx_mod_op(self): got = oinf.run({'X': x}) self.assertEqualArray(x % 10, got['Y']) + def test_onnx_ml_operator(self): + OnnxNormalizer = loadop(('ai.onnx.ml', "Normalizer")) + self.assertEqual(OnnxNormalizer.__name__, + 'OnnxAiOnnxMlNormalizer') + last = OnnxNormalizer('X', norm='L1', output_names=['Y']) + onx = last.to_onnx(numpy.float32, numpy.float32, verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([[-2, 2], [0, 3]], dtype=numpy.float32) + got = oinf.run({'X': x}) + a = numpy.abs(x) + self.assertEqualArray(x / a.sum(axis=1, keepdims=True), got['Y']) + + def test_onnx_ml_operator_shortcut(self): + OnnxNormalizer = loadop("Normalizer") + self.assertEqual(OnnxNormalizer.__name__, + 'OnnxAiOnnxMlNormalizer') + last = OnnxNormalizer('X', norm='L1', output_names=['Y']) + onx = last.to_onnx(numpy.float32, numpy.float32, verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([[-2, 2], [0, 3]], dtype=numpy.float32) + got = oinf.run({'X': x}) + a = numpy.abs(x) + self.assertEqualArray(x / a.sum(axis=1, keepdims=True), got['Y']) + if __name__ == "__main__": - # TestXOps().test_topk_iter() + # TestXOps().test_onnx_add_op() unittest.main(verbosity=2) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index f35ccd688..da795c888 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -6,6 +6,7 @@ .. versionadded:: 0.9 """ import os +import pprint import numpy from scipy.sparse.coo import coo_matrix import onnx @@ -35,8 +36,22 @@ def _default_OPSET_TO_IR_VERSION(): return { 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7, - 13: 7, 14: 7, 15: 8 - } + 13: 7, 14: 7, 15: 8} + + +def _domain_to_class_name(domain): + if domain == 'ai.onnx': + return '' + dom = domain.split('.') + res = [] + for d in dom: + if len(d) == 0: + res.append(d) + elif len(d) == 1: + res.append(d.upper()) + else: + res.append(d[0].upper() + d[1:]) + return "".join(res) def _populate_schemas(): @@ -45,6 +60,7 @@ def _populate_schemas(): """ res = {} versions = {} + domains = {} for schema in onnx.defs.get_all_schemas_with_history(): if schema.support_level == schema.SupportType.EXPERIMENTAL: # Skips experimental operators. @@ -53,15 +69,39 @@ def _populate_schemas(): if schema.name in res: if schema.since_version > res[schema.name].since_version: # We keep the most recent one. - res[schema.name] = schema + res[schema.domain, schema.name] = schema else: - res[schema.name] = schema + res[schema.domain, schema.name] = schema full_name = schema.name + '_' + str(schema.since_version) - res[full_name] = schema - if schema.name not in versions: - versions[schema.name] = set() - versions[schema.name].add(full_name) - return res, versions + res[schema.domain, full_name] = schema + key = schema.domain, schema.name + if key not in versions: + versions[key] = set() + if schema.name not in domains: + domains[schema.name] = set() + domains[schema.name].add(schema.domain) + versions[key].add(full_name) + return res, versions, domains + + +def _find_operator_domain(name): + """ + Determines the domain of an operator. + Raises an exception if not found or if there is an ambiguity. + + :param name: operator name + :return: domain + """ + if name not in _all_domains: + raise ValueError( + "Unable to guess domain for operator %r. " + "Not found in %r." % (name, list(_all_domains))) + domains = _all_domains[name] + if len(domains) == 1: + return list(domains)[0] + raise ValueError( + "Unable to guess domain of operator %r, found domains %r." % ( + name, domains)) def ClassFactory(class_name, op_name, inputs, outputs, @@ -215,20 +255,40 @@ def _c(obj, label, i): if operator_names is None: operator_names = list(_all_schemas_versions) + # type verification + ops = [] + for name in operator_names: + if isinstance(name, str): + if name.startswith('Onnx'): + raise ValueError( + "Operator name cannot starts with Onnx: %r." % name) + domain = _find_operator_domain(name.split('_', maxsplit=1)[0]) + ops.append((domain, name)) + elif isinstance(name, tuple) and len(name) == 2: + if name[1].startswith('Onnx'): + raise ValueError( + "Operator name cannot starts with Onnx: %r." % name) + ops.append(name) + else: + raise ValueError( + "Operator to fetch must be a string or a " + "`tuple(domain, name)` not %r." % (name)) + operator_names = ops + + # versions res = _all_schemas cls = {} set_names = dict() set_skip = set() - for pos, op_name in enumerate(operator_names): - set_names[op_name] = pos + for pos, (op_domain, op_name) in enumerate(operator_names): + if op_domain == 'ai.onnx': + op_domain = '' + set_names[op_domain, op_name] = pos if '_' in op_name: n = op_name.split('_')[0] - if n.startswith('Onnx'): - set_skip.add(n) - else: - set_skip.add('Onnx' + n) + set_skip.add((op_domain, n)) if n not in set_names: - set_names[n] = -1 + set_names[op_domain, n] = -1 if verbose > 1 and fLOG is not None: fLOG("[_dynamic_class_creation] set_names=%r" % set_names) @@ -237,49 +297,49 @@ def _c(obj, label, i): returned_classes = [] positions = {} - for op_name, position in set_names.items(): - cl_name = op_name if op_name.startswith('Onnx') else 'Onnx' + op_name + for (op_domain, op_name), position in set_names.items(): + cl_name = 'Onnx' + _domain_to_class_name(op_domain) + op_name if verbose > 3 and fLOG is not None: - fLOG('[_dynamic_class_creation] cl_name=%r op_name=%r (in=%d)' % ( - cl_name, op_name, 1 if cl_name in _all_classes else 0)) + fLOG('[_dynamic_class_creation] cl_name=%r op_domain=%r op_name=%r (in=%d)' % ( + cl_name, op_domain, op_name, 1 if cl_name in _all_classes else 0)) if cl_name in _all_classes: if cl_name not in set_skip: if position >= 0: returned_classes.append((position, _all_classes[cl_name])) continue - name = op_name[4:] if op_name.startswith('Onnx') else op_name - name_keep = name - if '_' in name: - names = [name] + # operator name without domain + if '_' in op_name: + names = [op_name] else: try: - names = _all_schemas_versions[name].copy() + names = _all_schemas_versions[op_domain, op_name].copy() except KeyError as e: raise ValueError( - "Operator %r (or %r) does not exists." % ( - name, op_name)) from e - names.add(name) + "Operator %r (domain=%r) does not exists." % ( + op_name, op_domain)) from e + names.add(op_name) if verbose > 0 and fLOG is not None: - fLOG("[_dynamic_class_creation] op_name=%r, cl_name=%r names=%r" - "" % (op_name, cl_name, names)) + fLOG("[_dynamic_class_creation] op_domain=%r op_name=%r, cl_name=%r names=%r" + "" % (op_domain, op_name, cl_name, names)) for name in names: try: - schema = res[name] + schema = res[op_domain, name] except KeyError as e: raise ValueError( - "Operator %r (or %r) does not exists." % ( - name, op_name)) from e + "Operator (%r, %r) does not exists (available=%r)" % ( + op_domain, name, pprint.pformat(list(res)))) from e inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] args = [p for p in schema.attributes] if '_' in name: - class_name = "Onnx" + name + class_name = "Onnx" + _domain_to_class_name(op_domain) + name else: - class_name = "Onnx" + schema.name + class_name = ( + "Onnx" + _domain_to_class_name(op_domain) + schema.name) if verbose > 0 and fLOG is not None: fLOG("[_dynamic_class_creation] op_name=%r, cl_name=%r cache=%r" @@ -305,7 +365,7 @@ def _c(obj, label, i): getattr(schema, 'deprecated', False), schema.since_version, {}) cls[class_name] = cl - if name == name_keep: + if name == op_name: positions[class_name] = position # Retrieves past classes. @@ -701,8 +761,7 @@ def _post_process_attributes(self): def find_schema(self, op_version): """ - Checks if there is an existing schema for a - specific version. + Checks if there is an existing schema for a specific version. :param op_version: requested version :return: schema @@ -1463,6 +1522,6 @@ def to_onnx(self, inputs=None, outputs=None, return onnx_model -_all_schemas, _all_schemas_versions = _populate_schemas() +_all_schemas, _all_schemas_versions, _all_domains = _populate_schemas() _all_classes = {} onnx_load_factory = Xop = OnnxLoadFactory() From 164cbb1aaaca2ab344c22deef7abb2196ea29904 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 21 Feb 2022 14:55:53 +0100 Subject: [PATCH 2/6] fix documentation --- mlprodict/npy/xop.py | 2 +- mlprodict/npy/xop_auto_import_.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index da795c888..14b752f2a 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -261,7 +261,7 @@ def _c(obj, label, i): if isinstance(name, str): if name.startswith('Onnx'): raise ValueError( - "Operator name cannot starts with Onnx: %r." % name) + "Operator name cannot start with Onnx: %r." % name) domain = _find_operator_domain(name.split('_', maxsplit=1)[0]) ops.append((domain, name)) elif isinstance(name, tuple) and len(name) == 2: diff --git a/mlprodict/npy/xop_auto_import_.py b/mlprodict/npy/xop_auto_import_.py index c44fc6f7d..ccc1b7703 100644 --- a/mlprodict/npy/xop_auto_import_.py +++ b/mlprodict/npy/xop_auto_import_.py @@ -18,7 +18,7 @@ def _update_module(): for cl in res: setattr(this, cl.__name__, cl) name = cl.__name__.split('_')[0] - unique.add(name) + unique.add((cl.domain, cl.operator_name)) res = _dynamic_class_creation(list(unique)) for cl in res: setattr(this, cl.__name__, cl) From 9d44b8611fc3804263b3c70e0fd023156d12ff96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 21 Feb 2022 19:25:02 +0100 Subject: [PATCH 3/6] fix issue when requesting one output among an undefined number --- _unittests/ut_npy/test_xop.py | 158 ++++++++++++++++++++++++++++-- mlprodict/npy/xop.py | 129 ++++++++++++++++-------- mlprodict/npy/xop_auto.py | 6 +- mlprodict/npy/xop_auto_import_.py | 1 - mlprodict/npy/xop_opset.py | 18 ++++ mlprodict/npy/xop_variable.py | 2 +- 6 files changed, 261 insertions(+), 53 deletions(-) diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index bc3ce5c17..53ea932cc 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -1,6 +1,6 @@ # pylint: disable=E0611 """ -@brief test log(time=10s) +@brief test log(time=15s) """ import unittest import numpy @@ -11,12 +11,18 @@ make_graph, make_tensor_value_info) from onnx.shape_inference import infer_shapes from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xop import loadop -from mlprodict.npy.xop_variable import Variable, max_supported_opset -from mlprodict.npy.xop import _GraphBuilder from mlprodict.onnxrt import OnnxInference from mlprodict.plotting.text_plot import onnx_simple_text_plot from mlprodict.onnx_tools.onnx2py_helper import get_dtype_shape +from mlprodict.npy.xop import loadop +from mlprodict.npy.xop_auto import get_domain_list +from mlprodict.npy.xop_variable import ( + Variable, max_supported_opset, + numpy_type_prototype, is_numpy_dtype) +from mlprodict.npy.xop import _GraphBuilder +from mlprodict.npy.xop_opset import ( + OnnxReduceSumApi11, OnnxSplitApi11, OnnxSqueezeApi11, + OnnxUnsqueezeApi11, OnnxReduceL2_typed, OnnxReshapeApi13) class TestXOps(ExtTestCase): @@ -24,6 +30,50 @@ class TestXOps(ExtTestCase): def test_float32(self): self.assertEqual(numpy.float32, numpy.dtype('float32')) + def test_numpy_dtype(self): + self.assertEqual(is_numpy_dtype(numpy.float32), True) + self.assertEqual(is_numpy_dtype(numpy.dtype('float32')), True) + self.assertEqual(is_numpy_dtype({}), False) + + def test_numpy_type_prototype(self): + self.assertEqual( + numpy_type_prototype(numpy.float32), TensorProto.FLOAT) + self.assertEqual( + numpy_type_prototype(numpy.dtype('float32')), TensorProto.FLOAT) + self.assertRaise(lambda: numpy_type_prototype(5), TypeError) + + def test_get_domain_list(self): + self.assertEqual(['', 'ai.onnx.ml', 'ai.onnx.preview.training'], + get_domain_list()) + + def test_variable(self): + var = Variable('X', numpy.float32) + self.assertEqual(var.is_named('X'), True) + self.assertEqual(var.name, 'X') + self.assertEqual(var.dtype, numpy.float32) + self.assertEqual(var.proto_type, TensorProto.FLOAT) + self.assertRaise(lambda: Variable('X', 5), TypeError) + self.assertRaise(lambda: var.is_named(4), TypeError) + self.assertRaise( + lambda: Variable('X', numpy.float32, added_dtype=5), + TypeError) + self.assertRaise(lambda: Variable('X', shape='t'), TypeError) + self.assertRaise(lambda: Variable('X', added_shape='t'), TypeError) + var = Variable('X', numpy.float32) + r = repr(var) + self.assertEqual(r, "Variable('X', dtype=)") + var = Variable('X', added_dtype=numpy.float32) + r = repr(var) + self.assertEqual( + r, "Variable('X', added_dtype=)") + self.assertRaise(lambda: var == 'T', TypeError) + var2 = var + self.assertEqual(var == var2, True) + self.assertEqual(var == Variable('Y'), False) + self.assertEqual(var == Variable('X', numpy.float32), False) + self.assertEqual( + var == Variable('X', added_dtype=numpy.float32), True) + def test_impossible(self): cl = loadop("Add") self.assertEqual(cl.__name__, "OnnxAdd") @@ -239,7 +289,7 @@ def test_if2(self): {'y': numpy.float32}) oinf = OnnxInference(model_def) dot = oinf.to_dot() - self.assertIn("out_red0 -> _greater;", dot) + self.assertIn("reduced0 -> _greater;", dot) def test_onnx_abs_shape_variable(self): OnnxAbs = loadop("Abs") @@ -693,7 +743,103 @@ def test_onnx_ml_operator_shortcut(self): a = numpy.abs(x) self.assertEqualArray(x / a.sum(axis=1, keepdims=True), got['Y']) + def test_opset_reduce_sum(self): + for opv in range(10, max_supported_opset() + 1): + with self.subTest(opv=opv): + node = OnnxReduceSumApi11( + 'X', axes=numpy.array([1], dtype=numpy.int64), + op_version=opv, output_names=['Y']) + onx = node.to_onnx(numpy.float32, numpy.float32, + target_opset=opv) + oinf = OnnxInference(onx) + x = numpy.array([[4, 5], [5.5, -6]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x.sum(axis=1, keepdims=1), got['Y']) + + def test_opset_squeeze(self): + for opv in range(10, max_supported_opset() + 1): + with self.subTest(opv=opv): + node = OnnxSqueezeApi11( + 'X', axes=numpy.array([0], dtype=numpy.int64), + op_version=opv, output_names=['Y']) + onx = node.to_onnx(numpy.float32, numpy.float32, + target_opset=opv) + oinf = OnnxInference(onx) + x = numpy.array([[4, 5]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.squeeze(x, axis=0), got['Y']) + + def test_opset_unsqueeze(self): + for opv in range(10, max_supported_opset() + 1): + with self.subTest(opv=opv): + node = OnnxUnsqueezeApi11( + 'X', axes=numpy.array([0], dtype=numpy.int64), + op_version=opv, output_names=['Y']) + onx = node.to_onnx(numpy.float32, numpy.float32, + target_opset=opv) + oinf = OnnxInference(onx) + x = numpy.array([4, 5], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x[numpy.newaxis, :], got['Y']) + + def test_opset_reshape(self): + for opv in range(10, max_supported_opset() + 1): + with self.subTest(opv=opv): + node = OnnxReshapeApi13( + 'X', numpy.array([2, 1, 1], dtype=numpy.int64), + op_version=opv, output_names=['Y']) + onx = node.to_onnx(numpy.float32, numpy.float32, + target_opset=opv) + oinf = OnnxInference(onx) + x = numpy.array([4, 5], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray( + x[:, numpy.newaxis, numpy.newaxis], got['Y']) + + def test_opset_reduce_l2_typed(self): + for dtype in [numpy.float32, numpy.float64]: + for opv in range(10, max_supported_opset() + 1): + with self.subTest(opv=opv, dtype=dtype): + node = OnnxReduceL2_typed( + dtype, 'X', numpy.array([1], dtype=numpy.int64), + op_version=opv, output_names=['Y']) + onx = node.to_onnx(numpy.float32, numpy.float32, + target_opset=opv) + oinf = OnnxInference(onx) + x = numpy.array([[4, 5], [6.7, 7.8]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray( + (x ** 2).sum(axis=1, keepdims=1) ** 0.5, got['Y']) + + def test_opset_split(self): + OnnxSub = loadop("Sub") + for dtype in [numpy.float32, numpy.float64]: + for opv in range(10, max_supported_opset() + 1): + with self.subTest(opv=opv, dtype=dtype): + node_split = OnnxSplitApi11( + 'X', split=numpy.array([1, 1], dtype=numpy.int64), + axis=1, op_version=opv) + node1 = node_split[0] + node2 = node_split[1] + node = OnnxSub(node1, node2, op_version=opv, + output_names=['Y']) + onx = node.to_onnx(numpy.float32, numpy.float32, + target_opset=opv) + oinf = OnnxInference(onx, runtime='onnxruntime1') + x = numpy.array([[4, 5], [6.7, 7.8]], dtype=numpy.float32) + x_copy = x.copy() + expected = (x[:, :1] - x[:, 1:]).copy() + got = oinf.run({'X': x}) + self.assertEqualArray(expected, got['Y']) + self.assertEqualArray(x, x_copy) + oinf = OnnxInference(onx, runtime='python') + x = numpy.array([[4, 5], [6.7, 7.8]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(expected, got['Y']) + # This not always hold, computation may happen in place. + # self.assertEqualArray(x, x_copy) + if __name__ == "__main__": - # TestXOps().test_onnx_add_op() + # TestXOps().test_if2() unittest.main(verbosity=2) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 14b752f2a..d7282d255 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -456,7 +456,7 @@ def __init__(self, onx_op, index, op_version=None): @property def inputs(self): "Returns the only inputs in a list." - inp = self.onx_op.inputs + inp = self.onx_op.output return [inp[self.index]] def add_to(self, builder): @@ -586,6 +586,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, self.domain = domain self.kwargs = kwargs self.onnx_prefix_name = None + self.max_item_ = None # check inputs if len(inputs) == 0: @@ -752,13 +753,32 @@ def _post_process_attributes(self): if not isinstance(value, int): try: to = numpy_type_prototype(value) - except ValueError as e: + except ValueError as e: # pragma: no cover raise ValueError( "Unable to convert argument to in operator cast, " "type is %r, value is %r." % (type(value), value)) from e self.kwargs['to'] = to return + def update_max_item(self, index): + """ + Some operators return a undefined number of outputs. + The method is called when require one of them (with `__getitem__`) + and keeps the greater requested index assuming the node does + not output any result beyond that index. + + :param index: requested index + """ + if self.max_item_ is None: + self.max_item_ = index + else: + self.max_item_ = max(self.max_item_, index) + if self.expected_outputs is None: + self.expected_outputs = [] + while len(self.expected_outputs) <= self.max_item_: + self.expected_outputs.append( + (("NEWOUTPUT", len(self.expected_outputs)), None)) + def find_schema(self, op_version): """ Checks if there is an existing schema for a specific version. @@ -824,6 +844,7 @@ def __getitem__(self, index): Returns an accessor to one of the output of this node. """ + self.update_max_item(index) return OnnxOperatorItem(self, index, self.op_version) def __iter__(self): @@ -837,10 +858,15 @@ def __iter__(self): rg = self.output_range if rg[0] == rg[1] and rg[0] > 0: n = rg[0] + if n is None and self.max_item_ is not None: + n = self.max_item_ + 1 if n is None: raise RuntimeError( - "Unable to guess the number of outputs of node type %r." % + "Unable to guess the number of outputs of node type %r. " + "Uses operator [] to select a specific output." % self.__class__.__name__) + if self.max_item_ is not None: + n = max(n, self.max_item_ + 1) for i in range(n): yield self[i] @@ -852,9 +878,12 @@ def add_to(self, builder): it must have a method `add_node` """ inputs = builder.get_input_names(self, self.inputs) - n_outputs = ( - self.output_range[0] if self.output_names is None - else len(self.output_names)) + if self.output_names is not None: + n_outputs = len(self.output_names) + elif self.expected_outputs is not None: + n_outputs = len(self.expected_outputs) + else: + n_outputs = self.output_range[0] outputs = [builder.get_output_name(self, i) for i in range(n_outputs)] builder.add_node( self.operator_name, @@ -915,6 +944,38 @@ def _process_input(inputs, set_inputs, inp, new_inputs): "Unexpected input type %r in node type %r." % ( type(inp), type(obj))) + def _get_type(node, name=None, outputs=None): + if outputs is None: + return None + if isinstance(outputs, Variable): + if name is None: + return outputs.dtype + if isinstance(name, Variable): + return outputs.dtype or name.dtype + else: + raise RuntimeError( # pragma: no cover + "Unable to handle outputs=%r." % outputs) + if isinstance(outputs, dict): + if name is None: + raise RuntimeError( + "Unable to get type among %r, name=None." % ( + outputs, )) + if isinstance(name, Variable): + n = name.name + else: + n = name + if n not in outputs: + return None + return outputs[n] + if isinstance(outputs, list): + raise NotImplementedError( + "Unexpected type for name=%r, outputs=%r." % ( + name, outputs)) + if is_numpy_dtype(outputs): + return outputs + raise RuntimeError( # pragma: no cover + "Unable to handle outputs=%r." % outputs) + node_outputs = [self] if other_outputs is not None: node_outputs += other_outputs @@ -938,8 +999,11 @@ def _process_input(inputs, set_inputs, inp, new_inputs): memo.extend(stack) new_stack = [] for obj in stack: - for inp in obj.inputs: - _process_input(inputs, set_inputs, inp, new_inputs) + if isinstance(obj, OnnxOperatorItem): + pass + else: + for inp in obj.inputs: + _process_input(inputs, set_inputs, inp, new_inputs) stack = new_stack # eliminate duplicates @@ -951,38 +1015,6 @@ def _process_input(inputs, set_inputs, inp, new_inputs): done.add(id(node)) nodes.append(node) - def _get_type(node, name=None, outputs=None): - if outputs is None: - return None - if isinstance(outputs, Variable): - if name is None: - return outputs.dtype - if isinstance(name, Variable): - return outputs.dtype or name.dtype - else: - raise RuntimeError( # pragma: no cover - "Unable to handle outputs=%r." % outputs) - if isinstance(outputs, dict): - if name is None: - raise RuntimeError( - "Unable to get type among %r, name=None." % ( - outputs, )) - if isinstance(name, Variable): - n = name.name - else: - n = name - if n not in outputs: - return None - return outputs[n] - if isinstance(outputs, list): - raise NotImplementedError( - "Unexpected type for name=%r, outputs=%r." % ( - name, outputs)) - if is_numpy_dtype(outputs): - return outputs - raise RuntimeError( # pragma: no cover - "Unable to handle outputs=%r." % outputs) - # outputs set_names = set() new_outputs = [] @@ -1320,8 +1352,21 @@ def get_output_name(self, node, index): return name if node.output_names is None: - prefix = node.onnx_prefix - n = '%s%d' % (prefix, index) + if node.expected_outputs is None: + prefix = node.onnx_prefix + n = '%s%d' % (prefix, index) + else: + n = node.expected_outputs[index][0] + if isinstance(n, tuple): + if n[0] == 'NEWOUTPUT': + # This case happen for node with undefined number + # of outputs like Split. + prefix = node.onnx_prefix + n = '%s%d' % (prefix, index) + else: + raise RuntimeError( + "Unexpected value for node=%r and output=%r." % ( + node, n)) else: output = node.output_names[index] if isinstance(output, Variable): diff --git a/mlprodict/npy/xop_auto.py b/mlprodict/npy/xop_auto.py index 34c952af4..157a66d99 100644 --- a/mlprodict/npy/xop_auto.py +++ b/mlprodict/npy/xop_auto.py @@ -13,7 +13,7 @@ def _get_doc_template(): try: from jinja2 import Template - except ImportError: + except ImportError: # pragma no cover class Template: "Docstring template" @@ -134,13 +134,13 @@ def get_rst_doc(op_name=None): schemas = [ schema for schema in onnx.defs.get_all_schemas_with_history() if schema.name == op_name] - if len(schemas) > 1: + if len(schemas) > 1: # pragma: no cover raise RuntimeError( "Multiple operators have the same name '{}'.".format(op_name)) elif not isinstance(op_name, list): schemas = [op_name] if len(schemas) == 0: - raise ValueError( + raise ValueError( # pragma: no cover "Unable to find any operator with name '{}'.".format(op_name)) # from onnx.backend.sample.ops import collect_sample_implementations diff --git a/mlprodict/npy/xop_auto_import_.py b/mlprodict/npy/xop_auto_import_.py index ccc1b7703..935388042 100644 --- a/mlprodict/npy/xop_auto_import_.py +++ b/mlprodict/npy/xop_auto_import_.py @@ -17,7 +17,6 @@ def _update_module(): unique = set() for cl in res: setattr(this, cl.__name__, cl) - name = cl.__name__.split('_')[0] unique.add((cl.domain, cl.operator_name)) res = _dynamic_class_creation(list(unique)) for cl in res: diff --git a/mlprodict/npy/xop_opset.py b/mlprodict/npy/xop_opset.py index 63c5153e4..d8a25d734 100644 --- a/mlprodict/npy/xop_opset.py +++ b/mlprodict/npy/xop_opset.py @@ -6,6 +6,7 @@ .. versionadded:: 0.9 """ import numpy +from .xop import loadop def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, @@ -16,6 +17,7 @@ def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, if op_version is None: raise RuntimeError("op_version must be specified.") if op_version is None or op_version >= 13: + OnnxReduceSum = loadop('ReduceSum') if axes is None: return OnnxReduceSum( *x, keepdims=keepdims, op_version=op_version, @@ -25,6 +27,7 @@ def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, keepdims=keepdims, op_version=op_version, output_names=output_names) if op_version >= 11: + OnnxReduceSum_11 = loadop('ReduceSum_11') if axes is None: return OnnxReduceSum_11( *x, keepdims=keepdims, @@ -32,6 +35,7 @@ def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, return OnnxReduceSum_11( *x, axes=axes, keepdims=keepdims, op_version=op_version, output_names=output_names) + OnnxReduceSum_1 = loadop('ReduceSum_1') if axes is None: return OnnxReduceSum_1(*x, keepdims=keepdims, op_version=op_version, @@ -48,6 +52,7 @@ def OnnxSplitApi11(*x, axis=0, split=None, op_version=None, if op_version is None: raise RuntimeError("op_version must be specified.") if op_version is None or op_version >= 13: + OnnxSplit = loadop('Split') if split is None: return OnnxSplit( *x, axis=axis, op_version=op_version, @@ -56,6 +61,7 @@ def OnnxSplitApi11(*x, axis=0, split=None, op_version=None, *x, numpy.array(split, dtype=numpy.int64), axis=axis, op_version=op_version, output_names=output_names) if op_version >= 11: + OnnxSplit_11 = loadop('Split_11') if split is None: return OnnxSplit_11( *x, axis=axis, op_version=op_version, @@ -63,6 +69,7 @@ def OnnxSplitApi11(*x, axis=0, split=None, op_version=None, return OnnxSplit_11( *x, split=split, axis=axis, op_version=op_version, output_names=output_names) + OnnxSplit_2 = loadop('Split_2') if split is None: return OnnxSplit_2( *x, axis=axis, op_version=op_version, output_names=output_names) @@ -78,13 +85,16 @@ def OnnxSqueezeApi11(*x, axes=None, op_version=None, if op_version is None: raise RuntimeError("op_version must be specified.") if op_version is None or op_version >= 13: + OnnxSqueeze = loadop('Squeeze') return OnnxSqueeze( *x, numpy.array(axes, dtype=numpy.int64), op_version=op_version, output_names=output_names) if op_version >= 11: + OnnxSqueeze_11 = loadop('Squeeze_11') return OnnxSqueeze_11( *x, axes=axes, op_version=op_version, output_names=output_names) + OnnxSqueeze_1 = loadop('Squeeze_1') return OnnxSqueeze_1(*x, axes=axes, op_version=op_version, output_names=output_names) @@ -97,13 +107,16 @@ def OnnxUnsqueezeApi11(*x, axes=None, op_version=None, if op_version is None: raise RuntimeError("op_version must be specified.") if op_version is None or op_version >= 13: + OnnxUnsqueeze = loadop('Unsqueeze') return OnnxUnsqueeze( *x, numpy.array(axes, dtype=numpy.int64), op_version=op_version, output_names=output_names) if op_version >= 11: + OnnxUnsqueeze_11 = loadop('Unsqueeze_11') return OnnxUnsqueeze_11( *x, axes=axes, op_version=op_version, output_names=output_names) + OnnxUnsqueeze_1 = loadop('Unsqueeze_1') return OnnxUnsqueeze_1(*x, axes=axes, op_version=op_version, output_names=output_names) @@ -113,7 +126,9 @@ def OnnxReduceL2_typed(dtype, x, axes=None, keepdims=1, op_version=None, """ Adds operator ReduceL2 for float or double. """ + OnnxMul, OnnxSqrt = loadop('Mul', 'Sqrt') if dtype == numpy.float32: + OnnxReduceL2 = loadop('ReduceL2') return OnnxReduceL2( x, axes=axes, keepdims=keepdims, op_version=op_version, output_names=output_names) @@ -132,11 +147,14 @@ def OnnxReshapeApi13(*x, allowzero=0, op_version=None, if op_version is None: raise RuntimeError("op_version must be specified.") if op_version is None or op_version >= 14: + OnnxReshape = loadop('Reshape') return OnnxReshape( *x, allowzero=allowzero, op_version=op_version, output_names=output_names) if op_version >= 13: + OnnxReshape_13 = loadop('Reshape_13') return OnnxReshape_13( *x, op_version=op_version, output_names=output_names) + OnnxReshape_5 = loadop('Reshape_5') return OnnxReshape_5( *x, op_version=op_version, output_names=output_names) diff --git a/mlprodict/npy/xop_variable.py b/mlprodict/npy/xop_variable.py index f0e2845a5..8a0b528ff 100644 --- a/mlprodict/npy/xop_variable.py +++ b/mlprodict/npy/xop_variable.py @@ -153,7 +153,7 @@ def copy_add(self, dtype): :return: @see cl Variable """ if self.added_dtype_ is not None: - raise RuntimeError( + raise RuntimeError( # pragma: no cover "Cannot copy as added_dtype is not None.") if isinstance(dtype, numpy.ndarray): dtype, shape = dtype.dtype, dtype.shape From 3bfc2327af90b7c65b442454043970e3a135ecfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 21 Feb 2022 22:20:41 +0100 Subject: [PATCH 4/6] Update test_shape_inference_xop.py --- _unittests/ut_onnxrt/test_shape_inference_xop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_onnxrt/test_shape_inference_xop.py b/_unittests/ut_onnxrt/test_shape_inference_xop.py index f629c797b..3d6bde605 100644 --- a/_unittests/ut_onnxrt/test_shape_inference_xop.py +++ b/_unittests/ut_onnxrt/test_shape_inference_xop.py @@ -39,7 +39,7 @@ def check_infer_shapes(self, onx, out, rt): onnx_simple_text_plot(onx))) def test_onnx_shape_inference(self): - OnnxAdd = loadop('OnnxAdd') + OnnxAdd = loadop('Add') dtype = numpy.float32 for opset in TestOnnxShapeInferenceXop.opsets: with self.subTest(opset=opset): From 8392dc2239990116e1777a9e99d4268e94f32ff9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 22 Feb 2022 02:10:02 +0100 Subject: [PATCH 5/6] Xop API, adds class OnnxSubOnnx to insert an ONNX graph into another --- .../source/_exts/generate_onnx_ops.py | 8 +- _doc/sphinxdoc/source/api/ast.rst | 38 ++++ _doc/sphinxdoc/source/api/index.rst | 7 + _doc/sphinxdoc/source/api/xop.rst | 36 ++-- _doc/sphinxdoc/source/tutorial/index.rst | 21 ++- .../source/tutorial/numpy_api_onnx.rst | 28 +-- _doc/sphinxdoc/source/tutorial/onnx_numpy.rst | 16 +- .../tutorial/{onnx.rst => onnx_runtime.rst} | 38 +--- _doc/sphinxdoc/source/tutorial/skl.rst | 33 ++++ _doc/sphinxdoc/source/tutorial/xop_api.rst | 4 + _unittests/ut_npy/test_xop_convert.py | 59 ++++++ mlprodict/npy/xop.py | 45 ++++- mlprodict/npy/xop_auto.py | 3 +- mlprodict/npy/xop_auto_import_.py | 2 +- mlprodict/npy/xop_convert.py | 169 ++++++++++++++++++ mlprodict/npy/xop_opset.py | 2 +- mlprodict/npy/xop_variable.py | 2 +- .../onnx_grammar/onnx_translation.py | 11 +- mlprodict/onnxrt/doc/doc_helper.py | 2 +- 19 files changed, 431 insertions(+), 93 deletions(-) create mode 100644 _doc/sphinxdoc/source/api/ast.rst rename _doc/sphinxdoc/source/tutorial/{onnx.rst => onnx_runtime.rst} (82%) create mode 100644 _doc/sphinxdoc/source/tutorial/skl.rst create mode 100644 _doc/sphinxdoc/source/tutorial/xop_api.rst create mode 100644 _unittests/ut_npy/test_xop_convert.py create mode 100644 mlprodict/npy/xop_convert.py diff --git a/_doc/sphinxdoc/source/_exts/generate_onnx_ops.py b/_doc/sphinxdoc/source/_exts/generate_onnx_ops.py index db68cbe8e..3c645a645 100644 --- a/_doc/sphinxdoc/source/_exts/generate_onnx_ops.py +++ b/_doc/sphinxdoc/source/_exts/generate_onnx_ops.py @@ -22,7 +22,7 @@ class SupportedOnnxOpsDirective(Directive): has_content = False def run(self): - cls = _dynamic_class_creation() + cls = _dynamic_class_creation(include_past=True) cls_name = [(c.__name__, c) for c in cls] rows = [] sorted_cls_name = list(sorted(cls_name)) @@ -32,8 +32,8 @@ def make_ref(cl): return ":ref:`l-xop-onnx-{}`".format(cl.__name__) table = [] - cut = len(sorted_cls_name) // 3 + \ - (1 if len(sorted_cls_name) % 3 else 0) + cut = (len(sorted_cls_name) // 3 + + (1 if len(sorted_cls_name) % 3 else 0)) for i in range(cut): row = [] row.append(make_ref(sorted_cls_name[i][1])) @@ -56,9 +56,9 @@ def make_ref(cl): nested_parse_with_titles(self.state, st, node) main += node - rows.append('') for name, cl in sorted_cls_name: rows = [] + rows.append('') rows.append('.. _l-xop-onnx-{}:'.format(cl.__name__)) rows.append('') rows.append(cl.__name__) diff --git a/_doc/sphinxdoc/source/api/ast.rst b/_doc/sphinxdoc/source/api/ast.rst new file mode 100644 index 000000000..3e75f4eef --- /dev/null +++ b/_doc/sphinxdoc/source/api/ast.rst @@ -0,0 +1,38 @@ + +=== +AST +=== + +.. contents:: + :local: + +Main functions +============== + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.translate_fct2onnx + +Additional functions +==================== + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.get_default_context + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.get_default_context_cpl + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.py_make_float_array + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.py_opp + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.py_mul + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.py_pow + +.. autosignature:: mlprodict.onnx_tools.onnx_translation.squareform_pdist + +Grammar Objects +=============== + +.. autosignature:: mlprodict.onnx_tools.onnx_grammar.node_visitor_translator.CodeNodeVisitor + +.. autosignature:: mlprodict.onnx_tools.onnx_grammar.onnx_translator.CodeTranslator + +.. autosignature:: mlprodict.onnx_tools.onnx_grammar.onnx_translator.OnnxTranslator diff --git a/_doc/sphinxdoc/source/api/index.rst b/_doc/sphinxdoc/source/api/index.rst index 28d161bc6..05db1a967 100644 --- a/_doc/sphinxdoc/source/api/index.rst +++ b/_doc/sphinxdoc/source/api/index.rst @@ -11,8 +11,15 @@ This is a summary of functions this modules provides. onnx_conv sklapi + +**Write ONNX graphs** + +.. toctree:: + :maxdepth: 1 + npy xop + ast **ONNX runtime** diff --git a/_doc/sphinxdoc/source/api/xop.rst b/_doc/sphinxdoc/source/api/xop.rst index d650f4bf3..8ffa37954 100644 --- a/_doc/sphinxdoc/source/api/xop.rst +++ b/_doc/sphinxdoc/source/api/xop.rst @@ -1,32 +1,40 @@ .. _l-xop-onnxpy: -Create ONNX graphs -================== +======= +Xop API +======= .. contents:: :local: -Example -+++++++ - -Converters -++++++++++ - API -+++ +=== + +Automated gathering of operators +++++++++++++++++++++++++++++++++ .. autosignature:: mlprodict.npy.xop.ClassFactory .. autosignature:: mlprodict.npy.xop.dynamic_class_creation -.. autosignature:: mlprodict.npy.xops_variable.Variable +.. autosignature:: mlprodict.npy.xop._GraphBuilder + +Main classes +++++++++++++ + +.. autosignature:: mlprodict.npy.xop_variable.Variable + +.. autosignature:: mlprodict.npy.xop.OnnxOperator + +.. autosignature:: mlprodict.npy.xop.OnnxOperatorItem -.. autosignature:: mlprodict.npy.xop_ops._GraphBuilder +.. autosignature:: mlprodict.npy.xop_convert.OnnxSubOnnx -.. autosignature:: mlprodict.npy.xop_ops.OnnxOperator +.. autosignature:: mlprodict.npy.xop_convert.OnnxSubEstimator -.. autosignature:: mlprodict.npy.xop_ops.OnnxOperatorItem +Helpers to handle API changing with opsets +++++++++++++++++++++++++++++++++++++++++++ .. autosignature:: mlprodict.npy.xop_opset.OnnxReduceSumApi11 @@ -41,7 +49,7 @@ API .. autosignature:: mlprodict.npy.xop_opset.OnnxReshapeApi13 Available ONNX operators -++++++++++++++++++++++++ +======================== .. toctree:: diff --git a/_doc/sphinxdoc/source/tutorial/index.rst b/_doc/sphinxdoc/source/tutorial/index.rst index b0336d008..afe8bec14 100644 --- a/_doc/sphinxdoc/source/tutorial/index.rst +++ b/_doc/sphinxdoc/source/tutorial/index.rst @@ -5,11 +5,26 @@ Tutorial The only tutorial is about :epkg:`ONNX` and only one piece this module can do. More should follow. +.. contents:: + :local: + +Run inference ++++++++++++++ + .. toctree:: :maxdepth: 1 - onnx - onnx_numpy - numpy_api_onnx + skl + onnx_runtime optim benchmark + +Write custom ONNX graph ++++++++++++++++++++++++ + +.. toctree:: + :maxdepth: 1 + + onnx_numpy + numpy_api_onnx + xop_api diff --git a/_doc/sphinxdoc/source/tutorial/numpy_api_onnx.rst b/_doc/sphinxdoc/source/tutorial/numpy_api_onnx.rst index 9d428e8d8..b2a73bb2d 100644 --- a/_doc/sphinxdoc/source/tutorial/numpy_api_onnx.rst +++ b/_doc/sphinxdoc/source/tutorial/numpy_api_onnx.rst @@ -48,7 +48,7 @@ Following example shows how to replace *numpy* by *ONNX*. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -187,7 +187,7 @@ One instance is added in a pipeline trained on the Iris dataset. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -342,7 +342,7 @@ is used. Let's see how to do it. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning import numpy from pandas import DataFrame @@ -455,7 +455,7 @@ the class is a transformer and automatically adds method .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning import numpy from pandas import DataFrame @@ -517,7 +517,7 @@ with arguments :class:`onnxnumpy_np .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -565,7 +565,7 @@ as an argument of `to_onnx`. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -621,7 +621,7 @@ another operator. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: import numpy as np @@ -713,7 +713,7 @@ the conversion to ONNX :meth:`to_algebra .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -763,7 +763,7 @@ types. If types are different, one must be cast into the other one. .. runpython:: :showcode: :exception: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -790,7 +790,7 @@ except one. .. runpython:: :showcode: :exception: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -841,7 +841,7 @@ a new one supporting custom functions implemented this API. .. runpython:: :showcode: :exception: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -893,7 +893,7 @@ does. However it produces the following error. .. runpython:: :showcode: :exception: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: import numpy @@ -947,7 +947,7 @@ in class @see cl OnnxVar. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any @@ -995,7 +995,7 @@ is called. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: from typing import Any diff --git a/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst b/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst index b540bdc7d..56fe17a68 100644 --- a/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst +++ b/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst @@ -1,8 +1,8 @@ .. _l-numpy2onnx-tutorial: -From numpy to ONNX -================== +Create custom ONNX graphs +========================= Converting a :epkg:`scikit-learn` pipeline is easy when the pipeline contains only pieces implemented in :epkg:`scikit-learn` @@ -25,7 +25,7 @@ the first examples of `sklearn-onnx tutorial`. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning import numpy from sklearn.pipeline import make_pipeline @@ -55,8 +55,8 @@ into *ONNX*. Even if function :epkg:`numpy:log` does exist in ONNX specification this problem is equivalent to a translation from a language, Python, to another one, ONNX. -Translating numpy to ONNX -+++++++++++++++++++++++++ +Translating numpy to ONNX with AST +++++++++++++++++++++++++++++++++++ .. index:: algebric function @@ -81,7 +81,7 @@ produces the :epkg:`ONNX` graph. .. runpython:: :showcode: - :warningout: DeprecationWarning + :warningout: DeprecationWarning, FutureWarning :process: :store_in_file: fct2onnx_expsine.py @@ -95,7 +95,7 @@ produces the :epkg:`ONNX` graph. # The function to convert into ONNX. def kernel_call_ynone(X, length_scale=1.2, periodicity=1.1, - pi=3.141592653589793): + pi=3.141592653589793, op_version=15): # squareform(pdist(X, ...)) in one function. dists = squareform_pdist(X, metric='euclidean') @@ -140,7 +140,7 @@ produces the :epkg:`ONNX` graph. # Calls the ONNX algebric function to produce the ONNX graph. inputs = {'X': x.astype(numpy.float32)} - onnx_g = onnx_model.to_onnx(inputs, target_opset=12) + onnx_g = onnx_model.to_onnx(inputs, target_opset=15) # Creates a python runtime associated to the ONNX function. oinf = OnnxInference(onnx_g) diff --git a/_doc/sphinxdoc/source/tutorial/onnx.rst b/_doc/sphinxdoc/source/tutorial/onnx_runtime.rst similarity index 82% rename from _doc/sphinxdoc/source/tutorial/onnx.rst rename to _doc/sphinxdoc/source/tutorial/onnx_runtime.rst index f16d4a6dc..b77bfeed7 100644 --- a/_doc/sphinxdoc/source/tutorial/onnx.rst +++ b/_doc/sphinxdoc/source/tutorial/onnx_runtime.rst @@ -1,8 +1,8 @@ .. _l-onnx-tutorial: -ONNX and Python Runtime -======================= +Execute ONNX graphs +=================== This package implements a python runtime for ONNX in class :class:`OnnxInference `. @@ -184,37 +184,3 @@ As a consequence, interdiate results cannot be seen anymore. oinf = OnnxInference(model_def, runtime='python_compiled') print(oinf.run({'X': X_test[:5]})) - -From scikit-learn to ONNX -+++++++++++++++++++++++++ - -Function `skl2onnx.to_onnx `_ is the -main entrypoint to convert a *scikit-learn* pipeline into ONNX. -The same function was extended in this package into -:func:`to_onnx ` to handle -dataframes, an extended list of supported converters, scorers. -It works exactly the same: - -.. runpython:: - :showcode: - :warningout: DeprecationWarning - - import numpy - from sklearn.datasets import load_iris - from sklearn.model_selection import train_test_split - from sklearn.cluster import KMeans - from mlprodict.onnx_conv import to_onnx - from mlprodict.onnxrt import OnnxInference - - iris = load_iris() - X = iris.data.astype(numpy.float32) - X_train, X_test = train_test_split(X) - clr = KMeans(n_clusters=3) - clr.fit(X_train) - - model_def = to_onnx(clr, X_train.astype(numpy.float32), - target_opset=12) - - oinf = OnnxInference(model_def, runtime='python') - print(oinf.run({'X': X_test[:5]})) diff --git a/_doc/sphinxdoc/source/tutorial/skl.rst b/_doc/sphinxdoc/source/tutorial/skl.rst new file mode 100644 index 000000000..912a4ae8f --- /dev/null +++ b/_doc/sphinxdoc/source/tutorial/skl.rst @@ -0,0 +1,33 @@ +From scikit-learn to ONNX +========================= + +Function `skl2onnx.to_onnx `_ is the +main entrypoint to convert a *scikit-learn* pipeline into ONNX. +The same function was extended in this package into +:func:`to_onnx ` to handle +dataframes, an extended list of supported converters, scorers. +It works exactly the same: + +.. runpython:: + :showcode: + :warningout: DeprecationWarning + + import numpy + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + from sklearn.cluster import KMeans + from mlprodict.onnx_conv import to_onnx + from mlprodict.onnxrt import OnnxInference + + iris = load_iris() + X = iris.data.astype(numpy.float32) + X_train, X_test = train_test_split(X) + clr = KMeans(n_clusters=3) + clr.fit(X_train) + + model_def = to_onnx(clr, X_train.astype(numpy.float32), + target_opset=12) + + oinf = OnnxInference(model_def, runtime='python') + print(oinf.run({'X': X_test[:5]})) diff --git a/_doc/sphinxdoc/source/tutorial/xop_api.rst b/_doc/sphinxdoc/source/tutorial/xop_api.rst new file mode 100644 index 000000000..a299f3edf --- /dev/null +++ b/_doc/sphinxdoc/source/tutorial/xop_api.rst @@ -0,0 +1,4 @@ +Xop API +======= + +*to be completed* diff --git a/_unittests/ut_npy/test_xop_convert.py b/_unittests/ut_npy/test_xop_convert.py new file mode 100644 index 000000000..a422bf354 --- /dev/null +++ b/_unittests/ut_npy/test_xop_convert.py @@ -0,0 +1,59 @@ +# pylint: disable=E0611 +""" +@brief test log(time=15s) +""" +import unittest +import numpy +from pyquickhelper.pycode import ExtTestCase +from mlprodict.onnxrt import OnnxInference +from mlprodict.npy.xop import loadop +from mlprodict.npy.xop_convert import OnnxSubOnnx + + +class TestXOpsConvert(ExtTestCase): + + def test_onnx_abs(self): + OnnxAbs = loadop("Abs") + ov = OnnxAbs('X', output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) + + sub = OnnxSubOnnx(onx, 'X', output_names=['Y']) + onx = sub.to_onnx(numpy.float32, numpy.float32, verbose=0) + + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x), got['Y']) + + def test_onnx_add(self): + OnnxAdd = loadop("Add") + ov = OnnxAdd('X', numpy.array([2], dtype=numpy.float32), + output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) + + sub = OnnxSubOnnx(onx, 'X', output_names=['Y']) + onx = sub.to_onnx(numpy.float32, numpy.float32, verbose=0) + + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x + 2, got['Y']) + + def test_onnx_cast(self): + OnnxCast = loadop("Cast") + ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) + + sub = OnnxSubOnnx(onx, 'X', output_names=['Y']) + onx = sub.to_onnx(numpy.float32, numpy.int64, verbose=0) + r = repr(sub) + self.assertStartsWith('OnnxSubOnnx(..., output_name', r) + + oinf = OnnxInference(onx) + x = numpy.array([-2.4, 2.4], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x.astype(numpy.int64), got['Y']) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index d7282d255..5e4e8b19c 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -1,7 +1,7 @@ # pylint: disable=E1101,C0302 """ @file -@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. +@brief Xop API to build onnx graphs. Inspired from :epkg:`skl2onnx`. .. versionadded:: 0.9 """ @@ -14,7 +14,7 @@ from onnx.helper import ( make_node, make_graph, make_model, make_tensor_value_info) -from onnx.numpy_helper import from_array +from onnx.numpy_helper import from_array, to_array from onnx.shape_inference import infer_shapes from ._cache import cache_folder from .xop_variable import ( @@ -40,6 +40,18 @@ def _default_OPSET_TO_IR_VERSION(): def _domain_to_class_name(domain): + """ + Converts domain into a name. + + :param domain: domain name such as `ai.onnx.ml` + :return: string + + .. runpython:: + :showcode: + + from mlprodict.npy.xop import _domain_to_class_name + print(_domain_to_class_name('ai.onnx.ml')) + """ if domain == 'ai.onnx': return '' dom = domain.split('.') @@ -229,7 +241,8 @@ def __init__(self, *args, **kwargs): return newclass -def _dynamic_class_creation(operator_names=None, cache=False, verbose=0, fLOG=print): +def _dynamic_class_creation(operator_names=None, cache=False, include_past=False, + verbose=0, fLOG=print): """ Automatically generates classes for each of the operators module *onnx* defines and described at @@ -242,6 +255,7 @@ def _dynamic_class_creation(operator_names=None, cache=False, verbose=0, fLOG=pr :param operator_names: list of operators to request or None for all :param cache: extract the documentation from onnx package and saves it on disk it True + :param include_past: includes past versions if operator_names is None :param verbose: display some progress :param fLOG: logging function :return: list of requested operators as a tuple @@ -254,6 +268,14 @@ def _c(obj, label, i): cache_dir = cache_folder() if operator_names is None: operator_names = list(_all_schemas_versions) + if include_past: + add = [] + for domain, op in operator_names: + add.extend( + [(domain, k) + for k in _all_schemas_versions[domain, op]]) + operator_names.extend(add) + operator_names.sort() # type verification ops = [] @@ -284,7 +306,7 @@ def _c(obj, label, i): if op_domain == 'ai.onnx': op_domain = '' set_names[op_domain, op_name] = pos - if '_' in op_name: + if '_' in op_name and not include_past: n = op_name.split('_')[0] set_skip.add((op_domain, n)) if n not in set_names: @@ -1424,6 +1446,19 @@ def get_input_names(self, node, inputs): "Unexpected type for an input %r." % type(i)) return names + def add_initializer(self, name, init): + """ + Adds an initializer to the graph. + + :param name: initializer name + :param init: initializer to copy + :return: created intializer + """ + value = to_array(init) + val = from_array(value, name) + self.initializer.append(val) + return val + def add_node(self, op_type, name, inputs, outputs, domain='', opset=None, **attributes): """ @@ -1435,6 +1470,7 @@ def add_node(self, op_type, name, inputs, outputs, domain='', :param outputs: outputs name list :param domain: node domain :param opset: node opset + :return: created node """ if not isinstance(inputs, list): raise TypeError( # pragma: no cover @@ -1456,6 +1492,7 @@ def add_node(self, op_type, name, inputs, outputs, domain='', node = make_node(op_type, inputs, outputs, name=name, domain=domain, **attributes) self.node.append(node) + return node def _process_io(self, inputs, input_names): if inputs is None: diff --git a/mlprodict/npy/xop_auto.py b/mlprodict/npy/xop_auto.py index 157a66d99..427bf4f76 100644 --- a/mlprodict/npy/xop_auto.py +++ b/mlprodict/npy/xop_auto.py @@ -1,6 +1,7 @@ """ @file -@brief Automates the generation of the documentation. +@brief Automates the generation of operators for the +documentation for the Xop API. .. versionadded:: 0.9 """ diff --git a/mlprodict/npy/xop_auto_import_.py b/mlprodict/npy/xop_auto_import_.py index 935388042..a6d82d076 100644 --- a/mlprodict/npy/xop_auto_import_.py +++ b/mlprodict/npy/xop_auto_import_.py @@ -1,6 +1,6 @@ """ @file -@brief Importing this file takes time. It should be avoided. +@brief Xop API. Importing this file takes time. It should be avoided. .. versionadded:: 0.9 """ diff --git a/mlprodict/npy/xop_convert.py b/mlprodict/npy/xop_convert.py new file mode 100644 index 000000000..8ede2e55c --- /dev/null +++ b/mlprodict/npy/xop_convert.py @@ -0,0 +1,169 @@ +""" +@file +@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. + +.. versionadded:: 0.9 +""" +from .xop import OnnxOperator + + +class OnnxSubOnnx(OnnxOperator): + """ + This operator is used to insert existing ONNX into + the ONNX graph being built. + """ + + domain = 'mlprodict' + since_version = 1 + expected_inputs = None + expected_outputs = None + input_range = [1, 1e9] + output_range = [1, 1e9] + + def __init__(self, model, *inputs, output_names=None): + if model is None: + raise ValueError("Model cannot be None.") + if len(inputs) > len(model.graph.input): + raise RuntimeError( + "Unexpected number of inputs %r > expected %r." % ( + len(inputs), len(model.graph.input))) + if (output_names is not None and + len(output_names) != len(model.graph.output)): + raise RuntimeError( + "Unexpected number of outputs %r != expected %r." % ( + len(output_names), len(model.graph.output))) + OnnxOperator.__init__(self, *inputs, output_names=output_names) + self.model = model + + def __repr__(self): + "usual" + atts = {} + for att in ['output_names']: + value = getattr(self, att, None) + if value is not None: + atts[att] = value + atts.update(self.kwargs) + msg = ", ".join("%s=%r" % (k, v) for k, v in atts.items()) + if len(atts) > 0: + msg = ", " + msg + return "%s(...%s)" % ( + self.__class__.__name__, msg) + + def add_to(self, builder): + """ + Adds to graph builder. + + :param builder: instance of @see cl _GraphBuilder, + it must have a method `add_node` + """ + inputs = builder.get_input_names(self, self.inputs) + n_outputs = len(self.model.graph.output) + outputs = [builder.get_output_name(self, i) for i in range(n_outputs)] + + mapped_names = {} + + # adding initializers + for init in self.model.graph.initializer: + new_name = builder.get_unique_name(init.name) + mapped_names[init.name] = new_name + builder.add_initializer(new_name, init) + + # linking inputs + for name in inputs: + new_name = builder.get_unique_name(name) + mapped_names[name] = new_name + builder.add_node( + 'Identity', builder.get_unique_name('_sub_' + name), + [name], [new_name]) + + # adding nodes + for node in self.model.graph.node: + new_inputs = [mapped_names[i] for i in node.input] + new_outputs = [] + for o in node.output: + new_name = builder.get_unique_name(o) + mapped_names[o] = new_name + new_outputs.append(new_name) + + atts = {} + for att in node.attribute: + if att.type == 2: + value = att.i + atts[att.name] = value + continue + raise NotImplementedError( + "Unable to copy attribute type %r (%r)." % ( + att.type, att)) + + builder.add_node( + node.op_type, + builder.get_unique_name('_sub_' + node.name), + new_inputs, new_outputs, domain=node.domain, **atts) + + # linking outputs + for out, name in zip(self.model.graph.output, outputs): + builder.add_node( + 'Identity', builder.get_unique_name('_sub_' + out.name), + [mapped_names[out.name]], [name]) + + +class OnnxSubEstimator(OnnxSubOnnx): + """ + This operator is used to call the converter of a model + to insert the node coming from the conversion into a + bigger ONNX graph. It supports model from :epkg:`scikit-learn` + using :epkg:`sklearn-onnx`. + + :param model: model to convert + :param inputs: inputs + :param op_version: targetted opset + :param options: to rewrite the options used to convert the model + :param input_types: the implementation may be wrong in guessing + the input types of the model, this parameter can be used + to overwrite them, usually a dictionary + `{ input_name: numpy array as an example }` + :param kwargs: any other parameters such as black listed or + white listed operators + """ + + since_version = 1 + expected_inputs = None + expected_outputs = None + input_range = [1, 1e9] + output_range = [1, 1e9] + + def __init__(self, model, *inputs, op_version=None, + output_names=None, options=None, + input_types=None, **kwargs): + if model is None: + raise ValueError("Model cannot be None.") + OnnxSubOnnx.__init__( + self, *inputs, op_version=op_version, + output_names=output_names, **kwargs) + self.model = model + self.options = options + self.input_types = input_types + + def __repr__(self): + "usual" + atts = {} + for att in ['op_version', 'output_names', 'options', + 'input_types']: + value = getattr(self, att, None) + if value is not None: + atts[att] = value + atts.update(self.kwargs) + msg = ", ".join("%s=%r" for k, v in atts.items()) + if len(atts) > 0: + msg += ", " + return "%s(%r%s)" % ( + self.__class__.__name__, self.model, msg) + + def add_to(self, builder): + """ + Adds to graph builder. + + :param builder: instance of @see cl _GraphBuilder, + it must have a method `add_node` + """ + raise NotImplementedError() diff --git a/mlprodict/npy/xop_opset.py b/mlprodict/npy/xop_opset.py index d8a25d734..d5f790c39 100644 --- a/mlprodict/npy/xop_opset.py +++ b/mlprodict/npy/xop_opset.py @@ -1,7 +1,7 @@ # pylint: disable=E0602 """ @file -@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. +@brief Xop API to build onnx graphs. Inspired from :epkg:`skl2onnx`. .. versionadded:: 0.9 """ diff --git a/mlprodict/npy/xop_variable.py b/mlprodict/npy/xop_variable.py index 8a0b528ff..cc0cbace5 100644 --- a/mlprodict/npy/xop_variable.py +++ b/mlprodict/npy/xop_variable.py @@ -1,6 +1,6 @@ """ @file -@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. +@brief Xop API to build onnx graphs. Inspired from :epkg:`skl2onnx`. .. versionadded:: 0.9 """ diff --git a/mlprodict/onnx_tools/onnx_grammar/onnx_translation.py b/mlprodict/onnx_tools/onnx_grammar/onnx_translation.py index 8d7e2df1b..c33dc8dff 100644 --- a/mlprodict/onnx_tools/onnx_grammar/onnx_translation.py +++ b/mlprodict/onnx_tools/onnx_grammar/onnx_translation.py @@ -197,10 +197,10 @@ def trs(x, y): import numpy from mlprodict.onnx_tools.onnx_grammar import translate_fct2onnx + from mlprodict.plotting.text_plot import onnx_simple_text_plot from mlprodict.onnxrt import OnnxInference from skl2onnx.algebra.onnx_ops import ( - OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity - ) + OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity) ctx = {'OnnxAdd': OnnxAdd, 'OnnxTranspose': OnnxTranspose, @@ -222,16 +222,17 @@ def trs(x, y): trs, context={'numpy.transpose': numpy.transpose}, cpl=True, context_cpl=ctx, output_names=['Z']) - onnx_code = onnx_fct('x', 'y', opset_version=12) - print('ONNX code:', onnx_code) + onnx_code = onnx_fct('x', 'y', op_version=12) onnx_g = onnx_code.to_onnx(inputs, target_opset=12) + print("ONNX model") + print(onnx_simple_text_plot(onnx_g)) oinf = OnnxInference(onnx_g) res = oinf.run(inputs) + print('-----------') print("ONNX inference:", res['Z']) - print("ONNX graph:", onnx_g) The function to be converted may include python functions which must not be converted. In that case, their name diff --git a/mlprodict/onnxrt/doc/doc_helper.py b/mlprodict/onnxrt/doc/doc_helper.py index ffae0816a..4a5a30753 100644 --- a/mlprodict/onnxrt/doc/doc_helper.py +++ b/mlprodict/onnxrt/doc/doc_helper.py @@ -352,7 +352,7 @@ def visual_rst_template(): Fitted on a problem type *{{ kind }}* (see :func:`find_suitable_problem `), - method {{ method }} matches output {{ output_index }}. + method `{{ method }}` matches output {{ output_index }}. {{ optim_param }} :: From 688cfbc156a006bf158e0a450a5ff97553a7e872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 22 Feb 2022 11:42:59 +0100 Subject: [PATCH 6/6] finalize to_onnx --- _unittests/ut_npy/test_xop_convert.py | 44 ++++++++- mlprodict/npy/xop_convert.py | 126 +++++++++++++++++++++----- 2 files changed, 148 insertions(+), 22 deletions(-) diff --git a/_unittests/ut_npy/test_xop_convert.py b/_unittests/ut_npy/test_xop_convert.py index a422bf354..447479d7f 100644 --- a/_unittests/ut_npy/test_xop_convert.py +++ b/_unittests/ut_npy/test_xop_convert.py @@ -5,9 +5,12 @@ import unittest import numpy from pyquickhelper.pycode import ExtTestCase +from sklearn.datasets import make_regression +from sklearn.linear_model import LinearRegression from mlprodict.onnxrt import OnnxInference from mlprodict.npy.xop import loadop -from mlprodict.npy.xop_convert import OnnxSubOnnx +from mlprodict.npy.xop_convert import OnnxSubOnnx, OnnxSubEstimator +from mlprodict.npy.xop_variable import max_supported_opset class TestXOpsConvert(ExtTestCase): @@ -54,6 +57,45 @@ def test_onnx_cast(self): got = oinf.run({'X': x}) self.assertEqualArray(x.astype(numpy.int64), got['Y']) + def test_onnx_lr(self): + X, y = make_regression(n_features=2) # pylint: disable=W0632 + lr = LinearRegression() + lr.fit(X, y) + X32 = X.astype(numpy.float32) + + OnnxIdentity, OnnxReshape = loadop("Identity", "Reshape") + ov = OnnxIdentity('X') + self.assertRaise(lambda: OnnxSubEstimator(lr, ov), NotImplementedError) + sub = OnnxSubEstimator( + lr, ov, op_version=max_supported_opset(), + initial_types=X32[:1]) + r = repr(sub) + self.assertStartsWith('OnnxSubEstimator(LinearRegression()', r) + last = OnnxReshape(sub, numpy.array([-1], dtype=numpy.int64), + output_names=['Y']) + onx = last.to_onnx(numpy.float32, numpy.float32, verbose=0) + + oinf = OnnxInference(onx) + got = oinf.run({'X': X32}) + expected = lr.predict(X32) + self.assertEqualArray(expected, got['Y'], decimal=4) + + def test_onnx_lr_only(self): + X, y = make_regression(n_features=2) # pylint: disable=W0632 + lr = LinearRegression() + lr.fit(X, y) + X32 = X.astype(numpy.float32) + + last = OnnxSubEstimator( + lr, 'X', op_version=max_supported_opset(), + initial_types=X32[:1], output_names=['Y']) + onx = last.to_onnx(numpy.float32, numpy.float32, verbose=0) + + oinf = OnnxInference(onx) + got = oinf.run({'X': X32}) + expected = lr.predict(X32) + self.assertEqualArray(expected, got['Y'].ravel(), decimal=4) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/mlprodict/npy/xop_convert.py b/mlprodict/npy/xop_convert.py index 8ede2e55c..3f0dc7b7e 100644 --- a/mlprodict/npy/xop_convert.py +++ b/mlprodict/npy/xop_convert.py @@ -4,6 +4,7 @@ .. versionadded:: 0.9 """ +import numpy from .xop import OnnxOperator @@ -69,16 +70,21 @@ def add_to(self, builder): builder.add_initializer(new_name, init) # linking inputs - for name in inputs: - new_name = builder.get_unique_name(name) - mapped_names[name] = new_name + for inp, name in zip(self.model.graph.input, inputs): + new_name = builder.get_unique_name(inp.name) + mapped_names[inp.name] = new_name builder.add_node( 'Identity', builder.get_unique_name('_sub_' + name), [name], [new_name]) # adding nodes for node in self.model.graph.node: - new_inputs = [mapped_names[i] for i in node.input] + new_inputs = [] + for i in node.input: + if i not in mapped_names: + raise RuntimeError( + "Unable to find input %r in %r." % (i, mapped_names)) + new_inputs.append(mapped_names[i]) new_outputs = [] for o in node.output: new_name = builder.get_unique_name(o) @@ -87,10 +93,14 @@ def add_to(self, builder): atts = {} for att in node.attribute: - if att.type == 2: + if att.type == 2: # .i value = att.i atts[att.name] = value continue + if att.type == 6: # .floats + value = list(att.floats) + atts[att.name] = value + continue raise NotImplementedError( "Unable to copy attribute type %r (%r)." % ( att.type, att)) @@ -118,7 +128,7 @@ class OnnxSubEstimator(OnnxSubOnnx): :param inputs: inputs :param op_version: targetted opset :param options: to rewrite the options used to convert the model - :param input_types: the implementation may be wrong in guessing + :param initial_types: the implementation may be wrong in guessing the input types of the model, this parameter can be used to overwrite them, usually a dictionary `{ input_name: numpy array as an example }` @@ -134,36 +144,110 @@ class OnnxSubEstimator(OnnxSubOnnx): def __init__(self, model, *inputs, op_version=None, output_names=None, options=None, - input_types=None, **kwargs): + initial_types=None, **kwargs): if model is None: raise ValueError("Model cannot be None.") + onx = OnnxSubEstimator._to_onnx( + model, inputs, op_version=op_version, options=options, + initial_types=initial_types, **kwargs) OnnxSubOnnx.__init__( - self, *inputs, op_version=op_version, - output_names=output_names, **kwargs) - self.model = model + self, onx, *inputs, output_names=output_names) + self.ml_model = model self.options = options - self.input_types = input_types + self.initial_types = initial_types + self.op_version = op_version def __repr__(self): "usual" atts = {} for att in ['op_version', 'output_names', 'options', - 'input_types']: + 'initial_types']: value = getattr(self, att, None) if value is not None: atts[att] = value atts.update(self.kwargs) - msg = ", ".join("%s=%r" for k, v in atts.items()) + msg = ", ".join("%s=%r" % (k, v) for k, v in atts.items()) if len(atts) > 0: - msg += ", " + msg = ", " + msg return "%s(%r%s)" % ( - self.__class__.__name__, self.model, msg) + self.__class__.__name__, self.ml_model, msg) - def add_to(self, builder): + @staticmethod + def _to_onnx(model, inputs, op_version=None, options=None, + initial_types=None, **kwargs): """ - Adds to graph builder. - - :param builder: instance of @see cl _GraphBuilder, - it must have a method `add_node` + Converts a model into ONNX and inserts it into an ONNX graph. + + :param model: a trained machine learned model + :param inputs: inputs + :param op_version: opset versions or None to use the latest one + :param options: options to change the behaviour of the converter + :param kwargs: additional parameters such as black listed or while listed + operators + :return: ONNX model + + The method currently supports models trained with + :epkg:`scikit-learn`, :epkg:`xgboost`, :epkg`:lightgbm`. """ - raise NotImplementedError() + from sklearn.base import BaseEstimator + + if isinstance(model, BaseEstimator): + return OnnxSubEstimator._to_onnx_sklearn( + model, inputs, op_version=op_version, options=options, + initial_types=initial_types, **kwargs) + raise RuntimeError( + "Unable to convert into ONNX model type %r." % type(model)) + + @staticmethod + def _to_onnx_sklearn(model, inputs, op_version=None, options=None, + initial_types=None, **kwargs): + """ + Converts a :epkg:`scikit-learn` model into ONNX + and inserts it into an ONNX graph. The library relies on + function @see fn to_onnx and library :epkg:`skearn-onnx`. + + :param model: a trained machine learned model + :param inputs: inputs + :param op_version: opset versions or None to use the latest one + :param initial_types: if None, the input types are guessed from the + inputs. The function converts into ONNX the previous + node of the graph and tries to infer the initial_types + with the little informations it has. It may not work. + It is recommended to specify this parameter. + :param options: options to change the behaviour of the converter + :param kwargs: additional parameters such as black listed or while listed + operators + :return: ONNX model + + Default options is `{'zipmap': False}` for a classifier. + """ + from ..onnx_conv.convert import to_onnx + if options is None: + from sklearn.base import ClassifierMixin + if isinstance(model, ClassifierMixin): + options = {'zipmap': False} + if initial_types is None: + # Let's to infer them from previous nodes. + raise NotImplementedError( + "initial_types is None and the method cannot guess the " + "initial_types of the model.") + + if isinstance(initial_types, numpy.ndarray): + if len(inputs) != 1: + raise RuntimeError( + "The model has %s inputs but only %d input are " + "described in 'initial_types'." % ( + len(inputs), 1)) + X = initial_types + initial_types = None + elif len(inputs) != len(initial_types): + raise RuntimeError( + "The model has %s inputs but only %d input are " + "described in 'initial_types'." % ( + len(inputs), len(initial_types))) + else: + X = None + + onx = to_onnx(model, X, initial_types=initial_types, options=options, + rewrite_ops=True, target_opset=op_version, **kwargs) + return onx