From f2e806dc2eb99daa6aeda74f7bf97929eb89b4bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 15 Feb 2022 02:01:10 +0100 Subject: [PATCH 01/13] exploration --- _doc/sphinxdoc/source/api/index.rst | 1 + _doc/sphinxdoc/source/api/xop.rst | 35 ++ _unittests/ut_npy/test_xop.py | 27 ++ mlprodict/npy/_cache/__init__.py | 14 + mlprodict/npy/xauto.py | 240 +++++++++++++ mlprodict/npy/xop.py | 526 ++++++++++++++++++++++++++++ mlprodict/npy/xop_classes.py | 228 ++++++++++++ mlprodict/npy/xops.py | 228 ++++++++++++ mlprodict/npy/xops_opset.py | 142 ++++++++ 9 files changed, 1441 insertions(+) create mode 100644 _doc/sphinxdoc/source/api/xop.rst create mode 100644 _unittests/ut_npy/test_xop.py create mode 100644 mlprodict/npy/_cache/__init__.py create mode 100644 mlprodict/npy/xauto.py create mode 100644 mlprodict/npy/xop.py create mode 100644 mlprodict/npy/xop_classes.py create mode 100644 mlprodict/npy/xops.py create mode 100644 mlprodict/npy/xops_opset.py diff --git a/_doc/sphinxdoc/source/api/index.rst b/_doc/sphinxdoc/source/api/index.rst index fcd8793db..28d161bc6 100644 --- a/_doc/sphinxdoc/source/api/index.rst +++ b/_doc/sphinxdoc/source/api/index.rst @@ -12,6 +12,7 @@ This is a summary of functions this modules provides. onnx_conv sklapi npy + xop **ONNX runtime** diff --git a/_doc/sphinxdoc/source/api/xop.rst b/_doc/sphinxdoc/source/api/xop.rst new file mode 100644 index 000000000..94fb3cd38 --- /dev/null +++ b/_doc/sphinxdoc/source/api/xop.rst @@ -0,0 +1,35 @@ + +.. _l-xop-onnxpy: + +Create ONNX graphs +================== + +.. contents:: + :local: + +API ++++ + +.. autosignature:: mlprodict.npy.xops.ClassFactory + +.. autosignature:: mlprodict.npy.xops.dynamic_class_creation + +.. autosignature:: mlprodict.npy.xops_classes.Variable + +.. autosignature:: mlprodict.npy.xops_classes.GraphBuilder + +.. autosignature:: mlprodict.npy.xop.OnnxOperator + +.. autosignature:: mlprodict.npy.xop.OnnxOperatorItem + +.. autosignature:: mlprodict.npy.xops_opset.OnnxReduceSumApi11 + +.. autosignature:: mlprodict.npy.xops_opset.OnnxSplitApi11 + +.. autosignature:: mlprodict.npy.xops_opset.OnnxSqueezeApi11 + +.. autosignature:: mlprodict.npy.xops_opset.OnnxUnsqueezeApi11 + +.. autosignature:: mlprodict.npy.xops_opset.OnnxReduceL2_typed + +.. autosignature:: mlprodict.npy.xops_opset.OnnxReshapeApi13 diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py new file mode 100644 index 000000000..62d7d1a0e --- /dev/null +++ b/_unittests/ut_npy/test_xop.py @@ -0,0 +1,27 @@ +# pylint: disable=E0611 +""" +@brief test log(time=3s) +""" +import unittest +import numpy +from pyquickhelper.pycode import ExtTestCase +from mlprodict.npy.xops import OnnxAbs +from mlprodict.onnxrt import OnnxInference + + +class TestXOps(ExtTestCase): + + def test_float32(self): + self.assertEqual(numpy.float32, numpy.dtype('float32')) + + def test_onnx_abs(self): + ov = OnnxAbs('X', output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=1) + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x), got['Y']) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/npy/_cache/__init__.py b/mlprodict/npy/_cache/__init__.py new file mode 100644 index 000000000..42ecaa53b --- /dev/null +++ b/mlprodict/npy/_cache/__init__.py @@ -0,0 +1,14 @@ +""" +@file +@brief Cache documentation for OnnxOps. + +.. versionadded:: 0.9 +""" +import os + + +def cache_folder(): + """ + Returns this folder. + """ + return os.path.abspath(os.path.dirname(__file__)) diff --git a/mlprodict/npy/xauto.py b/mlprodict/npy/xauto.py new file mode 100644 index 000000000..7f8e3f3fb --- /dev/null +++ b/mlprodict/npy/xauto.py @@ -0,0 +1,240 @@ +""" +@file +@brief Automates the generation of the documentation. + +.. versionadded:: 0.9 +""" +import textwrap +import onnx +import onnx.defs +from onnx.defs import OpSchema + + +def _get_doc_template(): + try: + from jinja2 import Template + except ImportError: + class Template: + "Docstring template" + def __init__(self, *args): + pass + + def render(self, **context): + schemas = context['schemas'] + rows = [] + for sch in schemas: + doc = sch.doc or '' + name = sch.name + if name is None: + raise RuntimeError("An operator must have a name.") + rows.extend([name, "=" * len(name), + "", doc, ""]) + return "\n".join(rows) + + return Template(textwrap.dedent(""" + {% for sch in schemas %} + + {{format_name_with_domain(sch)}} + {{'=' * len(format_name_with_domain(sch))}} + + **Version** + + *Onnx name:* `{{sch.name}} <{{build_doc_url(sch)}}{{sch.name}}>`_ + + {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %} + No versioning maintained for experimental ops. + {% else %} + This version of the operator has been {% if + sch.deprecated %}deprecated{% else %}available{% endif %} since + version {{sch.since_version}}{% if + sch.domain %} of domain {{sch.domain}}{% endif %}. + {% if len(sch.versions) > 1 %} + Other versions of this operator: + {% for v in sch.version[:-1] %} {{v}} {% endfor %} + {% endif %} + {% endif %} + + **Summary** + + {{process_documentation(sch.doc)}} + + {% if sch.attributes %} + **Attributes** + + {% for _, attr in sorted(sch.attributes.items()) %}* *{{attr.name}}*{% + if attr.required %} (required){% endif %}: {{attr.description}} {% + if attr.default_value %}Default value is + ``{{str(attr.default_value).replace('\\n', ' ').strip()}}``{% + endif %} + {% endfor %} + {% endif %} + + {% if sch.inputs %} + **Inputs** + + {% if sch.min_input != sch.max_input %}Between {{sch.min_input + }} and {{sch.max_input}} inputs. + {% endif %} + {% for ii, inp in enumerate(sch.inputs) %} + * *{{getname(inp, ii)}}*{{format_option(inp)}}{{inp.typeStr}}: {{ + inp.description}}{% endfor %} + {% endif %} + + {% if sch.outputs %} + **Outputs** + + {% if sch.min_output != sch.max_output %}Between {{sch.min_output + }} and {{sch.max_output}} outputs. + {% endif %} + {% for ii, out in enumerate(sch.outputs) %} + * *{{getname(out, ii)}}*{{format_option(out)}}{{out.typeStr}}: {{ + out.description}}{% endfor %} + {% endif %} + + {% if sch.type_constraints %} + **Type Constraints** + + {% for ii, type_constraint in enumerate(sch.type_constraints) + %}* {{getconstraint(type_constraint, ii)}}: {{ + type_constraint.description}} + {% endfor %} + {% endif %} + + {% endfor %} + """)) + + +_template_operator = _get_doc_template() + + +def get_domain_list(): + """ + Returns the list of available domains. + """ + return list(sorted(set(map(lambda s: s.domain, + onnx.defs.get_all_schemas_with_history())))) + + +def get_rst_doc(op_name=None): + """ + Returns a documentation in RST format + for all :class:`OnnxOperator`. + + :param op_name: operator name of None for all + :return: string + + The function relies on module :epkg:`jinja2` or replaces it + with a simple rendering if not present. + """ + if op_name is None: + schemas = onnx.defs.get_all_schemas_with_history() + elif isinstance(op_name, str): + schemas = [schema for schema in onnx.defs.get_all_schemas_with_history( + ) if schema.name == op_name] + if len(schemas) > 1: + raise RuntimeError( + "Multiple operators have the same name '{}'.".format(op_name)) + elif not isinstance(op_name, list): + schemas = [op_name] + if len(schemas) == 0: + raise ValueError( + "Unable to find any operator with name '{}'.".format(op_name)) + + # from onnx.backend.sample.ops import collect_sample_implementations + # from onnx.backend.test.case import collect_snippets + # SNIPPETS = collect_snippets() + # SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() + def format_name_with_domain(sch): + if sch.domain: + return '{} ({})'.format(sch.name, sch.domain) + else: + return sch.name + + def format_option(obj): + opts = [] + if OpSchema.FormalParameterOption.Optional == obj.option: + opts.append('optional') + elif OpSchema.FormalParameterOption.Variadic == obj.option: + opts.append('variadic') + if getattr(obj, 'isHomogeneous', False): + opts.append('heterogeneous') + if opts: + return " (%s)" % ", ".join(opts) + else: + return "" + + def getconstraint(const, ii): + if const.type_param_str: + name = const.type_param_str + else: + name = str(ii) + if const.allowed_type_strs: + name += " " + ", ".join(const.allowed_type_strs) + return name + + def getname(obj, i): + name = obj.name + if len(name) == 0: + return str(i) + else: + return name + + def process_documentation(doc): + if doc is None: + doc = '' + doc = textwrap.dedent(doc) + main_docs_url = "https://github.com/onnx/onnx/blob/master/" + rep = { + '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_', + '[the doc](Broadcasting.md)': + '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_', + '
': '', + '
': '', + '
': '* ', + '
': ' ', + '': '', + '
': '', + '': '``', + '': '``', + '
': '\n', + } + for k, v in rep.items(): + doc = doc.replace(k, v.format(main_docs_url)) + move = 0 + lines = [] + for line in doc.split('\n'): + if line.startswith("```"): + if move > 0: + move -= 4 + lines.append("\n") + else: + lines.append("::\n") + move += 4 + elif move > 0: + lines.append(" " * move + line) + else: + lines.append(line) + return "\n".join(lines) + + def build_doc_url(sch): + doc_url = "https://github.com/onnx/onnx/blob/master/docs/Operators" + if "ml" in sch.domain: + doc_url += "-ml" + doc_url += ".md" + doc_url += "#" + if sch.domain not in (None, '', 'ai.onnx'): + doc_url += sch.domain + "." + return doc_url + + fnwd = format_name_with_domain + tmpl = _template_operator + docs = tmpl.render(schemas=schemas, OpSchema=OpSchema, + len=len, getattr=getattr, sorted=sorted, + format_option=format_option, + getconstraint=getconstraint, + getname=getname, enumerate=enumerate, + format_name_with_domain=fnwd, + process_documentation=process_documentation, + build_doc_url=build_doc_url, + str=str) + return docs diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py new file mode 100644 index 000000000..0d438622e --- /dev/null +++ b/mlprodict/npy/xop.py @@ -0,0 +1,526 @@ +# pylint: disable=E1101 +""" +@file +@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. + +.. versionadded:: 0.9 +""" +from logging import getLogger +import numpy +from scipy.sparse import coo_matrix +from onnx import GraphProto, TensorProto +from onnx.helper import make_graph, make_model # pylint: disable=W0611 +from onnx.numpy_helper import from_array +from ..tools.asv_options_helper import get_opset_number_from_onnx +from .xop_classes import Variable, GraphBuilder + + +logger = getLogger('mlprodict.xop') + + +class OnnxOperatorItem: + """ + Accessor to one of the output returned by a *OnnxOperator*. + + :param onx_op: OnnxOperator + :param index: integer + """ + + def __init__(self, onx_op, index, op_version=None): + if not isinstance(index, int): + raise TypeError("index must be an integer.") + self.onx_op = onx_op + self.index = index + self.op_version = op_version + + def __str__(self): + """ + usual + """ + return "%s[%d]" % (str(self.onx_op), self.index) + + def get_output_name(self, i=0): + """ + Returns the output. + """ + if i != 0: + raise IndexError("Can only return the first item.") + return self.onx_op.get_output_name(self.index) + + def get_output(self, i=0): + """ + Returns the output. + """ + if i != 0: + raise IndexError("Can only return the first item.") + return self.onx_op.get_output(self.index) + + @property + def outputs(self): + """ + Returns the outputs of the node. + """ + if self.onx_op is None: + raise RuntimeError( + "self.onx_op cannot be None, type(self)={}".format( + type(self))) + if self.index is None: + raise RuntimeError( + "self.index cannot be None, type(self)={}".format( + type(self))) + outputs = self.onx_op.outputs + if outputs is None: + raise RuntimeError( + "self.onx_op.outputs cannot be None, " + "type(self)={}, type(self.onx_op)={}, " + "type(self.onx_op.state)={}".format( + type(self), type(self.onx_op), type(self.onx_op.state))) + return outputs[self.index:self.index + 1] + + def get_output_type_inference(self, input_shapes=None): + """ + Returns the inferred shape. + """ + if self.onx_op is None: + raise RuntimeError( + "self.onx_op cannot be None, type(self)={}".format( + type(self))) + if self.index is None: + raise RuntimeError( + "self.index cannot be None, type(self)={}".format( + type(self))) + outputs = self.onx_op.get_output_type_inference(input_shapes) + if outputs is None: + raise RuntimeError( + "self.onx_op.outputs cannot be None, " + "type(self)={}, type(self.onx_op)={}, " + "type(self.onx_op.state)={}".format( + type(self), type(self.onx_op), type(self.onx_op.state))) + return outputs[self.index:self.index + 1] + + +class OnnxOperator: + """ + Ancestor to every *ONNX* operator exposed in + :mod:`mlprodict.npy.xops` and :mod:`mlprodict.npy.xops_ml`. + + :param inputs: list of inputs expected by the operator + :param op_version: to select a specific version of the operator + :param output_names: used defined names for the outputs + :param domain: to overwrite the default domain + :param global_context: operator *If* executes one subgraph + whose nodes may use one existing output in the current + context. If not used in the main graph, these operators + are not linked to the output and cannot be retrieved. + *global_context* is a dictionary mapped the subgraph input + names to these operators. + :param clear_subgraph_inputs: clears subgraphs outputs. + Operator *If* does take subgraphs as attribute, + there are subgraphs with no inputs and + global variable as hidden inputs. + :param kwargs: additional parameters of the operator + + .. versionadd:: 0.9 + """ + + def __init__(self, *inputs, op_version=None, output_names=None, + domain=None, global_context=None, + clear_subgraph_inputs=False, **kwargs): + + if (output_names is None and + self.__class__.__name__.startswith("OnnxScan")): + raise NotImplementedError( + "The class cannot infer the number of variables " + "for node '{}' yet. output_names must be specified" + ".".format(self.__class__.__name__)) + if isinstance(output_names, (str, Variable)): + output_names = [output_names] + if isinstance(output_names[0], str): + output_names[0] = Variable(output_names[0]) + elif isinstance(output_names, list): + if len(output_names) == 0: + raise ValueError( + "output_names cannot be empty (operator %r)." + "" % self.__class__.__name__) + output_names = output_names.copy() + for i in range(len(output_names)): + if isinstance(output_names[i], str): + output_names[i] = Variable(output_names[i]) + elif output_names is not None: + raise TypeError( + "output_names must be a string or a list not %r." + "" % type(output_names)) + + if op_version is None: + if domain == '': + self.op_version = get_latest_tested_opset_version() + else: + self.op_version = None + else: + self.op_version = op_version + self.since_version = self.__class__.since_version + + if (self.op_version is not None and + self.op_version < self.since_version): + schema = self.find_schema(self.op_version) + self.since_version = schema.since_version + self.expected_inputs = schema.expected_inputs.copy() + self.expected_outputs = schema.expected_outputs.copy() + self.input_range = schema.input_range + self.output_range = schema.output_range + else: + self.expected_inputs = ( + None if self.__class__.expected_inputs is None + else self.__class__.expected_inputs.copy()) + self.expected_outputs = ( + None if self.__class__.expected_outputs is None + else self.__class__.expected_outputs.copy()) + self.input_range = self.__class__.input_range + self.output_range = self.__class__.output_range + if self.__class__.__name__ not in { + 'OnnxScan', 'OnnxLoop', 'OnnxIf'}: + # TODO: the minimum opset depends on embedded graph + # by default, it takes the given op_version but the + # optimal value could be lower. + self.op_version = self.since_version + if self.op_version is None: + self.op_version = self.since_version + + if (self.op_version is not None and + self.op_version < self.since_version): + raise RuntimeError( + "Operator '{}': requested version {} < " + "{} schema version.".format( + self.__class__.__name__, + self.op_version, self.since_version)) + + self.state = None + self.domain = domain + self.kwargs = kwargs + self.onnx_prefix_name = None + + # check inputs + if len(inputs) == 0: + if self.input_range[0] == self.input_range[1]: + self.inputs = [OnnxOperator.UnscopedVariable(_[0]) + for _ in self.expected_inputs] + else: + # The number of inputs may vary. + self.inputs = None + else: + self.inputs = [] + for inp in inputs: + if isinstance(inp, str): + self.inputs.append(Variable(inp)) + elif isinstance(inp, (OnnxOperator, Variable, + OnnxOperatorItem)): + self.inputs.append(inp) + elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)): + self.inputs.append( + OnnxOperator.ConstantVariable(inp)) + else: + raise TypeError( + "Unable to interpret the input name for type {} in " + "operator '{}' (value={}).".format( + type(inp), self.__class__.__name__, inp)) + + if self.inputs is not None: + if (len(self.inputs) < self.input_range[0] or + len(self.inputs) > self.input_range[1]): + raise RuntimeError( + "Operator '{}' expects a number of inputs " + "in [{}, {}] not {} (expected opset={}, " + "class opset={})".format( + self.operator_name, *self.input_range, + len(self.inputs), op_version, self.op_version)) + # global context + if global_context is None: + self.global_context = None + else: + if not isinstance(global_context, dict): + raise TypeError( + "global_context must be a dictionary not %r." + "" % type(global_context)) + for k, v in global_context.items(): + if not isinstance(v, (OnnxOperator, OnnxOperatorItem)): + raise TypeError( + "Value %r in must be an OnnxOperator or an " + "OnnxOperatorItem not %r." % (k, type(v))) + self.global_context = global_context + + # check output + self.output_names = output_names + self.output_variables = None + + if self.output_names is not None: + if len(self.output_names) == 0: + raise ValueError( + "output_names can be None but cannot be empty for " + "operator %r." % self) + if self.output_variables is None: + self.output_variables = [None for o in self.output_names] + for i in range(len(self.output_names)): + name = self.output_names[i] + if isinstance(name, Variable): + self.output_variables[i] = name + else: + raise TypeError("output_names must be a list of strings " + "and element %r is %r (%r)" % ( + i, type(name), name)) + if all(map(lambda x: x is None, self.output_variables)): + self.output_variables = None + + if (self.output_names is not None and ( + self.expected_outputs is None or + len(self.output_names) > len(self.expected_outputs))): + if self.expected_outputs is None: + self.expected_outputs = [] + for i in range(len(self.expected_outputs), + len(self.output_names)): + self.expected_outputs.append((self.output_names[i], None)) + + if (self.expected_inputs is None or + len(self.inputs) > len(self.expected_inputs)): + if self.expected_inputs is None: + self.expected_inputs = [] + for i in range(len(self.expected_inputs), + len(self.inputs)): + inp = self.inputs[i] + if isinstance(inp, str): + inp = (inp, None) + elif hasattr(inp, 'add_to'): + # OnnxOperator + existing = set(_[0] for _ in self.expected_inputs) + i = 10 + name = "input%d" % (10 + i) + while name in existing: + i += 1 + name = "input%d" % (10 + i) + inp = (name, None) + self.expected_inputs.append(inp) + + self.output_names_ = None + self._post_process_attributes( + clear_subgraph_inputs=clear_subgraph_inputs) + logger.debug( + '[Ops] +%s-%d (%s) id=%d', + self.__class__.__name__, self.op_version, self.domain, id(self)) + + def _post_process_attributes(self, clear_subgraph_inputs=False): + """ + Walks through attributes and replaces them by ONNX + values. + """ + # Looks into attributes if there is any tuple + # (GraphProto, OnnxOperator). In that case, the function + # replaces the tuple by the graph proto and keeps + # in attributes graph_algebra the OnnxOperator + # which is the source of it. + updates = {} + graph_algebra = {} + for k, v in self.kwargs.items(): + if isinstance(v, tuple) and isinstance(v[0], GraphProto): + updates[k] = v[0] + graph_algebra[k] = v[1] + if len(graph_algebra) > 0: + self.kwargs.update(updates) + self.graph_algebra = graph_algebra + + if clear_subgraph_inputs: + for k, v in self.kwargs.items(): + if isinstance(v, GraphProto): + del v.input[:] + + if self.__class__.__name__ == "OnnxConstantOfShape": + if "value" in self.kwargs: + value = self.kwargs['value'] + if isinstance(value, TensorProto): + return + if isinstance(value, numpy.ndarray): + if value.shape == (1, ): + val = value[0] + elif len(value.shape) == 0: + val = value + else: + raise RuntimeError( + "Unexpected shape %r for value, it must be " + "an array of one element." % value.shape) + self.kwargs['value'] = from_array( + numpy.array([val], dtype=value.dtype)) + return + raise TypeError( + "Unexpected type %r for value. It should be an array " + "of one element." % type(value)) + return + + if self.__class__.__name__ == "OnnxCast": + if "to" in self.kwargs: + value = self.kwargs['to'] + if isinstance(value, int): + return + to = guess_proto_type(_guess_numpy_type(value, None)) + self.kwargs['to'] = to + return + + def find_schema(self, op_version): + """ + Checks if there is an existing schema for a + specific version. + + :param op_version: requested version + :return: schema + """ + if not hasattr(self.__class__, 'past_version'): + raise RuntimeError("Missing attribute 'past_version', there is " + "no other available schema.") + found = None + for v in self.past_version.values(): + if v.since_version > op_version: + continue + if found is None or v.since_version > found.since_version: + found = v + if found is None: + raise RuntimeError( + "Operator '{}': requested version {} < " + "{} schema version.".format( + self.__class__.__name__, + op_version, self.since_version)) + return found + + def __str__(self): + """ + usual + """ + return "{}({} in) -> {}".format( + self.__class__.__name__, + len(self.inputs) if self.inputs is not None else 0, + [str(o) for o in self.output_names] + if self.output_names is not None else "?") + + def set_onnx_name_prefix(self, onnx_prefix_name): + """ + Provides a name to define a prefix in the onnx graph + to avoid to get unreadable node names. The method + does not overwrite an existing name, it propagates + the prefix to inputs and stops the propagation + if the prefix is already defined. + """ + if self.onnx_prefix_name is None: + self.onnx_prefix_name = onnx_prefix_name + for inp in self.inputs: + if hasattr(inp, 'set_onnx_prefix_name'): + inp.set_onnx_name_prefix(onnx_prefix_name) + return self + + @property + def onnx_prefix(self): + if self.onnx_prefix_name is None: + name = self.__class__.__name__ + if name.startswith("Onnx"): + name = name[4:] + return name[:2] + return self.onnx_prefix_name + + def __getitem__(self, index): + """ + Returns an accessor to one of the output + of this node. + """ + return OnnxOperatorItem(self, index, self.op_version) + + def _node_to_graph(self, other_outputs=None): + """ + Builds a graph as a list of nodes to walk through in that order. + """ + outputs = [self] + if other_outputs is not None: + outputs += other_outputs + + # walk through graphs + stack = list(outputs) + inputs = [] + memo = [] + while len(stack) > 0: + memo.extend(stack) + new_stack = [] + for obj in stack: + for inp in obj.inputs: + if isinstance(inp, OnnxOperator): + new_stack.append(inp) + else: + inputs.append(inp) + stack = new_stack + + # eliminate duplicates + done = set() + nodes = [] + for node in memo: + if id(node) in done: + continue + done.add(id(node)) + nodes.append(node) + return nodes, inputs + + def add_to(self, builder): + """ + Adds to graph builder. + """ + inputs = builder.get_input_names(self, self.inputs) + outputs = builder.get_output_names(self, self.output_names) + builder.add_node( + self.operator_name, + builder.get_unique_name('_' + self.operator_name.lower()), + inputs, outputs, domain=self.domain, opset=self.op_version, + **self.kwargs) + + def to_onnx(self, inputs=None, outputs=None, + other_outputs=None, target_opset=None, + verbose=0): + """ + Converts this operator into an ONNX graph. + + :param inputs: specific inputs (as a dictionary) or + default inputs if not specified + :param outputs: specific outputs + :param other_outputs: additional outputs to consider + as graph outputs but not outputs of this particular + node + :param target_opset: dictionary with target opset per domain, + None for the default one + :param verbose: prints information + """ + if isinstance(target_opset, dict): + dom = self.domain or '' + target_opset = target_opset.get(dom, None) + elif isinstance(target_opset, int): + if self.domain not in ('', None): + # The target_opset is for the domain '' + # We ignore it. + target_opset = None + elif target_opset is not None: + raise TypeError( + "target_opset must be a dictionary {domain: " + "target_opset} not %r for operator %r." % ( + target_opset, self.__class__.__name__)) + + if self.domain in ('', None) and target_opset == 1: + raise RuntimeError("target_opset cannot be 1.") + if (self.op_version is not None and target_opset is not None and + self.op_version > target_opset): + raise RuntimeError( + "target_opset={} is lower than the version={} requested " + "for this node '{}'.".format( + target_opset, self.op_version, self.__class__.__name__)) + + # get the graph + nodes, graph_inputs = self._node_to_graph(other_outputs) + if len(nodes) == 0: + raise RuntimeError( # pragma: no cover + "Node list is empty.") + builder = GraphBuilder() + for node in nodes: + node.add_to(builder) + + return builder.to_onnx(inputs=inputs, outputs=outputs, + target_opset=target_opset, + verbose=verbose) diff --git a/mlprodict/npy/xop_classes.py b/mlprodict/npy/xop_classes.py new file mode 100644 index 000000000..9ef2faaf7 --- /dev/null +++ b/mlprodict/npy/xop_classes.py @@ -0,0 +1,228 @@ +""" +@file +@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. + +.. versionadded:: 0.9 +""" +import numpy +from onnx.helper import ( + make_node, make_graph, make_model, + make_tensor_value_info) +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE +from ..tools.asv_options_helper import get_opset_number_from_onnx + + +def _default_OPSET_TO_IR_VERSION(): + return { + 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, + 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7, + 13: 7, 14: 7, 15: 8 + } + + +class Variable: + """ + An input to an ONNX graph. + """ + + def __init__(self, name, dtype=None): + self.name = name + self.dtype = dtype + + def __repr__(self): + "usual" + return "%s(%r, %r)" % ( + self.__class__.__name__, self.name, self.dtype) + + +class GraphBuilder: + """ + Graph builder. + """ + + def __init__(self): + from .xop import OnnxOperator, OnnxOperatorItem + self.initializer = [] + self.node = [] + self.input = [] + self.output = [] + self.opsets = {} + self.names = set() + self.input_names = set() + self.output_names = {} + self.cl_onnx_op = OnnxOperator + self.cl_onnx_op_item = OnnxOperatorItem + + def get_unique_name(self, name): + """ + Returns a unique name to name an output. + """ + if not isinstance(name, str): + raise TypeError( # pragma: no cover + "name must be a string not %r." % type(name)) + if name not in self.names: + self.names.add(name) + return name + i = 1 + new_name = "%s_%d" % (name, i) + while new_name in self.names: + i += 1 + new_name = "%s_%d" % (name, i) + self.names.add(new_name) + return new_name + + def get_output_names(self, node, outputs): + """ + Returns a new output name for a node if it exists + or create a new one. + """ + names = [] + for index, name in enumerate(outputs): + key = id(node), index + if key in self.output_names: + name = self.output_names[key] + else: + output = node.output_names[index] + if isinstance(output, str): + n = output + elif isinstance(output, Variable): + n = output.name + else: + raise TypeError( # pragma: no cover + "Unexpected type %r for output %d." % ( + type(output), index)) + name = self.get_unique_name(n) + self.output_names[key] = name + names.append(name) + return names + + def get_input_names(self, node, inputs): + """ + Returns input names for node *node* and inputs *inputs*. + + :param node: node + :param inputs: inputs + :return: name + """ + names = [] + for i in inputs: + if isinstance(i, str): + names.append(i) + self.input_names.add(i) + self.names.add(i) + elif isinstance(i, Variable): + names.append(i.name) + self.names.add(i.name) + self.input_names.add(i.name) + elif isinstance(i, self.cl_onnx_op): + name = self.get_output_name(i, 0) + names.append(name) + self.names.add(name) + elif isinstance(i, self.cl_onnx_op_item): + name = self.get_output_name(i.onnx_op, i.index) + names.append(name) + self.names.add(name) + else: + raise TypeError( # pragma: no cover + "Unexpected type for an input %r." % type(i)) + return names + + def add_node(self, op_type, name, inputs, outputs, domain='', + opset=None, **attributes): + """ + Adds a node to the graph. + + :param op_type: operator type + :param name: node name + :param inputs: inputs name list + :param outputs: outputs name list + :param domain: node domain + :param opset: node opset + """ + if not isinstance(inputs, list): + raise TypeError( # pragma: no cover + "inputs must be a list not %r." % type(inputs)) + if not isinstance(outputs, list): + raise TypeError( # pragma: no cover + "inputs must be a list not %r." % type(outputs)) + if any(map(lambda x: not isinstance(x, str), inputs)): + raise TypeError( # pragma: no cover + "inputs must be all strings not %r." % inputs) + if any(map(lambda x: not isinstance(x, (str, Variable)), outputs)): + raise TypeError( # pragma: no cover + "outputs must be all strings not %r." % outputs) + if opset is not None: + if domain not in self.opsets: + self.opsets[domain] = opset + else: + self.opsets[domain] = max(opset, self.opsets[domain]) + node = make_node(op_type, inputs, outputs, name=name, + domain=domain) + self.node.append(node) + + def _process_io(self, inputs, input_names): + if inputs is None: + return [ + make_tensor_value_info( + 'X', TensorProto.FLOAT, None) # pylint: disable=disable=E1101 + for name in self.input_names] + + if inputs in NP_TYPE_TO_TENSOR_TYPE: + inputs = [inputs] + elif numpy.dtype(inputs) in NP_TYPE_TO_TENSOR_TYPE: + inputs = [inputs] + if len(input_names) != len(inputs): + raise RuntimeError( # pragma: no cover + "Mismatch between %r and %r." % (input_names, inputs)) + if isinstance(input_names, dict): + if len(input_names) == 1: + input_names = list(input_names.values()) + else: + raise NotImplementedError( + "Unexpected %r." % input_names) + res = [] + for inp, name in zip(inputs, input_names): + if inp in NP_TYPE_TO_TENSOR_TYPE: + res.append( + make_tensor_value_info( + name, NP_TYPE_TO_TENSOR_TYPE[inp], None)) + elif numpy.dtype(inp) in NP_TYPE_TO_TENSOR_TYPE: + res.append( + make_tensor_value_info( + name, NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(inp)], None)) + else: + raise RuntimeError( + "Unexpected tuple(%r, %r)." % (inp, name)) + return res + + def to_onnx(self, inputs=None, outputs=None, + target_opset=None, verbose=0): + """ + Converts this operator into an ONNX graph. + + :param inputs: specific inputs (as a dictionary) or + default inputs if not specified + :param outputs: specific outputs + :param target_opset: dictionary with target opset per domain, + None for the default one + :param verbose: prints information + :return: onnx graph + """ + # inputs and outputs + self.input = self._process_io(inputs, self.input_names) + self.output = self._process_io(outputs, self.output_names) + + graph = make_graph( + self.node, 'XOP', self.input, self.output, self.initializer) + onnx_model = make_model(graph) + opv = self.opsets.get('', get_opset_number_from_onnx()) + opset2ir = _default_OPSET_TO_IR_VERSION() + irv = opset2ir.get(opv, max(opset2ir.values())) + onnx_model.ir_version = irv + + del onnx_model.opset_import[:] # pylint: disable=disable=E1101 + for k, v in self.opsets.items(): + op_set = onnx_model.opset_import.add() # pylint: disable=disable=E1101 + op_set.domain = k or '' + op_set.version = v + return onnx_model diff --git a/mlprodict/npy/xops.py b/mlprodict/npy/xops.py new file mode 100644 index 000000000..82a6661e7 --- /dev/null +++ b/mlprodict/npy/xops.py @@ -0,0 +1,228 @@ +""" +@file +@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. + +.. versionadded:: 0.9 +""" +import sys +import os +import numpy +from scipy.sparse.coo import coo_matrix +import onnx +from .xauto import get_rst_doc +from ._cache import cache_folder +from .xop_classes import Variable + + +def ClassFactory(class_name, op_name, inputs, outputs, + input_range, output_range, + domain, attr_names, doc, + deprecated, since_version, + past_version): + """ + Dynamically creates a class for a specific operator. + + :param class_name: class name + :param op_name: operator type + :param inputs: expected inputs + :param outputs: expected outputs + :param input_range: input range + :param output_range: output_range + :param domain: domain + :param attr_names: attributes names + :param doc: docstring + :param deprecated: is the operator deprecated + :param since_version: available since version + :param past_version: list of versions + """ + from .xop import OnnxOperator, OnnxOperatorItem + + def __init__(self, *args, **kwargs): + + op_version = kwargs.pop('op_version', None) + if isinstance(op_version, dict): + op_version = op_version.get(domain, None) + + if op_version is None: + if len(args) == 0 and input_range[0] == input_range[1]: + args = [_[0] for _ in self.__class__.expected_inputs] + if not (input_range[0] <= len(args) <= input_range[1]): + raise RuntimeError("Unexpected number of inputs, " + "got {}, expecting {} for operator " + "'{}'.".format( + len(args), len(inputs), op_name)) + + attr_names = self.attr_names + if '_' in self.__class__.__name__: + op_version_class = int(self.__class__.__name__.split('_')[-1]) + if op_version is None: + op_version = op_version_class + try: + op_version = min(op_version, op_version_class) + except TypeError: + raise TypeError( # pylint: disable=W0707 + "Could not compare versions {} ? {} for " + "class '{}' since_version {}. Parameter 'op_version' " + "is probably missing when the class " + "is instantiated.".format( + op_version, op_version_class, class_name, + since_version)) + else: + op_version_class = None + + # By default, the op_version is None. + # None means the latest available. + if op_version is None: + op_version = since_version + + found = None + if op_version is not None: + # attr_names refers to the most recent version of + # this operator. We may need an older one. + for op in range(op_version, 0, -1): + name = '{}_{}'.format(self.__class__.__name__, op) + if name in self.past_version: + found = (name, op) + attr_names = self.past_version[name].attr_names + break + if (op_version_class is not None and found is not None and + found[-1] != op_version_class): + raise RuntimeError( + "op_version={} does not refer to the same opset as the class " + "name ('{}').".format(op_version, self.__class__.__name__)) + for key in kwargs: + if key in {'output_names', 'op_version', 'domain', 'ir_version', + 'global_context', 'clear_subgraph_inputs'}: + continue + if key not in attr_names: + raise TypeError("Argument '%s' not valid for '%s' opset=%s." + % (key, op_name, op_version)) + + if op_version is not None: + kwargs['op_version'] = op_version + # This class can only be created by a user. Let's check + # types are either a variable, an operator or an array. + for i, a in enumerate(args): + if isinstance(a, tuple): + if len(a) != 2: + raise TypeError( + "Input %r is a tuple or class %r, it must have two " + "elements (name, type) not %r." % (i, class_name, a)) + if not isinstance(a[0], str): + raise TypeError( + "Input %r is a tuple or class %r, it must be a tuple " + "(name, type) not %r." % (i, class_name, a)) + continue + if not isinstance(a, ( + Variable, OnnxOperator, numpy.ndarray, str, + OnnxOperatorItem, coo_matrix)): + raise TypeError( + "Unexpected type %r for input %r of operator %r. " + "It must be an instance of Variable (or a string), " + "OnnxOperator, OnnxOperatorItem, numpy.ndarray, " + "coo_matrix)." % ( + type(a), i, class_name)) + OnnxOperator.__init__(self, *args, **kwargs) + + newclass = type(class_name, (OnnxOperator,), + {"__init__": __init__, '__doc__': doc, + 'expected_inputs': inputs, + 'expected_outputs': outputs, + 'operator_name': op_name, + 'input_range': input_range, + 'output_range': output_range, + 'domain': domain, + 'is_deprecated': deprecated, + 'since_version': since_version, + 'past_version': past_version, + 'attr_names': attr_names, + '__module__': __name__}) + return newclass + + +def dynamic_class_creation(cache=False): + """ + Automatically generates classes for each of the operators + module *onnx* defines and described at + `Operators + `_ + and `Operators + `_. + """ + cache_dir = cache_folder() + res = {} + for schema in onnx.defs.get_all_schemas_with_history(): + if schema.support_level == schema.SupportType.EXPERIMENTAL: + # Skips experimental operators. + continue + # Multiple version can coexist. The last one is kept. + if schema.name in res: + if schema.since_version > res[schema.name].since_version: + # We keep the most recent one. + res[schema.name] = schema + else: + res[schema.name] = schema + res[schema.name + '_' + str(schema.since_version)] = schema + cls = {} + + def _c(obj, label, i): + name = '%s%d' % (obj.name or label, i) + tys = obj.typeStr or '' + return (name, tys) + + for name in sorted(res): + schema = res[name] + inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] + outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] + args = [p for p in schema.attributes] + + if '_' in name: + class_name = "Onnx" + name + else: + class_name = "Onnx" + schema.name + + filename = os.path.join( + cache_dir, + schema.name + '_' + str(schema.since_version) + ".rst") + if not cache and os.path.exists(filename): + with open(filename, "r", encoding="utf-8") as f: + doc = f.read() + else: + doc = get_rst_doc(schema) + if cache: + with open(filename, 'w', encoding='utf-8') as f: + f.write(doc) + + cl = ClassFactory(class_name, schema.name, inputs, outputs, + [schema.min_input, schema.max_input], + [schema.min_output, schema.max_output], + schema.domain, args, + "**Version**" + doc.split('**Version**')[-1], + getattr(schema, 'deprecated', False), + schema.since_version, {}) + cls[class_name] = cl + + # Retrieves past classes. + for name in cls: # pylint: disable=C0206 + if '_' not in name: + continue + main, version = name.split('_') + last = cls[main] + last.past_version[name] = cls[name] + + return cls + + +def _update_module(): + """ + Dynamically updates the module with operators defined + by *ONNX*. + """ + res = dynamic_class_creation() + this = sys.modules[__name__] + for k, v in res.items(): + setattr(this, k, v) + + +_update_module() diff --git a/mlprodict/npy/xops_opset.py b/mlprodict/npy/xops_opset.py new file mode 100644 index 000000000..75e9a3c72 --- /dev/null +++ b/mlprodict/npy/xops_opset.py @@ -0,0 +1,142 @@ +# pylint: disable=E0602 +""" +@file +@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. + +.. versionadded:: 0.9 +""" +import numpy as np + + +def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, + output_names=None): + """ + Adds operator ReduceSum with opset>=13 following API from opset 12. + """ + if op_version is None: + raise RuntimeError("op_version must be specified.") + if op_version is None or op_version >= 13: + if axes is None: + return OnnxReduceSum( + *x, keepdims=keepdims, op_version=op_version, + output_names=output_names) + return OnnxReduceSum( + *x, np.array(axes, dtype=np.int64), + keepdims=keepdims, op_version=op_version, + output_names=output_names) + if op_version >= 11: + if axes is None: + return OnnxReduceSum_11( + *x, keepdims=keepdims, + op_version=op_version, output_names=output_names) + return OnnxReduceSum_11( + *x, axes=axes, keepdims=keepdims, + op_version=op_version, output_names=output_names) + if axes is None: + return OnnxReduceSum_1(*x, keepdims=keepdims, + op_version=op_version, + output_names=output_names) + return OnnxReduceSum_1(*x, axes=axes, keepdims=keepdims, + op_version=op_version, output_names=output_names) + + +def OnnxSplitApi11(*x, axis=0, split=None, op_version=None, + output_names=None): + """ + Adds operator Split with opset>=13 following API from opset 11. + """ + if op_version is None: + raise RuntimeError("op_version must be specified.") + if op_version is None or op_version >= 13: + if split is None: + return OnnxSplit( + *x, axis=axis, op_version=op_version, + output_names=output_names) + return OnnxSplit( + *x, np.array(split, dtype=np.int64), axis=axis, + op_version=op_version, output_names=output_names) + if op_version >= 11: + if split is None: + return OnnxSplit_11( + *x, axis=axis, op_version=op_version, + output_names=output_names) + return OnnxSplit_11( + *x, split=split, axis=axis, op_version=op_version, + output_names=output_names) + if split is None: + return OnnxSplit_2( + *x, axis=axis, op_version=op_version, output_names=output_names) + return OnnxSplit_2(*x, split=split, axis=axis, + op_version=op_version, output_names=output_names) + + +def OnnxSqueezeApi11(*x, axes=None, op_version=None, + output_names=None): + """ + Adds operator Squeeze with opset>=13 following API from opset 11. + """ + if op_version is None: + raise RuntimeError("op_version must be specified.") + if op_version is None or op_version >= 13: + return OnnxSqueeze( + *x, np.array(axes, dtype=np.int64), + op_version=op_version, output_names=output_names) + if op_version >= 11: + return OnnxSqueeze_11( + *x, axes=axes, op_version=op_version, + output_names=output_names) + return OnnxSqueeze_1(*x, axes=axes, + op_version=op_version, output_names=output_names) + + +def OnnxUnsqueezeApi11(*x, axes=None, op_version=None, + output_names=None): + """ + Adds operator Unsqueeze with opset>=13 following API from opset 11. + """ + if op_version is None: + raise RuntimeError("op_version must be specified.") + if op_version is None or op_version >= 13: + return OnnxUnsqueeze( + *x, np.array(axes, dtype=np.int64), + op_version=op_version, output_names=output_names) + if op_version >= 11: + return OnnxUnsqueeze_11( + *x, axes=axes, op_version=op_version, + output_names=output_names) + return OnnxUnsqueeze_1(*x, axes=axes, + op_version=op_version, output_names=output_names) + + +def OnnxReduceL2_typed(dtype, x, axes=None, keepdims=1, op_version=None, + output_names=None): + """ + Adds operator ReduceL2 for float or double. + """ + if dtype == np.float32: + return OnnxReduceL2( + x, axes=axes, keepdims=keepdims, + op_version=op_version, output_names=output_names) + x2 = OnnxMul(x, x, op_version=op_version) + red = OnnxReduceSumApi11( + x2, axes=[1], keepdims=1, op_version=op_version) + return OnnxSqrt( + red, op_version=op_version, output_names=output_names) + + +def OnnxReshapeApi13(*x, allowzero=0, op_version=None, + output_names=None): + """ + Adds operator Reshape with opset>=14 following API from opset 13. + """ + if op_version is None: + raise RuntimeError("op_version must be specified.") + if op_version is None or op_version >= 14: + return OnnxReshape( + *x, allowzero=allowzero, + op_version=op_version, output_names=output_names) + if op_version >= 13: + return OnnxReshape_13( + *x, op_version=op_version, output_names=output_names) + return OnnxReshape_5( + *x, op_version=op_version, output_names=output_names) From 67e7e8d2bfa0df7b7973c5e98d8c56ee9be34418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 15 Feb 2022 14:18:43 +0100 Subject: [PATCH 02/13] doc --- .gitignore | 1 + _unittests/ut_cli/test_cli_dynamic_doc.py | 26 +++++++++++++++++++++++ _unittests/ut_npy/test_xop.py | 19 ++++++++++++++++- _unittests/ut_npy/test_xop_doc.py | 23 ++++++++++++++++++++ mlprodict/__main__.py | 7 +++--- mlprodict/cli/onnx_code.py | 11 ++++++++++ mlprodict/npy/xop.py | 4 +--- mlprodict/npy/{xauto.py => xop_auto.py} | 1 + mlprodict/npy/xop_classes.py | 6 +++--- mlprodict/npy/xops.py | 8 ++++--- 10 files changed, 93 insertions(+), 13 deletions(-) create mode 100644 _unittests/ut_cli/test_cli_dynamic_doc.py create mode 100644 _unittests/ut_npy/test_xop_doc.py rename mlprodict/npy/{xauto.py => xop_auto.py} (99%) diff --git a/.gitignore b/.gitignore index 1659a0de0..4ae4749c7 100644 --- a/.gitignore +++ b/.gitignore @@ -319,3 +319,4 @@ cache-*.pickle onnxruntime*.json *net*.tar* _unittests/unittests.out +mlprodict/npy/_cache/*.rst diff --git a/_unittests/ut_cli/test_cli_dynamic_doc.py b/_unittests/ut_cli/test_cli_dynamic_doc.py new file mode 100644 index 000000000..82576dc16 --- /dev/null +++ b/_unittests/ut_cli/test_cli_dynamic_doc.py @@ -0,0 +1,26 @@ +""" +@brief test tree node (time=10s) +""" +import unittest +from pyquickhelper.loghelper import BufferedPrint +from pyquickhelper.pycode import ExtTestCase +from mlprodict.__main__ import main + + +class TestCliDynamicDoc(ExtTestCase): + + def test_cli_onnx_code_help(self): + st = BufferedPrint() + main(args=["dynamic_doc", "--help"], fLOG=st.fprint) + res = str(st) + self.assertIn("Generates the documentation", res) + + def test_cli_onnx_code(self): + st = BufferedPrint() + main(args=["dynamic_doc", '--verbose', '1'], fLOG=st.fprint) + res = str(st) + self.assertIn("Abs", res) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index 62d7d1a0e..05f4c5bdc 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -5,7 +5,7 @@ import unittest import numpy from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xops import OnnxAbs +from mlprodict.npy.xops import OnnxAbs, OnnxAdd from mlprodict.onnxrt import OnnxInference @@ -22,6 +22,23 @@ def test_onnx_abs(self): got = oinf.run({'X': x}) self.assertEqualArray(numpy.abs(x), got['Y']) + def test_onnx_add(self): + ov = OnnxAdd('X', 'X', output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=1) + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x + x, got['Y']) + + def test_onnx_add_cst(self): + ov = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), + output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=1) + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x + 1, got['Y']) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_npy/test_xop_doc.py b/_unittests/ut_npy/test_xop_doc.py new file mode 100644 index 000000000..4d0597c41 --- /dev/null +++ b/_unittests/ut_npy/test_xop_doc.py @@ -0,0 +1,23 @@ +""" +@brief test log(time=3s) +""" +import unittest +from pyquickhelper.pycode import ExtTestCase +from mlprodict.npy.xops import dynamic_class_creation +from mlprodict.npy.xop_auto import get_rst_doc + + +class TestXopDoc(ExtTestCase): + + @classmethod + def setUpClass(cls): + cls._algebra = dynamic_class_creation() + ExtTestCase.setUpClass() + + def test_doc_onnx(self): + rst = get_rst_doc() + self.assertIn("**Summary**", rst) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/__main__.py b/mlprodict/__main__.py index 5ce1c0754..e1cd8196b 100644 --- a/mlprodict/__main__.py +++ b/mlprodict/__main__.py @@ -22,7 +22,7 @@ def main(args, fLOG=print): from .cli.asv2csv import asv2csv from .cli.replay import benchmark_replay from .cli.einsum import einsum_test - from .cli.onnx_code import onnx_code + from .cli.onnx_code import onnx_code, dynamic_doc from .cli.validate import latency except ImportError: # pragma: no cover from mlprodict.cli.validate import validate_runtime @@ -32,7 +32,7 @@ def main(args, fLOG=print): from mlprodict.cli.asv2csv import asv2csv from mlprodict.cli.replay import benchmark_replay from mlprodict.cli.einsum import einsum_test - from mlprodict.cli.onnx_code import onnx_code + from mlprodict.cli.onnx_code import onnx_code, dynamic_doc from mlprodict.cli.validate import latency fcts = dict(validate_runtime=validate_runtime, @@ -44,7 +44,8 @@ def main(args, fLOG=print): benchmark_replay=benchmark_replay, einsum_test=einsum_test, onnx_code=onnx_code, - latency=latency) + latency=latency, + dynamic_doc=dynamic_doc) try: from pyquickhelper.cli import cli_main_helper except ImportError: # pragma: no cover diff --git a/mlprodict/cli/onnx_code.py b/mlprodict/cli/onnx_code.py index 6dcf1ed9c..4cc035499 100644 --- a/mlprodict/cli/onnx_code.py +++ b/mlprodict/cli/onnx_code.py @@ -59,3 +59,14 @@ def onnx_code(filename, format="onnx", output=None, verbose=0, name=None, f.write(code) else: fLOG(code) + + +def dynamic_doc(verbose=0, fLOG=print): + """ + Generates the documentation for ONNX operators. + + :param verbose: displays the list of operator + :param fLOG: logging function + """ + from ..npy.xops import dynamic_class_creation + dynamic_class_creation(cache=True, verbose=verbose, fLOG=fLOG) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 0d438622e..779a2092b 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -11,7 +11,6 @@ from onnx import GraphProto, TensorProto from onnx.helper import make_graph, make_model # pylint: disable=W0611 from onnx.numpy_helper import from_array -from ..tools.asv_options_helper import get_opset_number_from_onnx from .xop_classes import Variable, GraphBuilder @@ -216,8 +215,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, OnnxOperatorItem)): self.inputs.append(inp) elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)): - self.inputs.append( - OnnxOperator.ConstantVariable(inp)) + self.inputs.append(inp) else: raise TypeError( "Unable to interpret the input name for type {} in " diff --git a/mlprodict/npy/xauto.py b/mlprodict/npy/xop_auto.py similarity index 99% rename from mlprodict/npy/xauto.py rename to mlprodict/npy/xop_auto.py index 7f8e3f3fb..e5ad81c2a 100644 --- a/mlprodict/npy/xauto.py +++ b/mlprodict/npy/xop_auto.py @@ -16,6 +16,7 @@ def _get_doc_template(): except ImportError: class Template: "Docstring template" + def __init__(self, *args): pass diff --git a/mlprodict/npy/xop_classes.py b/mlprodict/npy/xop_classes.py index 9ef2faaf7..5035eb5ca 100644 --- a/mlprodict/npy/xop_classes.py +++ b/mlprodict/npy/xop_classes.py @@ -164,7 +164,7 @@ def _process_io(self, inputs, input_names): if inputs is None: return [ make_tensor_value_info( - 'X', TensorProto.FLOAT, None) # pylint: disable=disable=E1101 + 'X', TensorProto.FLOAT, None) # pylint: disable=E1101 for name in self.input_names] if inputs in NP_TYPE_TO_TENSOR_TYPE: @@ -220,9 +220,9 @@ def to_onnx(self, inputs=None, outputs=None, irv = opset2ir.get(opv, max(opset2ir.values())) onnx_model.ir_version = irv - del onnx_model.opset_import[:] # pylint: disable=disable=E1101 + del onnx_model.opset_import[:] # pylint: disable=E1101 for k, v in self.opsets.items(): - op_set = onnx_model.opset_import.add() # pylint: disable=disable=E1101 + op_set = onnx_model.opset_import.add() # pylint: disable=E1101 op_set.domain = k or '' op_set.version = v return onnx_model diff --git a/mlprodict/npy/xops.py b/mlprodict/npy/xops.py index 82a6661e7..5a081cf64 100644 --- a/mlprodict/npy/xops.py +++ b/mlprodict/npy/xops.py @@ -9,9 +9,9 @@ import numpy from scipy.sparse.coo import coo_matrix import onnx -from .xauto import get_rst_doc -from ._cache import cache_folder +from .xop_auto import get_rst_doc from .xop_classes import Variable +from ._cache import cache_folder def ClassFactory(class_name, op_name, inputs, outputs, @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs): return newclass -def dynamic_class_creation(cache=False): +def dynamic_class_creation(cache=False, verbose=0, fLOG=print): """ Automatically generates classes for each of the operators module *onnx* defines and described at @@ -172,6 +172,8 @@ def _c(obj, label, i): return (name, tys) for name in sorted(res): + if verbose > 0 and fLOG is not None: + fLOG(name) schema = res[name] inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] From 5ef46ac30e1437b4c0c837d3eefcf3f041a8b6a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 15 Feb 2022 19:59:49 +0100 Subject: [PATCH 03/13] bug --- _unittests/ut_npy/test_xop.py | 50 +++++++- mlprodict/npy/onnx_numpy_compiler.py | 6 +- mlprodict/npy/onnx_numpy_wrapper.py | 3 +- mlprodict/npy/onnx_sklearn_wrapper.py | 10 +- mlprodict/npy/onnx_variable.py | 3 +- mlprodict/npy/xop.py | 123 +++++++++++++++++--- mlprodict/npy/xop_classes.py | 135 +++++++++++++++------- mlprodict/npy/xops.py | 99 ++++++++++++---- mlprodict/onnx_tools/onnx2py_helper.py | 2 +- mlprodict/onnx_tools/optim/onnx_helper.py | 3 +- 10 files changed, 336 insertions(+), 98 deletions(-) diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index 05f4c5bdc..d8369a575 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -5,7 +5,8 @@ import unittest import numpy from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xops import OnnxAbs, OnnxAdd +from mlprodict.npy.xops import loadop +from mlprodict.npy.xop_classes import GraphBuilder from mlprodict.onnxrt import OnnxInference @@ -14,31 +15,72 @@ class TestXOps(ExtTestCase): def test_float32(self): self.assertEqual(numpy.float32, numpy.dtype('float32')) + def test_impossible(self): + cl = loadop("OnnxAdd") + self.assertEqual(cl.__name__, "OnnxAdd") + cl = loadop("OnnxCast") + self.assertEqual(cl.__name__, "OnnxCast") + cl = loadop("Cast_13") + self.assertEqual(cl.__name__, "OnnxCast_13") + cl = loadop("OnnxCast_13") + self.assertEqual(cl.__name__, "OnnxCast_13") + self.assertRaise(lambda: loadop("OnnxImpossible"), ValueError) + def test_onnx_abs(self): + OnnxAbs = loadop("OnnxAbs") ov = OnnxAbs('X', output_names=['Y']) - onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=1) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) oinf = OnnxInference(onx) x = numpy.array([-2, 2], dtype=numpy.float32) got = oinf.run({'X': x}) self.assertEqualArray(numpy.abs(x), got['Y']) def test_onnx_add(self): + OnnxAdd = loadop("Add") ov = OnnxAdd('X', 'X', output_names=['Y']) - onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=1) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) oinf = OnnxInference(onx) x = numpy.array([-2, 2], dtype=numpy.float32) got = oinf.run({'X': x}) self.assertEqualArray(x + x, got['Y']) def test_onnx_add_cst(self): + OnnxAdd = loadop("OnnxAdd") ov = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), output_names=['Y']) - onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=1) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) oinf = OnnxInference(onx) x = numpy.array([-2, 2], dtype=numpy.float32) got = oinf.run({'X': x}) self.assertEqualArray(x + 1, got['Y']) + def test_number2alpha(self): + sel = [GraphBuilder.number2alpha(i) for i in range(0, 100001)] + sel2 = sel.copy() + sel2.sort() + self.assertEqual(sel, sel2) + + def test_onnx_add_sub_left(self): + OnnxAdd, OnnxSub = loadop("OnnxAdd", "OnnxSub") + ov = OnnxAdd('X', 'X') + ov2 = OnnxSub(ov, 'X', output_names=['Y']) + onx = ov2.to_onnx(numpy.float32, numpy.float32, verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x, got['Y']) + + def test_onnx_add_sub_right(self): + OnnxAdd, OnnxSub = loadop("OnnxAdd", "OnnxSub") + ov = OnnxAdd('X', 'X') + ov2 = OnnxSub('X', ov, output_names=['Y']) + onx = ov2.to_onnx(numpy.float32, numpy.float32, verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([-2, 2], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(-x, got['Y']) + + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 0e588bcc8..060af2df9 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -7,14 +7,11 @@ import inspect from typing import Any import numpy -from skl2onnx.common.data_types import guess_numpy_type -from skl2onnx import __max_supported_opset__ from ..tools.ort_wrapper import InferenceSession from ..onnx_tools.optim._main_onnx_optim import onnx_optimisations from ..onnxrt import OnnxInference from .onnx_version import FctVersion from .onnx_numpy_annotation import get_args_kwargs -from .onnx_variable import OnnxVar class OnnxNumpyFunction: @@ -127,6 +124,7 @@ def __init__(self, fct, op_version=None, runtime=None, signature=None, "." % (type(version), version)) self.fctsig = fctsig if op_version is None: + from skl2onnx import __max_supported_opset__ op_version = __max_supported_opset__ if hasattr(fct, 'SerializeToString'): self.fct_ = None @@ -339,6 +337,8 @@ def _to_onnx(self, op_version=None, signature=None, version=None): Returns the onnx graph produced by function `fct_`. """ if self.onnx_ is None and self.fct_ is not None: + from .onnx_variable import OnnxVar + inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612 self._parse_annotation( signature=signature, version=version)) diff --git a/mlprodict/npy/onnx_numpy_wrapper.py b/mlprodict/npy/onnx_numpy_wrapper.py index 47d93e9f5..02b0181ab 100644 --- a/mlprodict/npy/onnx_numpy_wrapper.py +++ b/mlprodict/npy/onnx_numpy_wrapper.py @@ -8,7 +8,6 @@ from .onnx_version import FctVersion from .onnx_numpy_annotation import get_args_kwargs from .onnx_numpy_compiler import OnnxNumpyCompiler -from .onnx_variable import OnnxVar class _created_classes: @@ -53,6 +52,7 @@ def __call__(self, *args, **kwargs): """ Calls the compiled function with arguments `args`. """ + from .onnx_variable import OnnxVar try: return self.compiled(*args, **kwargs) except (TypeError, RuntimeError, ValueError) as e: @@ -191,6 +191,7 @@ def __call__(self, *args, **kwargs): tensor in *args* defines the templated version of the function to convert into *ONNX*. """ + from .onnx_variable import OnnxVar if len(self.kwargs) == 0: others = None else: diff --git a/mlprodict/npy/onnx_sklearn_wrapper.py b/mlprodict/npy/onnx_sklearn_wrapper.py index cb362eafc..f0a75e6b3 100644 --- a/mlprodict/npy/onnx_sklearn_wrapper.py +++ b/mlprodict/npy/onnx_sklearn_wrapper.py @@ -9,10 +9,6 @@ from sklearn.base import ( ClassifierMixin, ClusterMixin, RegressorMixin, TransformerMixin) -from skl2onnx import update_registered_converter -from skl2onnx.common.data_types import Int64TensorType -from skl2onnx.algebra.onnx_ops import OnnxIdentity # pylint: disable=E0611 -from .onnx_variable import OnnxVar, TupleOnnxAny from .onnx_numpy_wrapper import _created_classes_inst, wrapper_onnxnumpy_np from .onnx_numpy_annotation import NDArraySameType, NDArrayType @@ -67,6 +63,7 @@ def _common_shape_calculator_int_t(operator): raise RuntimeError( "This function only supports two outputs not %r." % len( operator.outputs)) + from skl2onnx.common.data_types import Int64TensorType op = operator.raw_operator cl = X[0].type.__class__ dim = [X[0].type.shape[0], getattr(op, 'n_outputs_', None)] @@ -107,6 +104,8 @@ def _common_converter_t(scope, operator, container): "This function only supports one output not %r." % len( operator.outputs)) + from skl2onnx.algebra.onnx_ops import OnnxIdentity # pylint: disable=E0611 + from .onnx_variable import OnnxVar xvar = OnnxVar(X[0]) fct_cl = operator.onnx_numpy_fct_ @@ -157,6 +156,8 @@ def _common_converter_int_t(scope, operator, container): "This function only supports two outputs not %r." % len( operator.outputs)) + from skl2onnx.algebra.onnx_ops import OnnxIdentity # pylint: disable=E0611 + from .onnx_variable import OnnxVar, TupleOnnxAny xvar = OnnxVar(X[0]) fct_cl = operator.onnx_numpy_fct_ @@ -281,6 +282,7 @@ def addattr(operator, obj): lambda scope, operator, container: cvtc(scope, addattr(operator, obj), container)) + from skl2onnx import update_registered_converter update_registered_converter( model, alias, convert_fct=local_convert_fct, shape_fct=local_shape_fct, overwrite=overwrite, diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py index 2009c3acf..89464f516 100644 --- a/mlprodict/npy/onnx_variable.py +++ b/mlprodict/npy/onnx_variable.py @@ -6,7 +6,6 @@ """ import numpy from onnx.helper import make_tensor -from skl2onnx.common.data_types import guess_numpy_type from skl2onnx.common._topology import Variable # pylint: disable=E0611,E0001 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 OnnxAdd, OnnxAnd, @@ -27,7 +26,7 @@ OnnxTopK, OnnxTranspose, OnnxWhere) from skl2onnx.algebra.onnx_operator import OnnxOperatorItem -from skl2onnx.common.data_types import _guess_numpy_type +from skl2onnx.common.data_types import guess_numpy_type, _guess_numpy_type from ..onnx_tools.onnx2py_helper import guess_proto_dtype diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 779a2092b..ffbbd6bf6 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -11,6 +11,7 @@ from onnx import GraphProto, TensorProto from onnx.helper import make_graph, make_model # pylint: disable=W0611 from onnx.numpy_helper import from_array +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE from .xop_classes import Variable, GraphBuilder @@ -426,17 +427,18 @@ def __getitem__(self, index): """ return OnnxOperatorItem(self, index, self.op_version) - def _node_to_graph(self, other_outputs=None): + def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None): """ Builds a graph as a list of nodes to walk through in that order. """ - outputs = [self] + node_outputs = [self] if other_outputs is not None: - outputs += other_outputs + node_outputs += other_outputs # walk through graphs - stack = list(outputs) - inputs = [] + stack = list(node_outputs) + new_inputs = [] + set_inputs = set() memo = [] while len(stack) > 0: memo.extend(stack) @@ -445,26 +447,101 @@ def _node_to_graph(self, other_outputs=None): for inp in obj.inputs: if isinstance(inp, OnnxOperator): new_stack.append(inp) + elif (isinstance(inp, Variable) and + inp.name not in set_inputs): + set_inputs.add(inp.name) + if inputs is None: + new_inputs.append(inp) + elif isinstance(inputs, dict): + if inp in inputs: + new_inputs.append((inp, inputs[inp])) + else: + raise ValueError( # pragma: no cover + "Unable to find input %r in %r." % ( + inp, inputs)) + elif (inputs in NP_TYPE_TO_TENSOR_TYPE or + numpy.dtype(inputs) in NP_TYPE_TO_TENSOR_TYPE): + new_inputs.append((inp, inputs)) + else: + raise RuntimeError( # pragma: no cover + "Unable to handle inputs=%r." % inputs) + elif isinstance(inp, numpy.ndarray): + pass else: - inputs.append(inp) + raise TypeError( + "Unexpected input type %r in node type %r." % ( + type(inp), type(obj))) stack = new_stack + if len(new_inputs) == 0: + raise RuntimeError( + "No detected inputs inputs=%r outputs=%r." % ( + inputs, outputs)) + # eliminate duplicates done = set() nodes = [] - for node in memo: + for node in reversed(memo): if id(node) in done: continue done.add(id(node)) nodes.append(node) - return nodes, inputs + + def _get_type(node, name=None, outputs=None): + if outputs is None: + raise NotImplementedError( + "outputs is None, expected_outputs=%r" % ( + node.expected_outputs, )) + if isinstance(outputs, dict): + if name is None: + raise RuntimeError( + "Unable to get type among %r, name=None." % ( + outputs, )) + if name not in outputs: + raise ValueError( # pragma: no cover + "Unable to find %r in %r." % ( + name, outputs)) + return outputs[name] + if isinstance(outputs, list): + raise NotImplementedError( + "Unexpected type for name=%r, outputs=%r." % ( + name, outputs)) + if (outputs in NP_TYPE_TO_TENSOR_TYPE or + numpy.dtype(outputs) in NP_TYPE_TO_TENSOR_TYPE): + return outputs + raise RuntimeError( # pragma: no cover + "Unable to handle outputs=%r." % outputs) + + + # outputs + new_outputs = [] + for node in node_outputs: + if node.output_names is None: + n = self.output_range[0] + for i in range(n): + to = _get_type(node, outputs=outputs) + new_outputs.append(('out%d' % i, to)) + else: + for o in self.output_names: + to = _get_type(node, o, outputs=outputs) + new_outputs.append((o, to)) + if len(new_outputs) == 0: + raise RuntimeError( + "No detected outputs inputs=%r outputs=%r." % ( + inputs, outputs)) + + return nodes, new_inputs, new_outputs def add_to(self, builder): """ Adds to graph builder. """ inputs = builder.get_input_names(self, self.inputs) - outputs = builder.get_output_names(self, self.output_names) + n_outputs = ( + self.output_range[0] if self.output_names is None + else len(self.output_names)) + outputs = [builder.get_output_name(self, i) + for i in range(n_outputs)] builder.add_node( self.operator_name, builder.get_unique_name('_' + self.operator_name.lower()), @@ -477,16 +554,16 @@ def to_onnx(self, inputs=None, outputs=None, """ Converts this operator into an ONNX graph. - :param inputs: specific inputs (as a dictionary) or - default inputs if not specified - :param outputs: specific outputs - :param other_outputs: additional outputs to consider + :param inputs: information about type + :param outputs: information about types + :param other_outputs: additional nodes to consider as graph outputs but not outputs of this particular node :param target_opset: dictionary with target opset per domain, None for the default one :param verbose: prints information """ + # opsets if isinstance(target_opset, dict): dom = self.domain or '' target_opset = target_opset.get(dom, None) @@ -509,16 +586,32 @@ def to_onnx(self, inputs=None, outputs=None, "target_opset={} is lower than the version={} requested " "for this node '{}'.".format( target_opset, self.op_version, self.__class__.__name__)) + + # inputs, outputs + if isinstance(inputs, list): + raise NotImplementedError( + "Unable to process inputs=%r." % (inputs, )) + if isinstance(outputs, list): + raise NotImplementedError( + "Unable to process outputs=%r." % (outputs, )) + # get the graph - nodes, graph_inputs = self._node_to_graph(other_outputs) + nodes, graph_inputs, graph_outputs = self._node_to_graph( + other_outputs, inputs, outputs) if len(nodes) == 0: raise RuntimeError( # pragma: no cover "Node list is empty.") + if verbose > 1: + for i, n in enumerate(nodes): + print("nodes[%d]=%r" % (i, n)) + for i, n in enumerate(graph_inputs): + print("graph_inputs[%d]=%r" % (i, n)) builder = GraphBuilder() for node in nodes: node.add_to(builder) - return builder.to_onnx(inputs=inputs, outputs=outputs, + return builder.to_onnx(inputs=graph_inputs, + outputs=graph_outputs, target_opset=target_opset, verbose=verbose) diff --git a/mlprodict/npy/xop_classes.py b/mlprodict/npy/xop_classes.py index 5035eb5ca..d85cdc8d8 100644 --- a/mlprodict/npy/xop_classes.py +++ b/mlprodict/npy/xop_classes.py @@ -8,6 +8,7 @@ from onnx.helper import ( make_node, make_graph, make_model, make_tensor_value_info) +from onnx.numpy_helper import from_array from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE from ..tools.asv_options_helper import get_opset_number_from_onnx @@ -50,9 +51,17 @@ def __init__(self): self.names = set() self.input_names = set() self.output_names = {} + self.output_names_rev = {} self.cl_onnx_op = OnnxOperator self.cl_onnx_op_item = OnnxOperatorItem + @staticmethod + def number2alpha(index): + dec = str(int(index)) + if len(dec) == 1: + return dec + return chr(96 + len(dec)) + dec + def get_unique_name(self, name): """ Returns a unique name to name an output. @@ -64,37 +73,45 @@ def get_unique_name(self, name): self.names.add(name) return name i = 1 - new_name = "%s_%d" % (name, i) + new_name = "%s_%s" % (name, self.number2alpha(i)) while new_name in self.names: i += 1 - new_name = "%s_%d" % (name, i) + new_name = "%s_%s" % (name, self.number2alpha(i)) self.names.add(new_name) return new_name - def get_output_names(self, node, outputs): + def get_output_name(self, node, index): """ - Returns a new output name for a node if it exists - or create a new one. + Returns the output name for a node. """ - names = [] - for index, name in enumerate(outputs): - key = id(node), index - if key in self.output_names: - name = self.output_names[key] - else: - output = node.output_names[index] - if isinstance(output, str): - n = output - elif isinstance(output, Variable): - n = output.name - else: - raise TypeError( # pragma: no cover - "Unexpected type %r for output %d." % ( - type(output), index)) - name = self.get_unique_name(n) - self.output_names[key] = name - names.append(name) - return names + key = id(node), index + if key in self.output_names: + name = self.output_names[key] + return name + if node.output_names is None: + prefix = node.onnx_prefix_name if node.onnx_prefix_name else 'out' + output = '%s%d' % (prefix, index) + else: + output = node.output_names[index] + if isinstance(output, Variable): + n = output.name + else: + raise TypeError( # pragma: no cover + "Unexpected type %r for output %d." % ( + type(output), index)) + name = self.get_unique_name(n) + self.output_names[key] = name + self.output_names_rev[name] = key + if node.output_names is not None: + var = node.output_names[index] + if isinstance(var, Variable): + var = var.name + if var != name: + raise RuntimeError( + "Output unique name %r is different from the " + "expected name %r at position %r." % ( + name, node.output_names[index], index)) + return name def get_input_names(self, node, inputs): """ @@ -106,11 +123,7 @@ def get_input_names(self, node, inputs): """ names = [] for i in inputs: - if isinstance(i, str): - names.append(i) - self.input_names.add(i) - self.names.add(i) - elif isinstance(i, Variable): + if isinstance(i, Variable): names.append(i.name) self.names.add(i.name) self.input_names.add(i.name) @@ -122,6 +135,13 @@ def get_input_names(self, node, inputs): name = self.get_output_name(i.onnx_op, i.index) names.append(name) self.names.add(name) + elif isinstance(i, numpy.ndarray): + # Adding an initializer + name = self.get_unique_name('init') + init = from_array(i, name) + self.initializer.append(init) + names.append(name) + self.names.add(name) else: raise TypeError( # pragma: no cover "Unexpected type for an input %r." % type(i)) @@ -160,39 +180,69 @@ def add_node(self, op_type, name, inputs, outputs, domain='', domain=domain) self.node.append(node) - def _process_io(self, inputs, input_names): + def _process_io(self, inputs, input_names, output=False): if inputs is None: return [ make_tensor_value_info( 'X', TensorProto.FLOAT, None) # pylint: disable=E1101 for name in self.input_names] - if inputs in NP_TYPE_TO_TENSOR_TYPE: - inputs = [inputs] - elif numpy.dtype(inputs) in NP_TYPE_TO_TENSOR_TYPE: - inputs = [inputs] + if not isinstance(inputs, list): + if inputs in NP_TYPE_TO_TENSOR_TYPE: + inputs = [inputs] + elif numpy.dtype(inputs) in NP_TYPE_TO_TENSOR_TYPE: + inputs = [inputs] + if output and isinstance(input_names, dict): + keep_names = {} + for inp in inputs: + if isinstance(inp, Variable) and inp.name in input_names: + keep_names[inp.name] = input_names[inp.name] + elif isinstance(inp, tuple) and len(inp) == 2: + var, dt = inp + if var.name in input_names: + keep_names[var.name] = input_names[var.name] + else: + raise TypeError( + "Unexpected type %r in %r." % (inp, inputs)) + input_names = keep_names if len(input_names) != len(inputs): raise RuntimeError( # pragma: no cover - "Mismatch between %r and %r." % (input_names, inputs)) + "Mismatch between %r and %r (output=%r)." % ( + input_names, inputs, output)) if isinstance(input_names, dict): if len(input_names) == 1: input_names = list(input_names.values()) else: raise NotImplementedError( - "Unexpected %r." % input_names) + "Unexpected %r (output=%r)." % (input_names, output)) res = [] for inp, name in zip(inputs, input_names): - if inp in NP_TYPE_TO_TENSOR_TYPE: + if isinstance(inp, tuple): + if len(inp) != 2: + raise RuntimeError( + "Unexpected value %r (output=%r)." % ( + inp, output)) + dname, dtype = inp + if isinstance(dname, Variable): + dname = dname.name + if dname != name: + raise RuntimeError( + "Unexpected name %r != %r (inp=%r, output=%r)." % ( + dname, name, inp, output)) + else: + dtype = inp + if dtype in NP_TYPE_TO_TENSOR_TYPE: res.append( make_tensor_value_info( - name, NP_TYPE_TO_TENSOR_TYPE[inp], None)) - elif numpy.dtype(inp) in NP_TYPE_TO_TENSOR_TYPE: + name, NP_TYPE_TO_TENSOR_TYPE[dtype], None)) + elif numpy.dtype(dtype) in NP_TYPE_TO_TENSOR_TYPE: res.append( make_tensor_value_info( - name, NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(inp)], None)) + name, NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)], None)) else: raise RuntimeError( - "Unexpected tuple(%r, %r)." % (inp, name)) + "Unexpected tuple(%r, %r) - output=%r." % ( + inp, name, output)) return res def to_onnx(self, inputs=None, outputs=None, @@ -210,7 +260,8 @@ def to_onnx(self, inputs=None, outputs=None, """ # inputs and outputs self.input = self._process_io(inputs, self.input_names) - self.output = self._process_io(outputs, self.output_names) + self.output = self._process_io( + outputs, self.output_names_rev, output=True) graph = make_graph( self.node, 'XOP', self.input, self.output, self.initializer) diff --git a/mlprodict/npy/xops.py b/mlprodict/npy/xops.py index 5a081cf64..6d07cdeb2 100644 --- a/mlprodict/npy/xops.py +++ b/mlprodict/npy/xops.py @@ -140,17 +140,10 @@ def __init__(self, *args, **kwargs): return newclass -def dynamic_class_creation(cache=False, verbose=0, fLOG=print): +def _populate_schemas(): """ - Automatically generates classes for each of the operators - module *onnx* defines and described at - `Operators - `_ - and `Operators - `_. + Populates all schemas. """ - cache_dir = cache_folder() res = {} for schema in onnx.defs.get_all_schemas_with_history(): if schema.support_level == schema.SupportType.EXPERIMENTAL: @@ -164,22 +157,69 @@ def dynamic_class_creation(cache=False, verbose=0, fLOG=print): else: res[schema.name] = schema res[schema.name + '_' + str(schema.since_version)] = schema - cls = {} + return res + +def _dynamic_class_creation(operator_names, cache=False, verbose=0, fLOG=print): + """ + Automatically generates classes for each of the operators + module *onnx* defines and described at + `Operators + `_ + and `Operators + `_. + """ def _c(obj, label, i): name = '%s%d' % (obj.name or label, i) tys = obj.typeStr or '' return (name, tys) - for name in sorted(res): + cache_dir = cache_folder() + + res = _all_schemas + cls = {} + set_names = set() + set_skip = set() + for op_name in operator_names: + set_names.add(op_name) + if '_' in op_name: + n = op_name.split('_')[0] + if n.startswith('Onnx'): + set_skip.add(n) + else: + set_skip.add('Onnx' + n) + set_names.add(n) + + if verbose > 1 and fLOG is not None: + fLOG("[_dynamic_class_creation] set_names=%r" % set_names) + fLOG("[_dynamic_class_creation] set_skip=%r" % set_skip) + + for op_name in set_names: + cl_name = op_name if op_name.startswith('Onnx') else 'Onnx' + op_name + if verbose > 1 and fLOG is not None: + fLOG('[_dynamic_class_creation] cl_name=%r op_name=%r (in=%d)' % ( + cl_name, op_name, 1 if cl_name in _all_classes else 0)) + if cl_name in _all_classes: + if cl_name not in set_skip: + yield _all_classes[cl_name] + continue if verbose > 0 and fLOG is not None: - fLOG(name) - schema = res[name] + fLOG("[_dynamic_class_creation] op_name=%r, cl_name=%r" % ( + op_name, cl_name)) + + name = op_name[4:] if op_name.startswith('Onnx') else op_name + try: + schema = res[name] + except KeyError as e: + raise ValueError( + "Operator %r (or %r) does not exists." % ( + name, op_name)) from e inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] args = [p for p in schema.attributes] - if '_' in name: + if '_' in op_name: class_name = "Onnx" + name else: class_name = "Onnx" + schema.name @@ -210,21 +250,30 @@ def _c(obj, label, i): if '_' not in name: continue main, version = name.split('_') - last = cls[main] + if main in cls: + last = cls[main] + else: + last = _all_classes[main] last.past_version[name] = cls[name] - return cls + _all_classes.update(cls) + for v in cls.values(): + if v not in set_skip: + yield v -def _update_module(): +_all_schemas = _populate_schemas() +_all_classes = {} + + +def loadop(*names, cache=False, verbose=0, fLOG=print): """ - Dynamically updates the module with operators defined - by *ONNX*. + Dynamically creates a class for a every operator type in + the given list. """ - res = dynamic_class_creation() - this = sys.modules[__name__] - for k, v in res.items(): - setattr(this, k, v) - + res = tuple(_dynamic_class_creation( + names, cache=cache, verbose=verbose, fLOG=fLOG)) + if len(res) == 1: + return res[0] + return res -_update_module() diff --git a/mlprodict/onnx_tools/onnx2py_helper.py b/mlprodict/onnx_tools/onnx2py_helper.py index 97fe76843..435840c75 100644 --- a/mlprodict/onnx_tools/onnx2py_helper.py +++ b/mlprodict/onnx_tools/onnx2py_helper.py @@ -9,7 +9,6 @@ from scipy.sparse import coo_matrix from onnx import onnx_pb as onnx_proto, TensorProto from onnx.numpy_helper import to_array, from_array as onnx_from_array -from skl2onnx.common.data_types import _guess_numpy_type def to_bytes(val): @@ -608,6 +607,7 @@ def to_skl2onnx_type(name, elem_type, shape): :param shape: expected shape :return: data type """ + from skl2onnx.common.data_types import _guess_numpy_type elem = guess_numpy_type_from_string(elem_type) shape = list(None if d == 0 else d for d in shape) return (name, _guess_numpy_type(elem, shape)) diff --git a/mlprodict/onnx_tools/optim/onnx_helper.py b/mlprodict/onnx_tools/optim/onnx_helper.py index 706899c98..f8f124adb 100644 --- a/mlprodict/onnx_tools/optim/onnx_helper.py +++ b/mlprodict/onnx_tools/optim/onnx_helper.py @@ -5,7 +5,6 @@ from collections import Counter from onnx.helper import make_graph from onnx import ValueInfoProto -from skl2onnx.common._topology import Variable from ._onnx_optimisation_common import _apply_optimisation_on_graph from .onnx_optimisation import onnx_remove_node @@ -146,6 +145,8 @@ def change_input_first_dimension(onnx_model, N=None, debug_info=None): @param debug_info unused @return modified model onnx """ + from skl2onnx.common._topology import Variable + def _make_value_info(variable): value_info = ValueInfoProto() value_info.name = variable.full_name From f62bc61c4300719c127aec5cac3aa2fdbd52ca6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 16 Feb 2022 01:56:46 +0100 Subject: [PATCH 04/13] updates --- _doc/sphinxdoc/source/api/xop.rst | 4 +- _unittests/ut_npy/test_xop.py | 5 +- _unittests/ut_npy/test_xop_doc.py | 4 +- mlprodict/cli/onnx_code.py | 4 +- mlprodict/npy/onnx_numpy_compiler.py | 2 +- mlprodict/npy/xop.py | 144 ++++++------------ mlprodict/npy/xop_auto.py | 1 + mlprodict/npy/{xops.py => xop_factory.py} | 8 +- .../{xop_classes.py => xop_graph_builder.py} | 131 ++++++++-------- mlprodict/npy/xop_variable.py | 92 +++++++++++ 10 files changed, 215 insertions(+), 180 deletions(-) rename mlprodict/npy/{xops.py => xop_factory.py} (99%) rename mlprodict/npy/{xop_classes.py => xop_graph_builder.py} (72%) create mode 100644 mlprodict/npy/xop_variable.py diff --git a/_doc/sphinxdoc/source/api/xop.rst b/_doc/sphinxdoc/source/api/xop.rst index 94fb3cd38..363817fa0 100644 --- a/_doc/sphinxdoc/source/api/xop.rst +++ b/_doc/sphinxdoc/source/api/xop.rst @@ -14,9 +14,9 @@ API .. autosignature:: mlprodict.npy.xops.dynamic_class_creation -.. autosignature:: mlprodict.npy.xops_classes.Variable +.. autosignature:: mlprodict.npy.xops_variable.Variable -.. autosignature:: mlprodict.npy.xops_classes.GraphBuilder +.. autosignature:: mlprodict.npy.xops_graph_builder.GraphBuilder .. autosignature:: mlprodict.npy.xop.OnnxOperator diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index d8369a575..27c32d275 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -5,8 +5,8 @@ import unittest import numpy from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xops import loadop -from mlprodict.npy.xop_classes import GraphBuilder +from mlprodict.npy.xop_factory import loadop +from mlprodict.npy.xop_graph_builder import GraphBuilder from mlprodict.onnxrt import OnnxInference @@ -81,6 +81,5 @@ def test_onnx_add_sub_right(self): self.assertEqualArray(-x, got['Y']) - if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_npy/test_xop_doc.py b/_unittests/ut_npy/test_xop_doc.py index 4d0597c41..3bb38995b 100644 --- a/_unittests/ut_npy/test_xop_doc.py +++ b/_unittests/ut_npy/test_xop_doc.py @@ -3,7 +3,7 @@ """ import unittest from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xops import dynamic_class_creation +from mlprodict.npy.xops import _dynamic_class_creation from mlprodict.npy.xop_auto import get_rst_doc @@ -11,7 +11,7 @@ class TestXopDoc(ExtTestCase): @classmethod def setUpClass(cls): - cls._algebra = dynamic_class_creation() + cls._algebra = _dynamic_class_creation() ExtTestCase.setUpClass() def test_doc_onnx(self): diff --git a/mlprodict/cli/onnx_code.py b/mlprodict/cli/onnx_code.py index 4cc035499..8bfefea8f 100644 --- a/mlprodict/cli/onnx_code.py +++ b/mlprodict/cli/onnx_code.py @@ -68,5 +68,5 @@ def dynamic_doc(verbose=0, fLOG=print): :param verbose: displays the list of operator :param fLOG: logging function """ - from ..npy.xops import dynamic_class_creation - dynamic_class_creation(cache=True, verbose=verbose, fLOG=fLOG) + from ..npy.xops import _dynamic_class_creation + _dynamic_class_creation(cache=True, verbose=verbose, fLOG=fLOG) diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 060af2df9..00880a44b 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -352,7 +352,7 @@ def _to_onnx(self, op_version=None, signature=None, version=None): getattr(self.fct_, '__module__', None))) names_in = [oi[0] for oi in inputs] names_out = [oi[0] for oi in outputs] - names_var = [OnnxVar(n, dtype=guess_numpy_type(dt[1])) + names_var = [OnnxVar(n, dtype=numpy_type_prototype(dt[1])) for n, dt in zip(names_in, inputs)] if 'op_version' in self.fct_.__code__.co_varnames: diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index ffbbd6bf6..331e4b531 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -5,30 +5,27 @@ .. versionadded:: 0.9 """ -from logging import getLogger import numpy from scipy.sparse import coo_matrix from onnx import GraphProto, TensorProto from onnx.helper import make_graph, make_model # pylint: disable=W0611 from onnx.numpy_helper import from_array -from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from .xop_classes import Variable, GraphBuilder - - -logger = getLogger('mlprodict.xop') +from .xop_variable import Variable, is_numpy_dtype +from .xop_graph_builder import GraphBuilder class OnnxOperatorItem: """ - Accessor to one of the output returned by a *OnnxOperator*. + Accessor to one of the output returned by a @see cl OnnxOperator. - :param onx_op: OnnxOperator + :param onx_op: @see cl OnnxOperator :param index: integer + :param op_version: defines the opset version """ def __init__(self, onx_op, index, op_version=None): if not isinstance(index, int): - raise TypeError("index must be an integer.") + raise TypeError("index must be an integer not %r." % type(index)) self.onx_op = onx_op self.index = index self.op_version = op_version @@ -41,7 +38,7 @@ def __str__(self): def get_output_name(self, i=0): """ - Returns the output. + Returns the output name at position *i*. """ if i != 0: raise IndexError("Can only return the first item.") @@ -55,49 +52,6 @@ def get_output(self, i=0): raise IndexError("Can only return the first item.") return self.onx_op.get_output(self.index) - @property - def outputs(self): - """ - Returns the outputs of the node. - """ - if self.onx_op is None: - raise RuntimeError( - "self.onx_op cannot be None, type(self)={}".format( - type(self))) - if self.index is None: - raise RuntimeError( - "self.index cannot be None, type(self)={}".format( - type(self))) - outputs = self.onx_op.outputs - if outputs is None: - raise RuntimeError( - "self.onx_op.outputs cannot be None, " - "type(self)={}, type(self.onx_op)={}, " - "type(self.onx_op.state)={}".format( - type(self), type(self.onx_op), type(self.onx_op.state))) - return outputs[self.index:self.index + 1] - - def get_output_type_inference(self, input_shapes=None): - """ - Returns the inferred shape. - """ - if self.onx_op is None: - raise RuntimeError( - "self.onx_op cannot be None, type(self)={}".format( - type(self))) - if self.index is None: - raise RuntimeError( - "self.index cannot be None, type(self)={}".format( - type(self))) - outputs = self.onx_op.get_output_type_inference(input_shapes) - if outputs is None: - raise RuntimeError( - "self.onx_op.outputs cannot be None, " - "type(self)={}, type(self.onx_op)={}, " - "type(self.onx_op.state)={}".format( - type(self), type(self.onx_op), type(self.onx_op.state))) - return outputs[self.index:self.index + 1] - class OnnxOperator: """ @@ -114,18 +68,13 @@ class OnnxOperator: are not linked to the output and cannot be retrieved. *global_context* is a dictionary mapped the subgraph input names to these operators. - :param clear_subgraph_inputs: clears subgraphs outputs. - Operator *If* does take subgraphs as attribute, - there are subgraphs with no inputs and - global variable as hidden inputs. :param kwargs: additional parameters of the operator .. versionadd:: 0.9 """ def __init__(self, *inputs, op_version=None, output_names=None, - domain=None, global_context=None, - clear_subgraph_inputs=False, **kwargs): + domain=None, global_context=None, **kwargs): if (output_names is None and self.__class__.__name__.startswith("OnnxScan")): @@ -143,7 +92,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, "output_names cannot be empty (operator %r)." "" % self.__class__.__name__) output_names = output_names.copy() - for i in range(len(output_names)): + for i in range(len(output_names)): # pylint: disable=C0200 if isinstance(output_names[i], str): output_names[i] = Variable(output_names[i]) elif output_names is not None: @@ -179,7 +128,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, self.output_range = self.__class__.output_range if self.__class__.__name__ not in { 'OnnxScan', 'OnnxLoop', 'OnnxIf'}: - # TODO: the minimum opset depends on embedded graph + # The minimum opset depends on embedded graph # by default, it takes the given op_version but the # optimal value could be lower. self.op_version = self.since_version @@ -258,7 +207,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, "operator %r." % self) if self.output_variables is None: self.output_variables = [None for o in self.output_names] - for i in range(len(self.output_names)): + for i in range(len(self.output_names)): # pylint: disable=C0200 name = self.output_names[i] if isinstance(name, Variable): self.output_variables[i] = name @@ -298,17 +247,26 @@ def __init__(self, *inputs, op_version=None, output_names=None, inp = (name, None) self.expected_inputs.append(inp) - self.output_names_ = None - self._post_process_attributes( - clear_subgraph_inputs=clear_subgraph_inputs) - logger.debug( - '[Ops] +%s-%d (%s) id=%d', - self.__class__.__name__, self.op_version, self.domain, id(self)) + self._post_process_attributes() + self._check() + + def _check(self): + input_types = (Variable, OnnxOperator, numpy.ndarray) + for o in self.inputs: + if not isinstance(o, input_types): + raise TypeError( + "Wrong type for inputs %r." % ( + self.inputs, )) + if self.output_names is not None: + for o in self.output_names: + if not isinstance(o, Variable): + raise TypeError( + "Wrong type for output_names %r." % ( + self.output_names, )) - def _post_process_attributes(self, clear_subgraph_inputs=False): + def _post_process_attributes(self): """ - Walks through attributes and replaces them by ONNX - values. + Walks through attributes and replaces them by ONNX values. """ # Looks into attributes if there is any tuple # (GraphProto, OnnxOperator). In that case, the function @@ -321,15 +279,11 @@ def _post_process_attributes(self, clear_subgraph_inputs=False): if isinstance(v, tuple) and isinstance(v[0], GraphProto): updates[k] = v[0] graph_algebra[k] = v[1] + if len(graph_algebra) > 0: self.kwargs.update(updates) self.graph_algebra = graph_algebra - if clear_subgraph_inputs: - for k, v in self.kwargs.items(): - if isinstance(v, GraphProto): - del v.input[:] - if self.__class__.__name__ == "OnnxConstantOfShape": if "value" in self.kwargs: value = self.kwargs['value'] @@ -355,9 +309,7 @@ def _post_process_attributes(self, clear_subgraph_inputs=False): if self.__class__.__name__ == "OnnxCast": if "to" in self.kwargs: value = self.kwargs['to'] - if isinstance(value, int): - return - to = guess_proto_type(_guess_numpy_type(value, None)) + stop self.kwargs['to'] = to return @@ -413,11 +365,12 @@ def set_onnx_name_prefix(self, onnx_prefix_name): @property def onnx_prefix(self): + "Returns a prefix for results coming out from this node." if self.onnx_prefix_name is None: name = self.__class__.__name__ if name.startswith("Onnx"): name = name[4:] - return name[:2] + return 'out_' + name[:3].lower() return self.onnx_prefix_name def __getitem__(self, index): @@ -447,8 +400,9 @@ def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None): for inp in obj.inputs: if isinstance(inp, OnnxOperator): new_stack.append(inp) - elif (isinstance(inp, Variable) and - inp.name not in set_inputs): + elif isinstance(inp, Variable): + if inp.name in set_inputs: + continue set_inputs.add(inp.name) if inputs is None: new_inputs.append(inp) @@ -459,9 +413,8 @@ def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None): raise ValueError( # pragma: no cover "Unable to find input %r in %r." % ( inp, inputs)) - elif (inputs in NP_TYPE_TO_TENSOR_TYPE or - numpy.dtype(inputs) in NP_TYPE_TO_TENSOR_TYPE): - new_inputs.append((inp, inputs)) + elif is_numpy_dtype(inputs): + new_inputs.append(inp.copy_add(inputs)) else: raise RuntimeError( # pragma: no cover "Unable to handle inputs=%r." % inputs) @@ -506,12 +459,10 @@ def _get_type(node, name=None, outputs=None): raise NotImplementedError( "Unexpected type for name=%r, outputs=%r." % ( name, outputs)) - if (outputs in NP_TYPE_TO_TENSOR_TYPE or - numpy.dtype(outputs) in NP_TYPE_TO_TENSOR_TYPE): + if is_numpy_dtype(outputs): return outputs raise RuntimeError( # pragma: no cover "Unable to handle outputs=%r." % outputs) - # outputs new_outputs = [] @@ -520,16 +471,18 @@ def _get_type(node, name=None, outputs=None): n = self.output_range[0] for i in range(n): to = _get_type(node, outputs=outputs) - new_outputs.append(('out%d' % i, to)) + res = ('out%d' % i, to) + new_outputs.append(Variable(res[0], added_dtype=to)) else: for o in self.output_names: to = _get_type(node, o, outputs=outputs) - new_outputs.append((o, to)) + res = (o, to) + new_outputs.append(o.copy_add(to)) if len(new_outputs) == 0: raise RuntimeError( "No detected outputs inputs=%r outputs=%r." % ( inputs, outputs)) - + return nodes, new_inputs, new_outputs def add_to(self, builder): @@ -540,8 +493,7 @@ def add_to(self, builder): n_outputs = ( self.output_range[0] if self.output_names is None else len(self.output_names)) - outputs = [builder.get_output_name(self, i) - for i in range(n_outputs)] + outputs = [builder.get_output_name(self, i) for i in range(n_outputs)] builder.add_node( self.operator_name, builder.get_unique_name('_' + self.operator_name.lower()), @@ -569,8 +521,7 @@ def to_onnx(self, inputs=None, outputs=None, target_opset = target_opset.get(dom, None) elif isinstance(target_opset, int): if self.domain not in ('', None): - # The target_opset is for the domain '' - # We ignore it. + # The target_opset is for the domain '' we ignore it. target_opset = None elif target_opset is not None: raise TypeError( @@ -586,7 +537,7 @@ def to_onnx(self, inputs=None, outputs=None, "target_opset={} is lower than the version={} requested " "for this node '{}'.".format( target_opset, self.op_version, self.__class__.__name__)) - + # inputs, outputs if isinstance(inputs, list): raise NotImplementedError( @@ -595,7 +546,6 @@ def to_onnx(self, inputs=None, outputs=None, raise NotImplementedError( "Unable to process outputs=%r." % (outputs, )) - # get the graph nodes, graph_inputs, graph_outputs = self._node_to_graph( other_outputs, inputs, outputs) diff --git a/mlprodict/npy/xop_auto.py b/mlprodict/npy/xop_auto.py index e5ad81c2a..746e16ebc 100644 --- a/mlprodict/npy/xop_auto.py +++ b/mlprodict/npy/xop_auto.py @@ -21,6 +21,7 @@ def __init__(self, *args): pass def render(self, **context): + "render" schemas = context['schemas'] rows = [] for sch in schemas: diff --git a/mlprodict/npy/xops.py b/mlprodict/npy/xop_factory.py similarity index 99% rename from mlprodict/npy/xops.py rename to mlprodict/npy/xop_factory.py index 6d07cdeb2..ff97a1a8a 100644 --- a/mlprodict/npy/xops.py +++ b/mlprodict/npy/xop_factory.py @@ -4,13 +4,12 @@ .. versionadded:: 0.9 """ -import sys import os import numpy from scipy.sparse.coo import coo_matrix import onnx from .xop_auto import get_rst_doc -from .xop_classes import Variable +from .xop_variable import Variable from ._cache import cache_folder @@ -176,7 +175,7 @@ def _c(obj, label, i): return (name, tys) cache_dir = cache_folder() - + res = _all_schemas cls = {} set_names = set() @@ -249,7 +248,7 @@ def _c(obj, label, i): for name in cls: # pylint: disable=C0206 if '_' not in name: continue - main, version = name.split('_') + main, _ = name.split('_') if main in cls: last = cls[main] else: @@ -276,4 +275,3 @@ def loadop(*names, cache=False, verbose=0, fLOG=print): if len(res) == 1: return res[0] return res - diff --git a/mlprodict/npy/xop_classes.py b/mlprodict/npy/xop_graph_builder.py similarity index 72% rename from mlprodict/npy/xop_classes.py rename to mlprodict/npy/xop_graph_builder.py index d85cdc8d8..3fad687b7 100644 --- a/mlprodict/npy/xop_classes.py +++ b/mlprodict/npy/xop_graph_builder.py @@ -5,15 +5,26 @@ .. versionadded:: 0.9 """ import numpy +from onnx import TensorProto from onnx.helper import ( make_node, make_graph, make_model, make_tensor_value_info) from onnx.numpy_helper import from_array -from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE from ..tools.asv_options_helper import get_opset_number_from_onnx +from .xop_variable import Variable, is_numpy_dtype, numpy_type_prototype def _default_OPSET_TO_IR_VERSION(): + """ + Returns the default mapping between opset and ir_version. + + .. runpython:: + :showcode: + + import pprint + from mlprodict.npy.xop_graph_builder import _default_OPSET_TO_IR_VERSION + pprint.pprint(_default_OPSET_TO_IR_VERSION()) + """ return { 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7, @@ -21,21 +32,6 @@ def _default_OPSET_TO_IR_VERSION(): } -class Variable: - """ - An input to an ONNX graph. - """ - - def __init__(self, name, dtype=None): - self.name = name - self.dtype = dtype - - def __repr__(self): - "usual" - return "%s(%r, %r)" % ( - self.__class__.__name__, self.name, self.dtype) - - class GraphBuilder: """ Graph builder. @@ -49,7 +45,7 @@ def __init__(self): self.output = [] self.opsets = {} self.names = set() - self.input_names = set() + self.input_names = {} self.output_names = {} self.output_names_rev = {} self.cl_onnx_op = OnnxOperator @@ -57,6 +53,10 @@ def __init__(self): @staticmethod def number2alpha(index): + """ + Converts a numbers into a string keeping the same + alphabetical order. + """ dec = str(int(index)) if len(dec) == 1: return dec @@ -88,17 +88,20 @@ def get_output_name(self, node, index): if key in self.output_names: name = self.output_names[key] return name + if node.output_names is None: - prefix = node.onnx_prefix_name if node.onnx_prefix_name else 'out' + prefix = node.onnx_prefix output = '%s%d' % (prefix, index) else: output = node.output_names[index] + if isinstance(output, Variable): n = output.name else: raise TypeError( # pragma: no cover "Unexpected type %r for output %d." % ( type(output), index)) + name = self.get_unique_name(n) self.output_names[key] = name self.output_names_rev[name] = key @@ -126,7 +129,7 @@ def get_input_names(self, node, inputs): if isinstance(i, Variable): names.append(i.name) self.names.add(i.name) - self.input_names.add(i.name) + self.input_names[i.name] = i elif isinstance(i, self.cl_onnx_op): name = self.get_output_name(i, 0) names.append(name) @@ -180,7 +183,7 @@ def add_node(self, op_type, name, inputs, outputs, domain='', domain=domain) self.node.append(node) - def _process_io(self, inputs, input_names, output=False): + def _process_io(self, inputs, input_names): if inputs is None: return [ make_tensor_value_info( @@ -188,61 +191,54 @@ def _process_io(self, inputs, input_names, output=False): for name in self.input_names] if not isinstance(inputs, list): - if inputs in NP_TYPE_TO_TENSOR_TYPE: + if is_numpy_dtype(inputs): inputs = [inputs] - elif numpy.dtype(inputs) in NP_TYPE_TO_TENSOR_TYPE: - inputs = [inputs] - if output and isinstance(input_names, dict): - keep_names = {} + + if input_names is None: + # outputs + input_names = [] for inp in inputs: - if isinstance(inp, Variable) and inp.name in input_names: - keep_names[inp.name] = input_names[inp.name] + if isinstance(inp, Variable): + if inp.name in self.output_names_rev: + input_names.append(inp) elif isinstance(inp, tuple) and len(inp) == 2: - var, dt = inp - if var.name in input_names: - keep_names[var.name] = input_names[var.name] + var, dtype = inp + if var.name in self.output_names_rev: + input_names.append(Variable(var.name, dtype)) else: raise TypeError( "Unexpected type %r in %r." % (inp, inputs)) - input_names = keep_names + if len(input_names) == 0: + raise RuntimeError( + "Unable to cross %r and %r." % (input, self.output_names_rev)) + elif not isinstance(input_names, list): + raise RuntimeError( + "Unexpected type for input_names %r." % type(input_names)) + if len(input_names) != len(inputs): raise RuntimeError( # pragma: no cover - "Mismatch between %r and %r (output=%r)." % ( - input_names, inputs, output)) - if isinstance(input_names, dict): - if len(input_names) == 1: - input_names = list(input_names.values()) - else: - raise NotImplementedError( - "Unexpected %r (output=%r)." % (input_names, output)) + "Mismatch between %r and %r." % ( + input_names, inputs)) + res = [] - for inp, name in zip(inputs, input_names): - if isinstance(inp, tuple): - if len(inp) != 2: - raise RuntimeError( - "Unexpected value %r (output=%r)." % ( - inp, output)) - dname, dtype = inp - if isinstance(dname, Variable): - dname = dname.name - if dname != name: - raise RuntimeError( - "Unexpected name %r != %r (inp=%r, output=%r)." % ( - dname, name, inp, output)) - else: - dtype = inp - if dtype in NP_TYPE_TO_TENSOR_TYPE: - res.append( - make_tensor_value_info( - name, NP_TYPE_TO_TENSOR_TYPE[dtype], None)) - elif numpy.dtype(dtype) in NP_TYPE_TO_TENSOR_TYPE: - res.append( - make_tensor_value_info( - name, NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)], None)) - else: + for inp, var in zip(inputs, input_names): + if isinstance(inp, (str, tuple)): + raise TypeError( + "inp not Variable but %r (%r)." % (type(inp), inp)) + if isinstance(var, (str, tuple)): + raise TypeError( + "var not Variable but %r (%r)." % (type(var), var)) + if isinstance(var, (str, tuple)): + raise TypeError( + "var not Variable but %r (%r)." % (type(var), var)) + # inp: Variable + # var: str + if inp != var: raise RuntimeError( - "Unexpected tuple(%r, %r) - output=%r." % ( - inp, name, output)) + "Unexpected %r != %r." % (inp, var)) + pt = numpy_type_prototype(inp.added_dtype or inp.dtype) + res.append(make_tensor_value_info(inp.name, pt, None)) + return res def to_onnx(self, inputs=None, outputs=None, @@ -259,9 +255,8 @@ def to_onnx(self, inputs=None, outputs=None, :return: onnx graph """ # inputs and outputs - self.input = self._process_io(inputs, self.input_names) - self.output = self._process_io( - outputs, self.output_names_rev, output=True) + self.input = self._process_io(inputs, list(self.input_names.values())) + self.output = self._process_io(outputs, None) graph = make_graph( self.node, 'XOP', self.input, self.output, self.initializer) diff --git a/mlprodict/npy/xop_variable.py b/mlprodict/npy/xop_variable.py new file mode 100644 index 000000000..632050e02 --- /dev/null +++ b/mlprodict/npy/xop_variable.py @@ -0,0 +1,92 @@ +""" +@file +@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. + +.. versionadded:: 0.9 +""" +import numpy +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE +from ..tools.asv_options_helper import get_opset_number_from_onnx + + +def is_numpy_dtype(dtype): + """ + Tells if a dtype is a numpy dtype. + + :param dtype: anything + :return: boolean + """ + if isinstance(dtype, (list, dict)): + return False + if dtype in NP_TYPE_TO_TENSOR_TYPE: + return True + dt = numpy.dtype(dtype) + if dt in NP_TYPE_TO_TENSOR_TYPE: + return True + return False + + +def numpy_type_prototype(dtype): + """ + Converts a numpy dtyp into a TensorProto dtype. + + :param dtype: dtype + :return: proto dtype + """ + if dtype in NP_TYPE_TO_TENSOR_TYPE: + return NP_TYPE_TO_TENSOR_TYPE[dtype] + dt = numpy.dtype(dtype) + if dt in NP_TYPE_TO_TENSOR_TYPE: + return NP_TYPE_TO_TENSOR_TYPE[dt] + raise ValueError( + "Unable to convert dtype %r into ProtoType." % dtype) + + +class Variable: + """ + An input to an ONNX graph. + """ + + def __init__(self, name, dtype=None, added_dtype=None): + self.name = name + self.dtype = dtype + self.added_dtype = added_dtype + + def __repr__(self): + "usual" + return "%s(%r, %r, %r)" % ( + self.__class__.__name__, self.name, self.dtype, self.added_dtype) + + def is_named(self, name): + "Tells the variable is named like that." + if not isinstance(name, str): + raise TypeError( + "name is expected to be a string not %r." % type(name)) + return self.name == name + + def copy_add(self, dtype): + """ + Returns a copy of this variable with a new dtype. + + :param dtype: added type + :return: @see cl Variable + """ + if self.added_dtype is not None: + raise RuntimeError( + "Cannot copy as added_dtype is not None.") + return Variable(self.name, self.dtype, dtype) + + def __eq__(self, other): + """ + Compares every attributes. + """ + if not isinstance(other, Variable): + raise TypeError( + "Unexpected type %r." % type(other)) + if self.name != other.name: + return False + dt1 = self.added_dtype or self.dtype + dt2 = other.added_dtype or other.dtype + if dt1 != dt2: + return False + return True From 1839f56eb42bb7cf6780e83c7c40d15e484a2560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 16 Feb 2022 10:46:15 +0100 Subject: [PATCH 05/13] Fixes the API for easy cases --- _unittests/ut_cli/test_cli_onnx_code.py | 2 +- _unittests/ut_npy/test_xop.py | 36 ++++++++++++++++- _unittests/ut_npy/test_xop_doc.py | 4 +- mlprodict/cli/onnx_code.py | 2 +- mlprodict/npy/onnx_numpy_compiler.py | 1 + mlprodict/npy/xop.py | 16 ++++++-- mlprodict/npy/xop_factory.py | 44 ++++++++++++++------- mlprodict/npy/xop_graph_builder.py | 23 ++++++----- mlprodict/npy/xop_variable.py | 51 +++++++++++++++++++------ mlprodict/npy/xops_opset.py | 12 +++--- 10 files changed, 140 insertions(+), 51 deletions(-) diff --git a/_unittests/ut_cli/test_cli_onnx_code.py b/_unittests/ut_cli/test_cli_onnx_code.py index ead90c2f0..861aac2d3 100644 --- a/_unittests/ut_cli/test_cli_onnx_code.py +++ b/_unittests/ut_cli/test_cli_onnx_code.py @@ -1,5 +1,5 @@ """ -@brief test tree node (time=10s) +@brief test tree node (time=15s) """ import os import unittest diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index 27c32d275..31287a9fb 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -1,6 +1,6 @@ # pylint: disable=E0611 """ -@brief test log(time=3s) +@brief test log(time=5s) """ import unittest import numpy @@ -62,6 +62,8 @@ def test_number2alpha(self): def test_onnx_add_sub_left(self): OnnxAdd, OnnxSub = loadop("OnnxAdd", "OnnxSub") + self.assertEqual(OnnxAdd.operator_name, 'Add') + self.assertEqual(OnnxSub.operator_name, 'Sub') ov = OnnxAdd('X', 'X') ov2 = OnnxSub(ov, 'X', output_names=['Y']) onx = ov2.to_onnx(numpy.float32, numpy.float32, verbose=0) @@ -72,6 +74,8 @@ def test_onnx_add_sub_left(self): def test_onnx_add_sub_right(self): OnnxAdd, OnnxSub = loadop("OnnxAdd", "OnnxSub") + self.assertEqual(OnnxAdd.operator_name, 'Add') + self.assertEqual(OnnxSub.operator_name, 'Sub') ov = OnnxAdd('X', 'X') ov2 = OnnxSub('X', ov, output_names=['Y']) onx = ov2.to_onnx(numpy.float32, numpy.float32, verbose=0) @@ -80,6 +84,36 @@ def test_onnx_add_sub_right(self): got = oinf.run({'X': x}) self.assertEqualArray(-x, got['Y']) + def test_onnx_transpose(self): + OnnxTranspose = loadop("OnnxTranspose") + ov = OnnxTranspose('X', perm=[1, 0], output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) + self.assertIn('perm', str(onx)) + oinf = OnnxInference(onx) + x = numpy.array([[-2, 2]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x.T, got['Y']) + + def test_onnx_transpose3(self): + OnnxTranspose = loadop("OnnxTranspose") + ov = OnnxTranspose('X', perm=[1, 0, 2], output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.float32, verbose=0) + self.assertIn('perm', str(onx)) + oinf = OnnxInference(onx) + x = numpy.array([[[-2, 2]]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.transpose(x, axes=(1, 0, 2)), got['Y']) + + def test_onnx_cast(self): + OnnxCast = loadop("OnnxCast") + ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) + onx = ov.to_onnx(numpy.float32, numpy.int64, verbose=0) + self.assertIn('to', str(onx)) + oinf = OnnxInference(onx) + x = numpy.array([[-2.1, 2.1]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x.astype(numpy.int64), got['Y']) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_npy/test_xop_doc.py b/_unittests/ut_npy/test_xop_doc.py index 3bb38995b..ee29cb77d 100644 --- a/_unittests/ut_npy/test_xop_doc.py +++ b/_unittests/ut_npy/test_xop_doc.py @@ -1,9 +1,9 @@ """ -@brief test log(time=3s) +@brief test log(time=10s) """ import unittest from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xops import _dynamic_class_creation +from mlprodict.npy.xop_factory import _dynamic_class_creation from mlprodict.npy.xop_auto import get_rst_doc diff --git a/mlprodict/cli/onnx_code.py b/mlprodict/cli/onnx_code.py index 8bfefea8f..fb63d237c 100644 --- a/mlprodict/cli/onnx_code.py +++ b/mlprodict/cli/onnx_code.py @@ -68,5 +68,5 @@ def dynamic_doc(verbose=0, fLOG=print): :param verbose: displays the list of operator :param fLOG: logging function """ - from ..npy.xops import _dynamic_class_creation + from ..npy.xop_factory import _dynamic_class_creation _dynamic_class_creation(cache=True, verbose=verbose, fLOG=fLOG) diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 00880a44b..ec9933718 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -338,6 +338,7 @@ def _to_onnx(self, op_version=None, signature=None, version=None): """ if self.onnx_ is None and self.fct_ is not None: from .onnx_variable import OnnxVar + from .xop_variable import numpy_type_prototype inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612 self._parse_annotation( diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 331e4b531..0dd1bbbc4 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -10,7 +10,7 @@ from onnx import GraphProto, TensorProto from onnx.helper import make_graph, make_model # pylint: disable=W0611 from onnx.numpy_helper import from_array -from .xop_variable import Variable, is_numpy_dtype +from .xop_variable import Variable, is_numpy_dtype, numpy_type_prototype from .xop_graph_builder import GraphBuilder @@ -102,6 +102,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, if op_version is None: if domain == '': + from ..tools.asv_options_helper import get_opset_number_from_onnx self.op_version = get_latest_tested_opset_version() else: self.op_version = None @@ -309,8 +310,14 @@ def _post_process_attributes(self): if self.__class__.__name__ == "OnnxCast": if "to" in self.kwargs: value = self.kwargs['to'] - stop - self.kwargs['to'] = to + if not isinstance(value, int): + try: + to = numpy_type_prototype(value) + except ValueError as e: + raise ValueError( + "Unable to convert argument to in operator cast, " + "type is %r, value is %r." % (type(value), value)) from e + self.kwargs['to'] = to return def find_schema(self, op_version): @@ -488,6 +495,9 @@ def _get_type(node, name=None, outputs=None): def add_to(self, builder): """ Adds to graph builder. + + :param builder: instance of @see cl GraphBuilder, + it must have a method `add_node` """ inputs = builder.get_input_names(self, self.inputs) n_outputs = ( diff --git a/mlprodict/npy/xop_factory.py b/mlprodict/npy/xop_factory.py index ff97a1a8a..241a13a19 100644 --- a/mlprodict/npy/xop_factory.py +++ b/mlprodict/npy/xop_factory.py @@ -159,7 +159,7 @@ def _populate_schemas(): return res -def _dynamic_class_creation(operator_names, cache=False, verbose=0, fLOG=print): +def _dynamic_class_creation(operator_names=None, cache=False, verbose=0, fLOG=print): """ Automatically generates classes for each of the operators module *onnx* defines and described at @@ -168,6 +168,13 @@ def _dynamic_class_creation(operator_names, cache=False, verbose=0, fLOG=print): and `Operators `_. + + :param operator_names: list of operators to request or None for all + :param cache: extract the documentation from onnx package and + saves it on disk it True + :param verbose: display some progress + :param fLOG: logging function + :return: list of requested operators as a tuple """ def _c(obj, label, i): name = '%s%d' % (obj.name or label, i) @@ -175,33 +182,40 @@ def _c(obj, label, i): return (name, tys) cache_dir = cache_folder() + if operator_names is None: + operator_names = list(_all_schemas) res = _all_schemas cls = {} - set_names = set() + set_names = dict() set_skip = set() - for op_name in operator_names: - set_names.add(op_name) + for pos, op_name in enumerate(operator_names): + set_names[op_name] = pos if '_' in op_name: n = op_name.split('_')[0] if n.startswith('Onnx'): set_skip.add(n) else: set_skip.add('Onnx' + n) - set_names.add(n) + if n not in set_names: + set_names[n] = -1 if verbose > 1 and fLOG is not None: fLOG("[_dynamic_class_creation] set_names=%r" % set_names) fLOG("[_dynamic_class_creation] set_skip=%r" % set_skip) - for op_name in set_names: + returned_classes = [] + positions = {} + + for op_name, position in set_names.items(): cl_name = op_name if op_name.startswith('Onnx') else 'Onnx' + op_name if verbose > 1 and fLOG is not None: fLOG('[_dynamic_class_creation] cl_name=%r op_name=%r (in=%d)' % ( cl_name, op_name, 1 if cl_name in _all_classes else 0)) if cl_name in _all_classes: if cl_name not in set_skip: - yield _all_classes[cl_name] + if position >= 0: + returned_classes.append((position, _all_classes[cl_name])) continue if verbose > 0 and fLOG is not None: fLOG("[_dynamic_class_creation] op_name=%r, cl_name=%r" % ( @@ -243,22 +257,26 @@ def _c(obj, label, i): getattr(schema, 'deprecated', False), schema.since_version, {}) cls[class_name] = cl + positions[class_name] = position # Retrieves past classes. for name in cls: # pylint: disable=C0206 if '_' not in name: continue main, _ = name.split('_') - if main in cls: + if main in cls: # pylint: disable=R1715 last = cls[main] else: last = _all_classes[main] last.past_version[name] = cls[name] _all_classes.update(cls) - for v in cls.values(): - if v not in set_skip: - yield v + for cl_name, v in cls.items(): + if v not in set_skip and positions.get(cl_name, -1) >= 0: + returned_classes.append((positions[cl_name], v)) + + returned_classes.sort() + return tuple(e[1] for e in returned_classes) _all_schemas = _populate_schemas() @@ -270,8 +288,8 @@ def loadop(*names, cache=False, verbose=0, fLOG=print): Dynamically creates a class for a every operator type in the given list. """ - res = tuple(_dynamic_class_creation( - names, cache=cache, verbose=verbose, fLOG=fLOG)) + res = _dynamic_class_creation( + names, cache=cache, verbose=verbose, fLOG=fLOG) if len(res) == 1: return res[0] return res diff --git a/mlprodict/npy/xop_graph_builder.py b/mlprodict/npy/xop_graph_builder.py index 3fad687b7..222fb4788 100644 --- a/mlprodict/npy/xop_graph_builder.py +++ b/mlprodict/npy/xop_graph_builder.py @@ -11,7 +11,7 @@ make_tensor_value_info) from onnx.numpy_helper import from_array from ..tools.asv_options_helper import get_opset_number_from_onnx -from .xop_variable import Variable, is_numpy_dtype, numpy_type_prototype +from .xop_variable import Variable, is_numpy_dtype def _default_OPSET_TO_IR_VERSION(): @@ -91,16 +91,15 @@ def get_output_name(self, node, index): if node.output_names is None: prefix = node.onnx_prefix - output = '%s%d' % (prefix, index) + n = '%s%d' % (prefix, index) else: output = node.output_names[index] - - if isinstance(output, Variable): - n = output.name - else: - raise TypeError( # pragma: no cover - "Unexpected type %r for output %d." % ( - type(output), index)) + if isinstance(output, Variable): + n = output.name + else: + raise TypeError( # pragma: no cover + "Unexpected type %r for output %d (output_names=%r)." % ( + type(output), index, node.output_names)) name = self.get_unique_name(n) self.output_names[key] = name @@ -180,7 +179,7 @@ def add_node(self, op_type, name, inputs, outputs, domain='', else: self.opsets[domain] = max(opset, self.opsets[domain]) node = make_node(op_type, inputs, outputs, name=name, - domain=domain) + domain=domain, **attributes) self.node.append(node) def _process_io(self, inputs, input_names): @@ -236,8 +235,8 @@ def _process_io(self, inputs, input_names): if inp != var: raise RuntimeError( "Unexpected %r != %r." % (inp, var)) - pt = numpy_type_prototype(inp.added_dtype or inp.dtype) - res.append(make_tensor_value_info(inp.name, pt, None)) + res.append(make_tensor_value_info( + inp.name, inp.proto_added_type, None)) return res diff --git a/mlprodict/npy/xop_variable.py b/mlprodict/npy/xop_variable.py index 632050e02..aed28b6cd 100644 --- a/mlprodict/npy/xop_variable.py +++ b/mlprodict/npy/xop_variable.py @@ -6,7 +6,6 @@ """ import numpy from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from ..tools.asv_options_helper import get_opset_number_from_onnx def is_numpy_dtype(dtype): @@ -47,15 +46,43 @@ class Variable: An input to an ONNX graph. """ - def __init__(self, name, dtype=None, added_dtype=None): - self.name = name - self.dtype = dtype - self.added_dtype = added_dtype + def __init__(self, name, dtype=None, shape=None, added_dtype=None): + self.name_ = name + self.dtype_ = dtype + self.added_dtype_ = added_dtype + self.shape_ = shape + + @property + def name(self): + "Returns the variable name." + return self.name_ + + @property + def proto_type(self): + "Returns the proto type for `self.dtype_`." + if self.dtype_ is None: + return 0 + return numpy_type_prototype(self.dtype_) + + @property + def proto_added_type(self): + "Returns the proto type for `self.added_dtype_` or `self.dtype_`." + dt = self.added_dtype_ or self.dtype_ + if dt is None: + return 0 + return numpy_type_prototype(dt) def __repr__(self): "usual" - return "%s(%r, %r, %r)" % ( - self.__class__.__name__, self.name, self.dtype, self.added_dtype) + kwargs = dict(dtype=self.dtype_, shape=self.shape_, + added_dtype=self.added_dtype_) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if len(kwargs) > 0: + msg = ", " + ", ".join("%s=%r" % (k, v) for k, v in kwargs.items()) + else: + msg = '' + return "%s(%r%s)" % ( + self.__class__.__name__, self.name_, msg) def is_named(self, name): "Tells the variable is named like that." @@ -71,10 +98,10 @@ def copy_add(self, dtype): :param dtype: added type :return: @see cl Variable """ - if self.added_dtype is not None: + if self.added_dtype_ is not None: raise RuntimeError( "Cannot copy as added_dtype is not None.") - return Variable(self.name, self.dtype, dtype) + return Variable(self.name_, self.dtype_, self.shape_, dtype) def __eq__(self, other): """ @@ -85,8 +112,8 @@ def __eq__(self, other): "Unexpected type %r." % type(other)) if self.name != other.name: return False - dt1 = self.added_dtype or self.dtype - dt2 = other.added_dtype or other.dtype - if dt1 != dt2: + if self.shape_ != other.shape_: + return False + if self.dtype_ != other.dtype_: return False return True diff --git a/mlprodict/npy/xops_opset.py b/mlprodict/npy/xops_opset.py index 75e9a3c72..63c5153e4 100644 --- a/mlprodict/npy/xops_opset.py +++ b/mlprodict/npy/xops_opset.py @@ -5,7 +5,7 @@ .. versionadded:: 0.9 """ -import numpy as np +import numpy def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, @@ -21,7 +21,7 @@ def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, *x, keepdims=keepdims, op_version=op_version, output_names=output_names) return OnnxReduceSum( - *x, np.array(axes, dtype=np.int64), + *x, numpy.array(axes, dtype=numpy.int64), keepdims=keepdims, op_version=op_version, output_names=output_names) if op_version >= 11: @@ -53,7 +53,7 @@ def OnnxSplitApi11(*x, axis=0, split=None, op_version=None, *x, axis=axis, op_version=op_version, output_names=output_names) return OnnxSplit( - *x, np.array(split, dtype=np.int64), axis=axis, + *x, numpy.array(split, dtype=numpy.int64), axis=axis, op_version=op_version, output_names=output_names) if op_version >= 11: if split is None: @@ -79,7 +79,7 @@ def OnnxSqueezeApi11(*x, axes=None, op_version=None, raise RuntimeError("op_version must be specified.") if op_version is None or op_version >= 13: return OnnxSqueeze( - *x, np.array(axes, dtype=np.int64), + *x, numpy.array(axes, dtype=numpy.int64), op_version=op_version, output_names=output_names) if op_version >= 11: return OnnxSqueeze_11( @@ -98,7 +98,7 @@ def OnnxUnsqueezeApi11(*x, axes=None, op_version=None, raise RuntimeError("op_version must be specified.") if op_version is None or op_version >= 13: return OnnxUnsqueeze( - *x, np.array(axes, dtype=np.int64), + *x, numpy.array(axes, dtype=numpy.int64), op_version=op_version, output_names=output_names) if op_version >= 11: return OnnxUnsqueeze_11( @@ -113,7 +113,7 @@ def OnnxReduceL2_typed(dtype, x, axes=None, keepdims=1, op_version=None, """ Adds operator ReduceL2 for float or double. """ - if dtype == np.float32: + if dtype == numpy.float32: return OnnxReduceL2( x, axes=axes, keepdims=keepdims, op_version=op_version, output_names=output_names) From 28fdb13a8d4b8078c7b21a5eee8147aa58911d53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 16 Feb 2022 11:22:52 +0100 Subject: [PATCH 06/13] Update onnx_numpy_compiler.py --- mlprodict/npy/onnx_numpy_compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index ec9933718..aba9da898 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -337,8 +337,8 @@ def _to_onnx(self, op_version=None, signature=None, version=None): Returns the onnx graph produced by function `fct_`. """ if self.onnx_ is None and self.fct_ is not None: + from skl2onnx.common.data_types import guess_numpy_type from .onnx_variable import OnnxVar - from .xop_variable import numpy_type_prototype inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612 self._parse_annotation( @@ -353,7 +353,7 @@ def _to_onnx(self, op_version=None, signature=None, version=None): getattr(self.fct_, '__module__', None))) names_in = [oi[0] for oi in inputs] names_out = [oi[0] for oi in outputs] - names_var = [OnnxVar(n, dtype=numpy_type_prototype(dt[1])) + names_var = [OnnxVar(n, dtype=guess_numpy_type(dt[1])) for n, dt in zip(names_in, inputs)] if 'op_version' in self.fct_.__code__.co_varnames: From da1f259e2c5263a588f3d785804648ccb60059d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 16 Feb 2022 12:19:34 +0100 Subject: [PATCH 07/13] documentation, delayed import --- .../source/_exts/generate_onnx_ops.py | 77 +++++++++++++++++++ _doc/sphinxdoc/source/api/xop.rst | 13 ++++ _doc/sphinxdoc/source/api/xop_supported.rst | 5 ++ _doc/sphinxdoc/source/conf.py | 3 + _unittests/ut_cli/test_cli_dynamic_doc.py | 4 +- _unittests/ut_npy/test_xop_doc.py | 6 +- mlprodict/npy/xop.py | 2 +- mlprodict/npy/xop_auto_import_.py | 27 +++++++ mlprodict/npy/xop_factory.py | 2 +- mlprodict/npy/xop_graph_builder.py | 3 +- 10 files changed, 136 insertions(+), 6 deletions(-) create mode 100644 _doc/sphinxdoc/source/_exts/generate_onnx_ops.py create mode 100644 _doc/sphinxdoc/source/api/xop_supported.rst create mode 100644 mlprodict/npy/xop_auto_import_.py diff --git a/_doc/sphinxdoc/source/_exts/generate_onnx_ops.py b/_doc/sphinxdoc/source/_exts/generate_onnx_ops.py new file mode 100644 index 000000000..8877636c7 --- /dev/null +++ b/_doc/sphinxdoc/source/_exts/generate_onnx_ops.py @@ -0,0 +1,77 @@ +""" +Extension for sphinx to display the onnx nodes. +""" +from docutils import nodes +from docutils.parsers.rst import Directive +from docutils.statemachine import StringList +import sphinx +from sphinx.util.nodes import nested_parse_with_titles +from tabulate import tabulate +from mlprodict.npy.xop_factory import _dynamic_class_creation + + +class SupportedOnnxOpsDirective(Directive): + """ + Automatically displays the list of supported ONNX models + *skl2onnx* can use to build converters. + """ + required_arguments = False + optional_arguments = 0 + final_argument_whitespace = True + option_spec = {} + has_content = False + + def run(self): + cls = _dynamic_class_creation() + cls_name = [(c.__name__, c) for c in cls] + rows = [] + sorted_cls_name = list(sorted(cls_name)) + main = nodes.container() + + def make_ref(cl): + return ":ref:`l-xop-onnx-{}`".format(cl.__name__) + + table = [] + cut = len(sorted_cls_name) // 3 + \ + (1 if len(sorted_cls_name) % 3 else 0) + for i in range(cut): + row = [] + row.append(make_ref(sorted_cls_name[i][1])) + if i + cut < len(sorted_cls_name): + row.append(make_ref(sorted_cls_name[i + cut][1])) + if i + cut * 2 < len(sorted_cls_name): + row.append(make_ref(sorted_cls_name[i + cut * 2][1])) + else: + row.append('') + else: + row.append('') + row.append('') + table.append(row) + + rst = tabulate(table, tablefmt="rst") + rows = rst.split("\n") + + node = nodes.container() + st = StringList(rows) + nested_parse_with_titles(self.state, st, node) + main += node + + rows.append('') + for name, cl in sorted_cls_name: + rows = [] + rows.append('.. _l-xop-onnx-{}:'.format(cl.__name__)) + rows.append('') + rows.append(cl.__name__) + rows.append('=' * len(cl.__name__)) + rows.append('') + rows.append( + ".. autoclass:: mlprodict.npy.xop.xop_auto_import_.{}".format(name)) + st = StringList(rows) + node = nodes.container() + nested_parse_with_titles(self.state, st, node) + main += node + + +def setup(app): + app.add_directive('supported-onnx-ops', SupportedOnnxOpsDirective) + return {'version': sphinx.__display_version__, 'parallel_read_safe': True} diff --git a/_doc/sphinxdoc/source/api/xop.rst b/_doc/sphinxdoc/source/api/xop.rst index 363817fa0..819351f22 100644 --- a/_doc/sphinxdoc/source/api/xop.rst +++ b/_doc/sphinxdoc/source/api/xop.rst @@ -7,6 +7,12 @@ Create ONNX graphs .. contents:: :local: +Example ++++++++ + +Converters +++++++++++ + API +++ @@ -33,3 +39,10 @@ API .. autosignature:: mlprodict.npy.xops_opset.OnnxReduceL2_typed .. autosignature:: mlprodict.npy.xops_opset.OnnxReshapeApi13 + +Available ONNX operators +++++++++++++++++++++++++ + +.. toctree:: + + xop_supported diff --git a/_doc/sphinxdoc/source/api/xop_supported.rst b/_doc/sphinxdoc/source/api/xop_supported.rst new file mode 100644 index 000000000..abb6e5fc6 --- /dev/null +++ b/_doc/sphinxdoc/source/api/xop_supported.rst @@ -0,0 +1,5 @@ + +Supported ONNX operators +======================== + +.. supported-onnx-ops:: diff --git a/_doc/sphinxdoc/source/conf.py b/_doc/sphinxdoc/source/conf.py index 1b79b764e..2ac86d231 100644 --- a/_doc/sphinxdoc/source/conf.py +++ b/_doc/sphinxdoc/source/conf.py @@ -20,11 +20,13 @@ try: import generate_visual_graphs import generate_automated_pages + import generate_onnx_ops except ImportError: # pragma: no cover this = os.path.dirname(__file__) sys.path.append(os.path.join(this, '_exts')) import generate_visual_graphs import generate_automated_pages + import generate_onnx_ops sys.path.insert(0, os.path.abspath(os.path.join(os.path.split(__file__)[0]))) @@ -44,6 +46,7 @@ 'sphinxcontrib.blockdiag', 'generate_automated_pages', 'generate_visual_graphs', + 'generate_onnx_ops', ]) html_css_files = ['my-styles.css'] diff --git a/_unittests/ut_cli/test_cli_dynamic_doc.py b/_unittests/ut_cli/test_cli_dynamic_doc.py index 82576dc16..1253b24ec 100644 --- a/_unittests/ut_cli/test_cli_dynamic_doc.py +++ b/_unittests/ut_cli/test_cli_dynamic_doc.py @@ -1,5 +1,5 @@ """ -@brief test tree node (time=10s) +@brief test tree node (time=23s) """ import unittest from pyquickhelper.loghelper import BufferedPrint @@ -13,7 +13,7 @@ def test_cli_onnx_code_help(self): st = BufferedPrint() main(args=["dynamic_doc", "--help"], fLOG=st.fprint) res = str(st) - self.assertIn("Generates the documentation", res) + self.assertIn("Generates", res) def test_cli_onnx_code(self): st = BufferedPrint() diff --git a/_unittests/ut_npy/test_xop_doc.py b/_unittests/ut_npy/test_xop_doc.py index ee29cb77d..09035f09d 100644 --- a/_unittests/ut_npy/test_xop_doc.py +++ b/_unittests/ut_npy/test_xop_doc.py @@ -18,6 +18,10 @@ def test_doc_onnx(self): rst = get_rst_doc() self.assertIn("**Summary**", rst) + def test_auto_import(self): + from mlprodict.npy.xop_auto_import_ import OnnxAdd # pylint: disable=E0611 + self.assertEqual(OnnxAdd.__name__, 'OnnxAdd') + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 0dd1bbbc4..9475d9a9e 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -103,7 +103,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, if op_version is None: if domain == '': from ..tools.asv_options_helper import get_opset_number_from_onnx - self.op_version = get_latest_tested_opset_version() + self.op_version = get_opset_number_from_onnx() else: self.op_version = None else: diff --git a/mlprodict/npy/xop_auto_import_.py b/mlprodict/npy/xop_auto_import_.py new file mode 100644 index 000000000..3db5fc730 --- /dev/null +++ b/mlprodict/npy/xop_auto_import_.py @@ -0,0 +1,27 @@ +""" +@file +@brief Importing this file takes time. It should be avoided. + +.. versionadded:: 0.9 +""" +import sys +from .xop_factory import _dynamic_class_creation + + +def _update_module(): + """ + Dynamically updates the module with operators defined by *ONNX*. + """ + res = _dynamic_class_creation() + this = sys.modules[__name__] + unique = set() + for cl in res: + setattr(this, cl.__name__, cl) + name = cl.__name__.split('_')[0] + unique.add(name) + res = _dynamic_class_creation(list(unique)) + for cl in res: + setattr(this, cl.__name__, cl) + + +_update_module() diff --git a/mlprodict/npy/xop_factory.py b/mlprodict/npy/xop_factory.py index 241a13a19..ef3d8d860 100644 --- a/mlprodict/npy/xop_factory.py +++ b/mlprodict/npy/xop_factory.py @@ -209,7 +209,7 @@ def _c(obj, label, i): for op_name, position in set_names.items(): cl_name = op_name if op_name.startswith('Onnx') else 'Onnx' + op_name - if verbose > 1 and fLOG is not None: + if verbose > 3 and fLOG is not None: fLOG('[_dynamic_class_creation] cl_name=%r op_name=%r (in=%d)' % ( cl_name, op_name, 1 if cl_name in _all_classes else 0)) if cl_name in _all_classes: diff --git a/mlprodict/npy/xop_graph_builder.py b/mlprodict/npy/xop_graph_builder.py index 222fb4788..ef85e461e 100644 --- a/mlprodict/npy/xop_graph_builder.py +++ b/mlprodict/npy/xop_graph_builder.py @@ -10,7 +10,6 @@ make_node, make_graph, make_model, make_tensor_value_info) from onnx.numpy_helper import from_array -from ..tools.asv_options_helper import get_opset_number_from_onnx from .xop_variable import Variable, is_numpy_dtype @@ -254,6 +253,8 @@ def to_onnx(self, inputs=None, outputs=None, :return: onnx graph """ # inputs and outputs + from ..tools.asv_options_helper import get_opset_number_from_onnx + self.input = self._process_io(inputs, list(self.input_names.values())) self.output = self._process_io(outputs, None) From 4839335b0960ee4a949fefcf7349c7ec15dbebb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 16 Feb 2022 13:54:49 +0100 Subject: [PATCH 08/13] improves import --- mlprodict/npy/xop.py | 9 +++++---- mlprodict/npy/xop_auto.py | 13 ++++++------- mlprodict/npy/xop_factory.py | 4 ++-- mlprodict/npy/xop_graph_builder.py | 6 ++---- mlprodict/npy/xop_variable.py | 14 ++++++++++++++ 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 9475d9a9e..4076de054 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -10,8 +10,8 @@ from onnx import GraphProto, TensorProto from onnx.helper import make_graph, make_model # pylint: disable=W0611 from onnx.numpy_helper import from_array -from .xop_variable import Variable, is_numpy_dtype, numpy_type_prototype -from .xop_graph_builder import GraphBuilder +from .xop_variable import ( + Variable, is_numpy_dtype, numpy_type_prototype, max_supported_opset) class OnnxOperatorItem: @@ -102,8 +102,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, if op_version is None: if domain == '': - from ..tools.asv_options_helper import get_opset_number_from_onnx - self.op_version = get_opset_number_from_onnx() + self.op_version = max_supported_opset() else: self.op_version = None else: @@ -525,6 +524,8 @@ def to_onnx(self, inputs=None, outputs=None, None for the default one :param verbose: prints information """ + from .xop_graph_builder import GraphBuilder + # opsets if isinstance(target_opset, dict): dom = self.domain or '' diff --git a/mlprodict/npy/xop_auto.py b/mlprodict/npy/xop_auto.py index 746e16ebc..34c952af4 100644 --- a/mlprodict/npy/xop_auto.py +++ b/mlprodict/npy/xop_auto.py @@ -131,8 +131,9 @@ def get_rst_doc(op_name=None): if op_name is None: schemas = onnx.defs.get_all_schemas_with_history() elif isinstance(op_name, str): - schemas = [schema for schema in onnx.defs.get_all_schemas_with_history( - ) if schema.name == op_name] + schemas = [ + schema for schema in onnx.defs.get_all_schemas_with_history() + if schema.name == op_name] if len(schemas) > 1: raise RuntimeError( "Multiple operators have the same name '{}'.".format(op_name)) @@ -149,8 +150,7 @@ def get_rst_doc(op_name=None): def format_name_with_domain(sch): if sch.domain: return '{} ({})'.format(sch.name, sch.domain) - else: - return sch.name + return sch.name def format_option(obj): opts = [] @@ -162,8 +162,7 @@ def format_option(obj): opts.append('heterogeneous') if opts: return " (%s)" % ", ".join(opts) - else: - return "" + return "" def getconstraint(const, ii): if const.type_param_str: @@ -219,7 +218,7 @@ def process_documentation(doc): return "\n".join(lines) def build_doc_url(sch): - doc_url = "https://github.com/onnx/onnx/blob/master/docs/Operators" + doc_url = "https://github.com/onnx/onnx/blob/main/docs/Operators" if "ml" in sch.domain: doc_url += "-ml" doc_url += ".md" diff --git a/mlprodict/npy/xop_factory.py b/mlprodict/npy/xop_factory.py index ef3d8d860..69850366d 100644 --- a/mlprodict/npy/xop_factory.py +++ b/mlprodict/npy/xop_factory.py @@ -8,9 +8,9 @@ import numpy from scipy.sparse.coo import coo_matrix import onnx -from .xop_auto import get_rst_doc -from .xop_variable import Variable from ._cache import cache_folder +from .xop_variable import Variable +from .xop_auto import get_rst_doc def ClassFactory(class_name, op_name, inputs, outputs, diff --git a/mlprodict/npy/xop_graph_builder.py b/mlprodict/npy/xop_graph_builder.py index ef85e461e..a8579b885 100644 --- a/mlprodict/npy/xop_graph_builder.py +++ b/mlprodict/npy/xop_graph_builder.py @@ -10,7 +10,7 @@ make_node, make_graph, make_model, make_tensor_value_info) from onnx.numpy_helper import from_array -from .xop_variable import Variable, is_numpy_dtype +from .xop_variable import Variable, is_numpy_dtype, max_supported_opset def _default_OPSET_TO_IR_VERSION(): @@ -253,15 +253,13 @@ def to_onnx(self, inputs=None, outputs=None, :return: onnx graph """ # inputs and outputs - from ..tools.asv_options_helper import get_opset_number_from_onnx - self.input = self._process_io(inputs, list(self.input_names.values())) self.output = self._process_io(outputs, None) graph = make_graph( self.node, 'XOP', self.input, self.output, self.initializer) onnx_model = make_model(graph) - opv = self.opsets.get('', get_opset_number_from_onnx()) + opv = self.opsets.get('', max_supported_opset()) opset2ir = _default_OPSET_TO_IR_VERSION() irv = opset2ir.get(opv, max(opset2ir.values())) onnx_model.ir_version = irv diff --git a/mlprodict/npy/xop_variable.py b/mlprodict/npy/xop_variable.py index aed28b6cd..8e3aea7b3 100644 --- a/mlprodict/npy/xop_variable.py +++ b/mlprodict/npy/xop_variable.py @@ -6,6 +6,20 @@ """ import numpy from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE +from onnx.defs import onnx_opset_version + + +def max_supported_opset(): + """ + Returns the latest supported opset for the main domain. + + .. runpython:: + :showcode: + + from mlprodict.npy.xop_variable import max_supported_opset + print("max_supported_opset() returns", max_supported_opset()) + """ + return min(15, onnx_opset_version()) def is_numpy_dtype(dtype): From b1a562ae27a6e455cb803ab0d4910ab2048497b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 16 Feb 2022 15:22:28 +0100 Subject: [PATCH 09/13] refactoring --- _doc/sphinxdoc/source/api/xop.rst | 22 +- _unittests/ut_npy/test_xop.py | 39 +- _unittests/ut_npy/test_xop_doc.py | 2 +- mlprodict/cli/onnx_code.py | 2 +- mlprodict/npy/xop.py | 835 ++++++----------- mlprodict/npy/xop_auto_import_.py | 2 +- mlprodict/npy/xop_factory.py | 295 ------ mlprodict/npy/xop_graph_builder.py | 272 ------ mlprodict/npy/xop_ops.py | 869 ++++++++++++++++++ mlprodict/npy/{xops_opset.py => xop_opset.py} | 0 mlprodict/npy/xop_variable.py | 39 +- 11 files changed, 1231 insertions(+), 1146 deletions(-) delete mode 100644 mlprodict/npy/xop_factory.py delete mode 100644 mlprodict/npy/xop_graph_builder.py create mode 100644 mlprodict/npy/xop_ops.py rename mlprodict/npy/{xops_opset.py => xop_opset.py} (100%) diff --git a/_doc/sphinxdoc/source/api/xop.rst b/_doc/sphinxdoc/source/api/xop.rst index 819351f22..d650f4bf3 100644 --- a/_doc/sphinxdoc/source/api/xop.rst +++ b/_doc/sphinxdoc/source/api/xop.rst @@ -16,29 +16,29 @@ Converters API +++ -.. autosignature:: mlprodict.npy.xops.ClassFactory +.. autosignature:: mlprodict.npy.xop.ClassFactory -.. autosignature:: mlprodict.npy.xops.dynamic_class_creation +.. autosignature:: mlprodict.npy.xop.dynamic_class_creation .. autosignature:: mlprodict.npy.xops_variable.Variable -.. autosignature:: mlprodict.npy.xops_graph_builder.GraphBuilder +.. autosignature:: mlprodict.npy.xop_ops._GraphBuilder -.. autosignature:: mlprodict.npy.xop.OnnxOperator +.. autosignature:: mlprodict.npy.xop_ops.OnnxOperator -.. autosignature:: mlprodict.npy.xop.OnnxOperatorItem +.. autosignature:: mlprodict.npy.xop_ops.OnnxOperatorItem -.. autosignature:: mlprodict.npy.xops_opset.OnnxReduceSumApi11 +.. autosignature:: mlprodict.npy.xop_opset.OnnxReduceSumApi11 -.. autosignature:: mlprodict.npy.xops_opset.OnnxSplitApi11 +.. autosignature:: mlprodict.npy.xop_opset.OnnxSplitApi11 -.. autosignature:: mlprodict.npy.xops_opset.OnnxSqueezeApi11 +.. autosignature:: mlprodict.npy.xop_opset.OnnxSqueezeApi11 -.. autosignature:: mlprodict.npy.xops_opset.OnnxUnsqueezeApi11 +.. autosignature:: mlprodict.npy.xop_opset.OnnxUnsqueezeApi11 -.. autosignature:: mlprodict.npy.xops_opset.OnnxReduceL2_typed +.. autosignature:: mlprodict.npy.xop_opset.OnnxReduceL2_typed -.. autosignature:: mlprodict.npy.xops_opset.OnnxReshapeApi13 +.. autosignature:: mlprodict.npy.xop_opset.OnnxReshapeApi13 Available ONNX operators ++++++++++++++++++++++++ diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index 31287a9fb..51d437d40 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -5,8 +5,9 @@ import unittest import numpy from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xop_factory import loadop -from mlprodict.npy.xop_graph_builder import GraphBuilder +from mlprodict.npy.xop import loadop +from mlprodict.npy.xop_variable import Variable +from mlprodict.npy.xop_ops import _GraphBuilder from mlprodict.onnxrt import OnnxInference @@ -55,7 +56,7 @@ def test_onnx_add_cst(self): self.assertEqualArray(x + 1, got['Y']) def test_number2alpha(self): - sel = [GraphBuilder.number2alpha(i) for i in range(0, 100001)] + sel = [_GraphBuilder.number2alpha(i) for i in range(0, 100001)] sel2 = sel.copy() sel2.sort() self.assertEqual(sel, sel2) @@ -114,6 +115,38 @@ def test_onnx_cast(self): got = oinf.run({'X': x}) self.assertEqualArray(x.astype(numpy.int64), got['Y']) + def test_onnx_dict(self): + OnnxCast = loadop("OnnxCast") + ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) + onx = ov.to_onnx({'X': numpy.float32}, {'Y': numpy.int64}, verbose=0) + self.assertIn('to', str(onx)) + oinf = OnnxInference(onx) + x = numpy.array([[-2.1, 2.1]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x.astype(numpy.int64), got['Y']) + + def test_onnx_var(self): + OnnxCast = loadop("OnnxCast") + ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) + onx = ov.to_onnx(Variable('X', numpy.float32), + Variable('Y', numpy.float32), verbose=0) + self.assertIn('to', str(onx)) + oinf = OnnxInference(onx) + x = numpy.array([[-2.1, 2.1]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x.astype(numpy.int64), got['Y']) + + def test_onnx_var_list(self): + OnnxCast = loadop("OnnxCast") + ov = OnnxCast('X', to=numpy.int64, output_names=['Y']) + onx = ov.to_onnx([Variable('X', numpy.float32)], + [Variable('Y', numpy.float32)], verbose=0) + self.assertIn('to', str(onx)) + oinf = OnnxInference(onx) + x = numpy.array([[-2.1, 2.1]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(x.astype(numpy.int64), got['Y']) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_npy/test_xop_doc.py b/_unittests/ut_npy/test_xop_doc.py index 09035f09d..713283825 100644 --- a/_unittests/ut_npy/test_xop_doc.py +++ b/_unittests/ut_npy/test_xop_doc.py @@ -3,7 +3,7 @@ """ import unittest from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xop_factory import _dynamic_class_creation +from mlprodict.npy.xop import _dynamic_class_creation from mlprodict.npy.xop_auto import get_rst_doc diff --git a/mlprodict/cli/onnx_code.py b/mlprodict/cli/onnx_code.py index fb63d237c..e4a335f0b 100644 --- a/mlprodict/cli/onnx_code.py +++ b/mlprodict/cli/onnx_code.py @@ -68,5 +68,5 @@ def dynamic_doc(verbose=0, fLOG=print): :param verbose: displays the list of operator :param fLOG: logging function """ - from ..npy.xop_factory import _dynamic_class_creation + from ..npy.xop import _dynamic_class_creation _dynamic_class_creation(cache=True, verbose=verbose, fLOG=fLOG) diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 4076de054..66b44ac30 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -1,578 +1,295 @@ -# pylint: disable=E1101 """ @file @brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. .. versionadded:: 0.9 """ +import os import numpy -from scipy.sparse import coo_matrix -from onnx import GraphProto, TensorProto -from onnx.helper import make_graph, make_model # pylint: disable=W0611 -from onnx.numpy_helper import from_array -from .xop_variable import ( - Variable, is_numpy_dtype, numpy_type_prototype, max_supported_opset) - - -class OnnxOperatorItem: +from scipy.sparse.coo import coo_matrix +import onnx +from ._cache import cache_folder +from .xop_variable import Variable +from .xop_auto import get_rst_doc + + +def ClassFactory(class_name, op_name, inputs, outputs, + input_range, output_range, + domain, attr_names, doc, + deprecated, since_version, + past_version): """ - Accessor to one of the output returned by a @see cl OnnxOperator. - - :param onx_op: @see cl OnnxOperator - :param index: integer - :param op_version: defines the opset version + Dynamically creates a class for a specific operator. + + :param class_name: class name + :param op_name: operator type + :param inputs: expected inputs + :param outputs: expected outputs + :param input_range: input range + :param output_range: output_range + :param domain: domain + :param attr_names: attributes names + :param doc: docstring + :param deprecated: is the operator deprecated + :param since_version: available since version + :param past_version: list of versions """ + from .xop_ops import OnnxOperator, OnnxOperatorItem - def __init__(self, onx_op, index, op_version=None): - if not isinstance(index, int): - raise TypeError("index must be an integer not %r." % type(index)) - self.onx_op = onx_op - self.index = index - self.op_version = op_version - - def __str__(self): - """ - usual - """ - return "%s[%d]" % (str(self.onx_op), self.index) - - def get_output_name(self, i=0): - """ - Returns the output name at position *i*. - """ - if i != 0: - raise IndexError("Can only return the first item.") - return self.onx_op.get_output_name(self.index) - - def get_output(self, i=0): - """ - Returns the output. - """ - if i != 0: - raise IndexError("Can only return the first item.") - return self.onx_op.get_output(self.index) - - -class OnnxOperator: - """ - Ancestor to every *ONNX* operator exposed in - :mod:`mlprodict.npy.xops` and :mod:`mlprodict.npy.xops_ml`. - - :param inputs: list of inputs expected by the operator - :param op_version: to select a specific version of the operator - :param output_names: used defined names for the outputs - :param domain: to overwrite the default domain - :param global_context: operator *If* executes one subgraph - whose nodes may use one existing output in the current - context. If not used in the main graph, these operators - are not linked to the output and cannot be retrieved. - *global_context* is a dictionary mapped the subgraph input - names to these operators. - :param kwargs: additional parameters of the operator - - .. versionadd:: 0.9 - """ + def __init__(self, *args, **kwargs): - def __init__(self, *inputs, op_version=None, output_names=None, - domain=None, global_context=None, **kwargs): - - if (output_names is None and - self.__class__.__name__.startswith("OnnxScan")): - raise NotImplementedError( - "The class cannot infer the number of variables " - "for node '{}' yet. output_names must be specified" - ".".format(self.__class__.__name__)) - if isinstance(output_names, (str, Variable)): - output_names = [output_names] - if isinstance(output_names[0], str): - output_names[0] = Variable(output_names[0]) - elif isinstance(output_names, list): - if len(output_names) == 0: - raise ValueError( - "output_names cannot be empty (operator %r)." - "" % self.__class__.__name__) - output_names = output_names.copy() - for i in range(len(output_names)): # pylint: disable=C0200 - if isinstance(output_names[i], str): - output_names[i] = Variable(output_names[i]) - elif output_names is not None: - raise TypeError( - "output_names must be a string or a list not %r." - "" % type(output_names)) + op_version = kwargs.pop('op_version', None) + if isinstance(op_version, dict): + op_version = op_version.get(domain, None) if op_version is None: - if domain == '': - self.op_version = max_supported_opset() - else: - self.op_version = None + if len(args) == 0 and input_range[0] == input_range[1]: + args = [_[0] for _ in self.__class__.expected_inputs] + if not (input_range[0] <= len(args) <= input_range[1]): + raise RuntimeError("Unexpected number of inputs, " + "got {}, expecting {} for operator " + "'{}'.".format( + len(args), len(inputs), op_name)) + + attr_names = self.attr_names + if '_' in self.__class__.__name__: + op_version_class = int(self.__class__.__name__.split('_')[-1]) + if op_version is None: + op_version = op_version_class + try: + op_version = min(op_version, op_version_class) + except TypeError: + raise TypeError( # pylint: disable=W0707 + "Could not compare versions {} ? {} for " + "class '{}' since_version {}. Parameter 'op_version' " + "is probably missing when the class " + "is instantiated.".format( + op_version, op_version_class, class_name, + since_version)) else: - self.op_version = op_version - self.since_version = self.__class__.since_version - - if (self.op_version is not None and - self.op_version < self.since_version): - schema = self.find_schema(self.op_version) - self.since_version = schema.since_version - self.expected_inputs = schema.expected_inputs.copy() - self.expected_outputs = schema.expected_outputs.copy() - self.input_range = schema.input_range - self.output_range = schema.output_range - else: - self.expected_inputs = ( - None if self.__class__.expected_inputs is None - else self.__class__.expected_inputs.copy()) - self.expected_outputs = ( - None if self.__class__.expected_outputs is None - else self.__class__.expected_outputs.copy()) - self.input_range = self.__class__.input_range - self.output_range = self.__class__.output_range - if self.__class__.__name__ not in { - 'OnnxScan', 'OnnxLoop', 'OnnxIf'}: - # The minimum opset depends on embedded graph - # by default, it takes the given op_version but the - # optimal value could be lower. - self.op_version = self.since_version - if self.op_version is None: - self.op_version = self.since_version - - if (self.op_version is not None and - self.op_version < self.since_version): + op_version_class = None + + # By default, the op_version is None. + # None means the latest available. + if op_version is None: + op_version = since_version + + found = None + if op_version is not None: + # attr_names refers to the most recent version of + # this operator. We may need an older one. + for op in range(op_version, 0, -1): + name = '{}_{}'.format(self.__class__.__name__, op) + if name in self.past_version: + found = (name, op) + attr_names = self.past_version[name].attr_names + break + if (op_version_class is not None and found is not None and + found[-1] != op_version_class): raise RuntimeError( - "Operator '{}': requested version {} < " - "{} schema version.".format( - self.__class__.__name__, - self.op_version, self.since_version)) - - self.state = None - self.domain = domain - self.kwargs = kwargs - self.onnx_prefix_name = None - - # check inputs - if len(inputs) == 0: - if self.input_range[0] == self.input_range[1]: - self.inputs = [OnnxOperator.UnscopedVariable(_[0]) - for _ in self.expected_inputs] - else: - # The number of inputs may vary. - self.inputs = None - else: - self.inputs = [] - for inp in inputs: - if isinstance(inp, str): - self.inputs.append(Variable(inp)) - elif isinstance(inp, (OnnxOperator, Variable, - OnnxOperatorItem)): - self.inputs.append(inp) - elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)): - self.inputs.append(inp) - else: + "op_version={} does not refer to the same opset as the class " + "name ('{}').".format(op_version, self.__class__.__name__)) + for key in kwargs: + if key in {'output_names', 'op_version', 'domain', 'ir_version', + 'global_context', 'clear_subgraph_inputs'}: + continue + if key not in attr_names: + raise TypeError("Argument '%s' not valid for '%s' opset=%s." + % (key, op_name, op_version)) + + if op_version is not None: + kwargs['op_version'] = op_version + # This class can only be created by a user. Let's check + # types are either a variable, an operator or an array. + for i, a in enumerate(args): + if isinstance(a, tuple): + if len(a) != 2: raise TypeError( - "Unable to interpret the input name for type {} in " - "operator '{}' (value={}).".format( - type(inp), self.__class__.__name__, inp)) - - if self.inputs is not None: - if (len(self.inputs) < self.input_range[0] or - len(self.inputs) > self.input_range[1]): - raise RuntimeError( - "Operator '{}' expects a number of inputs " - "in [{}, {}] not {} (expected opset={}, " - "class opset={})".format( - self.operator_name, *self.input_range, - len(self.inputs), op_version, self.op_version)) - # global context - if global_context is None: - self.global_context = None - else: - if not isinstance(global_context, dict): - raise TypeError( - "global_context must be a dictionary not %r." - "" % type(global_context)) - for k, v in global_context.items(): - if not isinstance(v, (OnnxOperator, OnnxOperatorItem)): + "Input %r is a tuple or class %r, it must have two " + "elements (name, type) not %r." % (i, class_name, a)) + if not isinstance(a[0], str): raise TypeError( - "Value %r in must be an OnnxOperator or an " - "OnnxOperatorItem not %r." % (k, type(v))) - self.global_context = global_context - - # check output - self.output_names = output_names - self.output_variables = None - - if self.output_names is not None: - if len(self.output_names) == 0: - raise ValueError( - "output_names can be None but cannot be empty for " - "operator %r." % self) - if self.output_variables is None: - self.output_variables = [None for o in self.output_names] - for i in range(len(self.output_names)): # pylint: disable=C0200 - name = self.output_names[i] - if isinstance(name, Variable): - self.output_variables[i] = name - else: - raise TypeError("output_names must be a list of strings " - "and element %r is %r (%r)" % ( - i, type(name), name)) - if all(map(lambda x: x is None, self.output_variables)): - self.output_variables = None - - if (self.output_names is not None and ( - self.expected_outputs is None or - len(self.output_names) > len(self.expected_outputs))): - if self.expected_outputs is None: - self.expected_outputs = [] - for i in range(len(self.expected_outputs), - len(self.output_names)): - self.expected_outputs.append((self.output_names[i], None)) - - if (self.expected_inputs is None or - len(self.inputs) > len(self.expected_inputs)): - if self.expected_inputs is None: - self.expected_inputs = [] - for i in range(len(self.expected_inputs), - len(self.inputs)): - inp = self.inputs[i] - if isinstance(inp, str): - inp = (inp, None) - elif hasattr(inp, 'add_to'): - # OnnxOperator - existing = set(_[0] for _ in self.expected_inputs) - i = 10 - name = "input%d" % (10 + i) - while name in existing: - i += 1 - name = "input%d" % (10 + i) - inp = (name, None) - self.expected_inputs.append(inp) - - self._post_process_attributes() - self._check() - - def _check(self): - input_types = (Variable, OnnxOperator, numpy.ndarray) - for o in self.inputs: - if not isinstance(o, input_types): - raise TypeError( - "Wrong type for inputs %r." % ( - self.inputs, )) - if self.output_names is not None: - for o in self.output_names: - if not isinstance(o, Variable): - raise TypeError( - "Wrong type for output_names %r." % ( - self.output_names, )) - - def _post_process_attributes(self): - """ - Walks through attributes and replaces them by ONNX values. - """ - # Looks into attributes if there is any tuple - # (GraphProto, OnnxOperator). In that case, the function - # replaces the tuple by the graph proto and keeps - # in attributes graph_algebra the OnnxOperator - # which is the source of it. - updates = {} - graph_algebra = {} - for k, v in self.kwargs.items(): - if isinstance(v, tuple) and isinstance(v[0], GraphProto): - updates[k] = v[0] - graph_algebra[k] = v[1] - - if len(graph_algebra) > 0: - self.kwargs.update(updates) - self.graph_algebra = graph_algebra - - if self.__class__.__name__ == "OnnxConstantOfShape": - if "value" in self.kwargs: - value = self.kwargs['value'] - if isinstance(value, TensorProto): - return - if isinstance(value, numpy.ndarray): - if value.shape == (1, ): - val = value[0] - elif len(value.shape) == 0: - val = value - else: - raise RuntimeError( - "Unexpected shape %r for value, it must be " - "an array of one element." % value.shape) - self.kwargs['value'] = from_array( - numpy.array([val], dtype=value.dtype)) - return - raise TypeError( - "Unexpected type %r for value. It should be an array " - "of one element." % type(value)) - return - - if self.__class__.__name__ == "OnnxCast": - if "to" in self.kwargs: - value = self.kwargs['to'] - if not isinstance(value, int): - try: - to = numpy_type_prototype(value) - except ValueError as e: - raise ValueError( - "Unable to convert argument to in operator cast, " - "type is %r, value is %r." % (type(value), value)) from e - self.kwargs['to'] = to - return - - def find_schema(self, op_version): - """ - Checks if there is an existing schema for a - specific version. - - :param op_version: requested version - :return: schema - """ - if not hasattr(self.__class__, 'past_version'): - raise RuntimeError("Missing attribute 'past_version', there is " - "no other available schema.") - found = None - for v in self.past_version.values(): - if v.since_version > op_version: - continue - if found is None or v.since_version > found.since_version: - found = v - if found is None: - raise RuntimeError( - "Operator '{}': requested version {} < " - "{} schema version.".format( - self.__class__.__name__, - op_version, self.since_version)) - return found - - def __str__(self): - """ - usual - """ - return "{}({} in) -> {}".format( - self.__class__.__name__, - len(self.inputs) if self.inputs is not None else 0, - [str(o) for o in self.output_names] - if self.output_names is not None else "?") - - def set_onnx_name_prefix(self, onnx_prefix_name): - """ - Provides a name to define a prefix in the onnx graph - to avoid to get unreadable node names. The method - does not overwrite an existing name, it propagates - the prefix to inputs and stops the propagation - if the prefix is already defined. - """ - if self.onnx_prefix_name is None: - self.onnx_prefix_name = onnx_prefix_name - for inp in self.inputs: - if hasattr(inp, 'set_onnx_prefix_name'): - inp.set_onnx_name_prefix(onnx_prefix_name) - return self - - @property - def onnx_prefix(self): - "Returns a prefix for results coming out from this node." - if self.onnx_prefix_name is None: - name = self.__class__.__name__ - if name.startswith("Onnx"): - name = name[4:] - return 'out_' + name[:3].lower() - return self.onnx_prefix_name - - def __getitem__(self, index): - """ - Returns an accessor to one of the output - of this node. - """ - return OnnxOperatorItem(self, index, self.op_version) - - def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None): - """ - Builds a graph as a list of nodes to walk through in that order. - """ - node_outputs = [self] - if other_outputs is not None: - node_outputs += other_outputs - - # walk through graphs - stack = list(node_outputs) - new_inputs = [] - set_inputs = set() - memo = [] - while len(stack) > 0: - memo.extend(stack) - new_stack = [] - for obj in stack: - for inp in obj.inputs: - if isinstance(inp, OnnxOperator): - new_stack.append(inp) - elif isinstance(inp, Variable): - if inp.name in set_inputs: - continue - set_inputs.add(inp.name) - if inputs is None: - new_inputs.append(inp) - elif isinstance(inputs, dict): - if inp in inputs: - new_inputs.append((inp, inputs[inp])) - else: - raise ValueError( # pragma: no cover - "Unable to find input %r in %r." % ( - inp, inputs)) - elif is_numpy_dtype(inputs): - new_inputs.append(inp.copy_add(inputs)) - else: - raise RuntimeError( # pragma: no cover - "Unable to handle inputs=%r." % inputs) - elif isinstance(inp, numpy.ndarray): - pass - else: - raise TypeError( - "Unexpected input type %r in node type %r." % ( - type(inp), type(obj))) - stack = new_stack - - if len(new_inputs) == 0: - raise RuntimeError( - "No detected inputs inputs=%r outputs=%r." % ( - inputs, outputs)) - - # eliminate duplicates - done = set() - nodes = [] - for node in reversed(memo): - if id(node) in done: + "Input %r is a tuple or class %r, it must be a tuple " + "(name, type) not %r." % (i, class_name, a)) continue - done.add(id(node)) - nodes.append(node) - - def _get_type(node, name=None, outputs=None): - if outputs is None: - raise NotImplementedError( - "outputs is None, expected_outputs=%r" % ( - node.expected_outputs, )) - if isinstance(outputs, dict): - if name is None: - raise RuntimeError( - "Unable to get type among %r, name=None." % ( - outputs, )) - if name not in outputs: - raise ValueError( # pragma: no cover - "Unable to find %r in %r." % ( - name, outputs)) - return outputs[name] - if isinstance(outputs, list): - raise NotImplementedError( - "Unexpected type for name=%r, outputs=%r." % ( - name, outputs)) - if is_numpy_dtype(outputs): - return outputs - raise RuntimeError( # pragma: no cover - "Unable to handle outputs=%r." % outputs) - - # outputs - new_outputs = [] - for node in node_outputs: - if node.output_names is None: - n = self.output_range[0] - for i in range(n): - to = _get_type(node, outputs=outputs) - res = ('out%d' % i, to) - new_outputs.append(Variable(res[0], added_dtype=to)) + if not isinstance(a, ( + Variable, OnnxOperator, numpy.ndarray, str, + OnnxOperatorItem, coo_matrix)): + raise TypeError( + "Unexpected type %r for input %r of operator %r. " + "It must be an instance of Variable (or a string), " + "OnnxOperator, OnnxOperatorItem, numpy.ndarray, " + "coo_matrix)." % ( + type(a), i, class_name)) + OnnxOperator.__init__(self, *args, **kwargs) + + newclass = type(class_name, (OnnxOperator,), + {"__init__": __init__, '__doc__': doc, + 'expected_inputs': inputs, + 'expected_outputs': outputs, + 'operator_name': op_name, + 'input_range': input_range, + 'output_range': output_range, + 'domain': domain, + 'is_deprecated': deprecated, + 'since_version': since_version, + 'past_version': past_version, + 'attr_names': attr_names, + '__module__': __name__}) + return newclass + + +def _populate_schemas(): + """ + Populates all schemas. + """ + res = {} + for schema in onnx.defs.get_all_schemas_with_history(): + if schema.support_level == schema.SupportType.EXPERIMENTAL: + # Skips experimental operators. + continue + # Multiple version can coexist. The last one is kept. + if schema.name in res: + if schema.since_version > res[schema.name].since_version: + # We keep the most recent one. + res[schema.name] = schema + else: + res[schema.name] = schema + res[schema.name + '_' + str(schema.since_version)] = schema + return res + + +def _dynamic_class_creation(operator_names=None, cache=False, verbose=0, fLOG=print): + """ + Automatically generates classes for each of the operators + module *onnx* defines and described at + `Operators + `_ + and `Operators + `_. + + :param operator_names: list of operators to request or None for all + :param cache: extract the documentation from onnx package and + saves it on disk it True + :param verbose: display some progress + :param fLOG: logging function + :return: list of requested operators as a tuple + """ + def _c(obj, label, i): + name = '%s%d' % (obj.name or label, i) + tys = obj.typeStr or '' + return (name, tys) + + cache_dir = cache_folder() + if operator_names is None: + operator_names = list(_all_schemas) + + res = _all_schemas + cls = {} + set_names = dict() + set_skip = set() + for pos, op_name in enumerate(operator_names): + set_names[op_name] = pos + if '_' in op_name: + n = op_name.split('_')[0] + if n.startswith('Onnx'): + set_skip.add(n) else: - for o in self.output_names: - to = _get_type(node, o, outputs=outputs) - res = (o, to) - new_outputs.append(o.copy_add(to)) - if len(new_outputs) == 0: - raise RuntimeError( - "No detected outputs inputs=%r outputs=%r." % ( - inputs, outputs)) - - return nodes, new_inputs, new_outputs - - def add_to(self, builder): - """ - Adds to graph builder. - - :param builder: instance of @see cl GraphBuilder, - it must have a method `add_node` - """ - inputs = builder.get_input_names(self, self.inputs) - n_outputs = ( - self.output_range[0] if self.output_names is None - else len(self.output_names)) - outputs = [builder.get_output_name(self, i) for i in range(n_outputs)] - builder.add_node( - self.operator_name, - builder.get_unique_name('_' + self.operator_name.lower()), - inputs, outputs, domain=self.domain, opset=self.op_version, - **self.kwargs) - - def to_onnx(self, inputs=None, outputs=None, - other_outputs=None, target_opset=None, - verbose=0): - """ - Converts this operator into an ONNX graph. - - :param inputs: information about type - :param outputs: information about types - :param other_outputs: additional nodes to consider - as graph outputs but not outputs of this particular - node - :param target_opset: dictionary with target opset per domain, - None for the default one - :param verbose: prints information - """ - from .xop_graph_builder import GraphBuilder - - # opsets - if isinstance(target_opset, dict): - dom = self.domain or '' - target_opset = target_opset.get(dom, None) - elif isinstance(target_opset, int): - if self.domain not in ('', None): - # The target_opset is for the domain '' we ignore it. - target_opset = None - elif target_opset is not None: - raise TypeError( - "target_opset must be a dictionary {domain: " - "target_opset} not %r for operator %r." % ( - target_opset, self.__class__.__name__)) - - if self.domain in ('', None) and target_opset == 1: - raise RuntimeError("target_opset cannot be 1.") - if (self.op_version is not None and target_opset is not None and - self.op_version > target_opset): - raise RuntimeError( - "target_opset={} is lower than the version={} requested " - "for this node '{}'.".format( - target_opset, self.op_version, self.__class__.__name__)) - - # inputs, outputs - if isinstance(inputs, list): - raise NotImplementedError( - "Unable to process inputs=%r." % (inputs, )) - if isinstance(outputs, list): - raise NotImplementedError( - "Unable to process outputs=%r." % (outputs, )) - - # get the graph - nodes, graph_inputs, graph_outputs = self._node_to_graph( - other_outputs, inputs, outputs) - if len(nodes) == 0: - raise RuntimeError( # pragma: no cover - "Node list is empty.") - if verbose > 1: - for i, n in enumerate(nodes): - print("nodes[%d]=%r" % (i, n)) - for i, n in enumerate(graph_inputs): - print("graph_inputs[%d]=%r" % (i, n)) - builder = GraphBuilder() - for node in nodes: - node.add_to(builder) - - return builder.to_onnx(inputs=graph_inputs, - outputs=graph_outputs, - target_opset=target_opset, - verbose=verbose) + set_skip.add('Onnx' + n) + if n not in set_names: + set_names[n] = -1 + + if verbose > 1 and fLOG is not None: + fLOG("[_dynamic_class_creation] set_names=%r" % set_names) + fLOG("[_dynamic_class_creation] set_skip=%r" % set_skip) + + returned_classes = [] + positions = {} + + for op_name, position in set_names.items(): + cl_name = op_name if op_name.startswith('Onnx') else 'Onnx' + op_name + if verbose > 3 and fLOG is not None: + fLOG('[_dynamic_class_creation] cl_name=%r op_name=%r (in=%d)' % ( + cl_name, op_name, 1 if cl_name in _all_classes else 0)) + if cl_name in _all_classes: + if cl_name not in set_skip: + if position >= 0: + returned_classes.append((position, _all_classes[cl_name])) + continue + if verbose > 0 and fLOG is not None: + fLOG("[_dynamic_class_creation] op_name=%r, cl_name=%r" % ( + op_name, cl_name)) + + name = op_name[4:] if op_name.startswith('Onnx') else op_name + try: + schema = res[name] + except KeyError as e: + raise ValueError( + "Operator %r (or %r) does not exists." % ( + name, op_name)) from e + inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] + outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] + args = [p for p in schema.attributes] + + if '_' in op_name: + class_name = "Onnx" + name + else: + class_name = "Onnx" + schema.name + + filename = os.path.join( + cache_dir, + schema.name + '_' + str(schema.since_version) + ".rst") + if not cache and os.path.exists(filename): + with open(filename, "r", encoding="utf-8") as f: + doc = f.read() + else: + doc = get_rst_doc(schema) + if cache: + with open(filename, 'w', encoding='utf-8') as f: + f.write(doc) + + cl = ClassFactory(class_name, schema.name, inputs, outputs, + [schema.min_input, schema.max_input], + [schema.min_output, schema.max_output], + schema.domain, args, + "**Version**" + doc.split('**Version**')[-1], + getattr(schema, 'deprecated', False), + schema.since_version, {}) + cls[class_name] = cl + positions[class_name] = position + + # Retrieves past classes. + for name in cls: # pylint: disable=C0206 + if '_' not in name: + continue + main, _ = name.split('_') + if main in cls: # pylint: disable=R1715 + last = cls[main] + else: + last = _all_classes[main] + last.past_version[name] = cls[name] + + _all_classes.update(cls) + for cl_name, v in cls.items(): + if v not in set_skip and positions.get(cl_name, -1) >= 0: + returned_classes.append((positions[cl_name], v)) + + returned_classes.sort() + return tuple(e[1] for e in returned_classes) + + +_all_schemas = _populate_schemas() +_all_classes = {} + + +def loadop(*names, cache=False, verbose=0, fLOG=print): + """ + Dynamically creates a class for a every operator type in + the given list. + """ + res = _dynamic_class_creation( + names, cache=cache, verbose=verbose, fLOG=fLOG) + if len(res) == 1: + return res[0] + return res diff --git a/mlprodict/npy/xop_auto_import_.py b/mlprodict/npy/xop_auto_import_.py index 3db5fc730..c44fc6f7d 100644 --- a/mlprodict/npy/xop_auto_import_.py +++ b/mlprodict/npy/xop_auto_import_.py @@ -5,7 +5,7 @@ .. versionadded:: 0.9 """ import sys -from .xop_factory import _dynamic_class_creation +from .xop import _dynamic_class_creation def _update_module(): diff --git a/mlprodict/npy/xop_factory.py b/mlprodict/npy/xop_factory.py deleted file mode 100644 index 69850366d..000000000 --- a/mlprodict/npy/xop_factory.py +++ /dev/null @@ -1,295 +0,0 @@ -""" -@file -@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. - -.. versionadded:: 0.9 -""" -import os -import numpy -from scipy.sparse.coo import coo_matrix -import onnx -from ._cache import cache_folder -from .xop_variable import Variable -from .xop_auto import get_rst_doc - - -def ClassFactory(class_name, op_name, inputs, outputs, - input_range, output_range, - domain, attr_names, doc, - deprecated, since_version, - past_version): - """ - Dynamically creates a class for a specific operator. - - :param class_name: class name - :param op_name: operator type - :param inputs: expected inputs - :param outputs: expected outputs - :param input_range: input range - :param output_range: output_range - :param domain: domain - :param attr_names: attributes names - :param doc: docstring - :param deprecated: is the operator deprecated - :param since_version: available since version - :param past_version: list of versions - """ - from .xop import OnnxOperator, OnnxOperatorItem - - def __init__(self, *args, **kwargs): - - op_version = kwargs.pop('op_version', None) - if isinstance(op_version, dict): - op_version = op_version.get(domain, None) - - if op_version is None: - if len(args) == 0 and input_range[0] == input_range[1]: - args = [_[0] for _ in self.__class__.expected_inputs] - if not (input_range[0] <= len(args) <= input_range[1]): - raise RuntimeError("Unexpected number of inputs, " - "got {}, expecting {} for operator " - "'{}'.".format( - len(args), len(inputs), op_name)) - - attr_names = self.attr_names - if '_' in self.__class__.__name__: - op_version_class = int(self.__class__.__name__.split('_')[-1]) - if op_version is None: - op_version = op_version_class - try: - op_version = min(op_version, op_version_class) - except TypeError: - raise TypeError( # pylint: disable=W0707 - "Could not compare versions {} ? {} for " - "class '{}' since_version {}. Parameter 'op_version' " - "is probably missing when the class " - "is instantiated.".format( - op_version, op_version_class, class_name, - since_version)) - else: - op_version_class = None - - # By default, the op_version is None. - # None means the latest available. - if op_version is None: - op_version = since_version - - found = None - if op_version is not None: - # attr_names refers to the most recent version of - # this operator. We may need an older one. - for op in range(op_version, 0, -1): - name = '{}_{}'.format(self.__class__.__name__, op) - if name in self.past_version: - found = (name, op) - attr_names = self.past_version[name].attr_names - break - if (op_version_class is not None and found is not None and - found[-1] != op_version_class): - raise RuntimeError( - "op_version={} does not refer to the same opset as the class " - "name ('{}').".format(op_version, self.__class__.__name__)) - for key in kwargs: - if key in {'output_names', 'op_version', 'domain', 'ir_version', - 'global_context', 'clear_subgraph_inputs'}: - continue - if key not in attr_names: - raise TypeError("Argument '%s' not valid for '%s' opset=%s." - % (key, op_name, op_version)) - - if op_version is not None: - kwargs['op_version'] = op_version - # This class can only be created by a user. Let's check - # types are either a variable, an operator or an array. - for i, a in enumerate(args): - if isinstance(a, tuple): - if len(a) != 2: - raise TypeError( - "Input %r is a tuple or class %r, it must have two " - "elements (name, type) not %r." % (i, class_name, a)) - if not isinstance(a[0], str): - raise TypeError( - "Input %r is a tuple or class %r, it must be a tuple " - "(name, type) not %r." % (i, class_name, a)) - continue - if not isinstance(a, ( - Variable, OnnxOperator, numpy.ndarray, str, - OnnxOperatorItem, coo_matrix)): - raise TypeError( - "Unexpected type %r for input %r of operator %r. " - "It must be an instance of Variable (or a string), " - "OnnxOperator, OnnxOperatorItem, numpy.ndarray, " - "coo_matrix)." % ( - type(a), i, class_name)) - OnnxOperator.__init__(self, *args, **kwargs) - - newclass = type(class_name, (OnnxOperator,), - {"__init__": __init__, '__doc__': doc, - 'expected_inputs': inputs, - 'expected_outputs': outputs, - 'operator_name': op_name, - 'input_range': input_range, - 'output_range': output_range, - 'domain': domain, - 'is_deprecated': deprecated, - 'since_version': since_version, - 'past_version': past_version, - 'attr_names': attr_names, - '__module__': __name__}) - return newclass - - -def _populate_schemas(): - """ - Populates all schemas. - """ - res = {} - for schema in onnx.defs.get_all_schemas_with_history(): - if schema.support_level == schema.SupportType.EXPERIMENTAL: - # Skips experimental operators. - continue - # Multiple version can coexist. The last one is kept. - if schema.name in res: - if schema.since_version > res[schema.name].since_version: - # We keep the most recent one. - res[schema.name] = schema - else: - res[schema.name] = schema - res[schema.name + '_' + str(schema.since_version)] = schema - return res - - -def _dynamic_class_creation(operator_names=None, cache=False, verbose=0, fLOG=print): - """ - Automatically generates classes for each of the operators - module *onnx* defines and described at - `Operators - `_ - and `Operators - `_. - - :param operator_names: list of operators to request or None for all - :param cache: extract the documentation from onnx package and - saves it on disk it True - :param verbose: display some progress - :param fLOG: logging function - :return: list of requested operators as a tuple - """ - def _c(obj, label, i): - name = '%s%d' % (obj.name or label, i) - tys = obj.typeStr or '' - return (name, tys) - - cache_dir = cache_folder() - if operator_names is None: - operator_names = list(_all_schemas) - - res = _all_schemas - cls = {} - set_names = dict() - set_skip = set() - for pos, op_name in enumerate(operator_names): - set_names[op_name] = pos - if '_' in op_name: - n = op_name.split('_')[0] - if n.startswith('Onnx'): - set_skip.add(n) - else: - set_skip.add('Onnx' + n) - if n not in set_names: - set_names[n] = -1 - - if verbose > 1 and fLOG is not None: - fLOG("[_dynamic_class_creation] set_names=%r" % set_names) - fLOG("[_dynamic_class_creation] set_skip=%r" % set_skip) - - returned_classes = [] - positions = {} - - for op_name, position in set_names.items(): - cl_name = op_name if op_name.startswith('Onnx') else 'Onnx' + op_name - if verbose > 3 and fLOG is not None: - fLOG('[_dynamic_class_creation] cl_name=%r op_name=%r (in=%d)' % ( - cl_name, op_name, 1 if cl_name in _all_classes else 0)) - if cl_name in _all_classes: - if cl_name not in set_skip: - if position >= 0: - returned_classes.append((position, _all_classes[cl_name])) - continue - if verbose > 0 and fLOG is not None: - fLOG("[_dynamic_class_creation] op_name=%r, cl_name=%r" % ( - op_name, cl_name)) - - name = op_name[4:] if op_name.startswith('Onnx') else op_name - try: - schema = res[name] - except KeyError as e: - raise ValueError( - "Operator %r (or %r) does not exists." % ( - name, op_name)) from e - inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] - outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] - args = [p for p in schema.attributes] - - if '_' in op_name: - class_name = "Onnx" + name - else: - class_name = "Onnx" + schema.name - - filename = os.path.join( - cache_dir, - schema.name + '_' + str(schema.since_version) + ".rst") - if not cache and os.path.exists(filename): - with open(filename, "r", encoding="utf-8") as f: - doc = f.read() - else: - doc = get_rst_doc(schema) - if cache: - with open(filename, 'w', encoding='utf-8') as f: - f.write(doc) - - cl = ClassFactory(class_name, schema.name, inputs, outputs, - [schema.min_input, schema.max_input], - [schema.min_output, schema.max_output], - schema.domain, args, - "**Version**" + doc.split('**Version**')[-1], - getattr(schema, 'deprecated', False), - schema.since_version, {}) - cls[class_name] = cl - positions[class_name] = position - - # Retrieves past classes. - for name in cls: # pylint: disable=C0206 - if '_' not in name: - continue - main, _ = name.split('_') - if main in cls: # pylint: disable=R1715 - last = cls[main] - else: - last = _all_classes[main] - last.past_version[name] = cls[name] - - _all_classes.update(cls) - for cl_name, v in cls.items(): - if v not in set_skip and positions.get(cl_name, -1) >= 0: - returned_classes.append((positions[cl_name], v)) - - returned_classes.sort() - return tuple(e[1] for e in returned_classes) - - -_all_schemas = _populate_schemas() -_all_classes = {} - - -def loadop(*names, cache=False, verbose=0, fLOG=print): - """ - Dynamically creates a class for a every operator type in - the given list. - """ - res = _dynamic_class_creation( - names, cache=cache, verbose=verbose, fLOG=fLOG) - if len(res) == 1: - return res[0] - return res diff --git a/mlprodict/npy/xop_graph_builder.py b/mlprodict/npy/xop_graph_builder.py deleted file mode 100644 index a8579b885..000000000 --- a/mlprodict/npy/xop_graph_builder.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -@file -@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. - -.. versionadded:: 0.9 -""" -import numpy -from onnx import TensorProto -from onnx.helper import ( - make_node, make_graph, make_model, - make_tensor_value_info) -from onnx.numpy_helper import from_array -from .xop_variable import Variable, is_numpy_dtype, max_supported_opset - - -def _default_OPSET_TO_IR_VERSION(): - """ - Returns the default mapping between opset and ir_version. - - .. runpython:: - :showcode: - - import pprint - from mlprodict.npy.xop_graph_builder import _default_OPSET_TO_IR_VERSION - pprint.pprint(_default_OPSET_TO_IR_VERSION()) - """ - return { - 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, - 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7, - 13: 7, 14: 7, 15: 8 - } - - -class GraphBuilder: - """ - Graph builder. - """ - - def __init__(self): - from .xop import OnnxOperator, OnnxOperatorItem - self.initializer = [] - self.node = [] - self.input = [] - self.output = [] - self.opsets = {} - self.names = set() - self.input_names = {} - self.output_names = {} - self.output_names_rev = {} - self.cl_onnx_op = OnnxOperator - self.cl_onnx_op_item = OnnxOperatorItem - - @staticmethod - def number2alpha(index): - """ - Converts a numbers into a string keeping the same - alphabetical order. - """ - dec = str(int(index)) - if len(dec) == 1: - return dec - return chr(96 + len(dec)) + dec - - def get_unique_name(self, name): - """ - Returns a unique name to name an output. - """ - if not isinstance(name, str): - raise TypeError( # pragma: no cover - "name must be a string not %r." % type(name)) - if name not in self.names: - self.names.add(name) - return name - i = 1 - new_name = "%s_%s" % (name, self.number2alpha(i)) - while new_name in self.names: - i += 1 - new_name = "%s_%s" % (name, self.number2alpha(i)) - self.names.add(new_name) - return new_name - - def get_output_name(self, node, index): - """ - Returns the output name for a node. - """ - key = id(node), index - if key in self.output_names: - name = self.output_names[key] - return name - - if node.output_names is None: - prefix = node.onnx_prefix - n = '%s%d' % (prefix, index) - else: - output = node.output_names[index] - if isinstance(output, Variable): - n = output.name - else: - raise TypeError( # pragma: no cover - "Unexpected type %r for output %d (output_names=%r)." % ( - type(output), index, node.output_names)) - - name = self.get_unique_name(n) - self.output_names[key] = name - self.output_names_rev[name] = key - if node.output_names is not None: - var = node.output_names[index] - if isinstance(var, Variable): - var = var.name - if var != name: - raise RuntimeError( - "Output unique name %r is different from the " - "expected name %r at position %r." % ( - name, node.output_names[index], index)) - return name - - def get_input_names(self, node, inputs): - """ - Returns input names for node *node* and inputs *inputs*. - - :param node: node - :param inputs: inputs - :return: name - """ - names = [] - for i in inputs: - if isinstance(i, Variable): - names.append(i.name) - self.names.add(i.name) - self.input_names[i.name] = i - elif isinstance(i, self.cl_onnx_op): - name = self.get_output_name(i, 0) - names.append(name) - self.names.add(name) - elif isinstance(i, self.cl_onnx_op_item): - name = self.get_output_name(i.onnx_op, i.index) - names.append(name) - self.names.add(name) - elif isinstance(i, numpy.ndarray): - # Adding an initializer - name = self.get_unique_name('init') - init = from_array(i, name) - self.initializer.append(init) - names.append(name) - self.names.add(name) - else: - raise TypeError( # pragma: no cover - "Unexpected type for an input %r." % type(i)) - return names - - def add_node(self, op_type, name, inputs, outputs, domain='', - opset=None, **attributes): - """ - Adds a node to the graph. - - :param op_type: operator type - :param name: node name - :param inputs: inputs name list - :param outputs: outputs name list - :param domain: node domain - :param opset: node opset - """ - if not isinstance(inputs, list): - raise TypeError( # pragma: no cover - "inputs must be a list not %r." % type(inputs)) - if not isinstance(outputs, list): - raise TypeError( # pragma: no cover - "inputs must be a list not %r." % type(outputs)) - if any(map(lambda x: not isinstance(x, str), inputs)): - raise TypeError( # pragma: no cover - "inputs must be all strings not %r." % inputs) - if any(map(lambda x: not isinstance(x, (str, Variable)), outputs)): - raise TypeError( # pragma: no cover - "outputs must be all strings not %r." % outputs) - if opset is not None: - if domain not in self.opsets: - self.opsets[domain] = opset - else: - self.opsets[domain] = max(opset, self.opsets[domain]) - node = make_node(op_type, inputs, outputs, name=name, - domain=domain, **attributes) - self.node.append(node) - - def _process_io(self, inputs, input_names): - if inputs is None: - return [ - make_tensor_value_info( - 'X', TensorProto.FLOAT, None) # pylint: disable=E1101 - for name in self.input_names] - - if not isinstance(inputs, list): - if is_numpy_dtype(inputs): - inputs = [inputs] - - if input_names is None: - # outputs - input_names = [] - for inp in inputs: - if isinstance(inp, Variable): - if inp.name in self.output_names_rev: - input_names.append(inp) - elif isinstance(inp, tuple) and len(inp) == 2: - var, dtype = inp - if var.name in self.output_names_rev: - input_names.append(Variable(var.name, dtype)) - else: - raise TypeError( - "Unexpected type %r in %r." % (inp, inputs)) - if len(input_names) == 0: - raise RuntimeError( - "Unable to cross %r and %r." % (input, self.output_names_rev)) - elif not isinstance(input_names, list): - raise RuntimeError( - "Unexpected type for input_names %r." % type(input_names)) - - if len(input_names) != len(inputs): - raise RuntimeError( # pragma: no cover - "Mismatch between %r and %r." % ( - input_names, inputs)) - - res = [] - for inp, var in zip(inputs, input_names): - if isinstance(inp, (str, tuple)): - raise TypeError( - "inp not Variable but %r (%r)." % (type(inp), inp)) - if isinstance(var, (str, tuple)): - raise TypeError( - "var not Variable but %r (%r)." % (type(var), var)) - if isinstance(var, (str, tuple)): - raise TypeError( - "var not Variable but %r (%r)." % (type(var), var)) - # inp: Variable - # var: str - if inp != var: - raise RuntimeError( - "Unexpected %r != %r." % (inp, var)) - res.append(make_tensor_value_info( - inp.name, inp.proto_added_type, None)) - - return res - - def to_onnx(self, inputs=None, outputs=None, - target_opset=None, verbose=0): - """ - Converts this operator into an ONNX graph. - - :param inputs: specific inputs (as a dictionary) or - default inputs if not specified - :param outputs: specific outputs - :param target_opset: dictionary with target opset per domain, - None for the default one - :param verbose: prints information - :return: onnx graph - """ - # inputs and outputs - self.input = self._process_io(inputs, list(self.input_names.values())) - self.output = self._process_io(outputs, None) - - graph = make_graph( - self.node, 'XOP', self.input, self.output, self.initializer) - onnx_model = make_model(graph) - opv = self.opsets.get('', max_supported_opset()) - opset2ir = _default_OPSET_TO_IR_VERSION() - irv = opset2ir.get(opv, max(opset2ir.values())) - onnx_model.ir_version = irv - - del onnx_model.opset_import[:] # pylint: disable=E1101 - for k, v in self.opsets.items(): - op_set = onnx_model.opset_import.add() # pylint: disable=E1101 - op_set.domain = k or '' - op_set.version = v - return onnx_model diff --git a/mlprodict/npy/xop_ops.py b/mlprodict/npy/xop_ops.py new file mode 100644 index 000000000..47be11dbf --- /dev/null +++ b/mlprodict/npy/xop_ops.py @@ -0,0 +1,869 @@ +# pylint: disable=E1101 +""" +@file +@brief Easier API to build onnx graphs. Inspired from :epkg:`skl2onnx`. + +.. versionadded:: 0.9 +""" +import numpy +from scipy.sparse import coo_matrix +from onnx import GraphProto, TensorProto +from onnx.helper import ( + make_node, make_graph, make_model, + make_tensor_value_info) +from onnx.numpy_helper import from_array +from .xop_variable import ( + Variable, is_numpy_dtype, numpy_type_prototype, max_supported_opset) + + +class OnnxOperatorItem: + """ + Accessor to one of the output returned by a @see cl OnnxOperator. + + :param onx_op: @see cl OnnxOperator + :param index: integer + :param op_version: defines the opset version + """ + + def __init__(self, onx_op, index, op_version=None): + if not isinstance(index, int): + raise TypeError("index must be an integer not %r." % type(index)) + self.onx_op = onx_op + self.index = index + self.op_version = op_version + + def __str__(self): + """ + usual + """ + return "%s[%d]" % (str(self.onx_op), self.index) + + def get_output_name(self, i=0): + """ + Returns the output name at position *i*. + """ + if i != 0: + raise IndexError("Can only return the first item.") + return self.onx_op.get_output_name(self.index) + + def get_output(self, i=0): + """ + Returns the output. + """ + if i != 0: + raise IndexError("Can only return the first item.") + return self.onx_op.get_output(self.index) + + +class OnnxOperator: + """ + Ancestor to every *ONNX* operator exposed in + :mod:`mlprodict.npy.xops` and :mod:`mlprodict.npy.xops_ml`. + + :param inputs: list of inputs expected by the operator + :param op_version: to select a specific version of the operator + :param output_names: used defined names for the outputs + :param domain: to overwrite the default domain + :param global_context: operator *If* executes one subgraph + whose nodes may use one existing output in the current + context. If not used in the main graph, these operators + are not linked to the output and cannot be retrieved. + *global_context* is a dictionary mapped the subgraph input + names to these operators. + :param kwargs: additional parameters of the operator + + .. versionadd:: 0.9 + """ + + def __init__(self, *inputs, op_version=None, output_names=None, + domain=None, global_context=None, **kwargs): + + if (output_names is None and + self.__class__.__name__.startswith("OnnxScan")): + raise NotImplementedError( + "The class cannot infer the number of variables " + "for node '{}' yet. output_names must be specified" + ".".format(self.__class__.__name__)) + if isinstance(output_names, (str, Variable)): + output_names = [output_names] + if isinstance(output_names[0], str): + output_names[0] = Variable(output_names[0]) + elif isinstance(output_names, list): + if len(output_names) == 0: + raise ValueError( + "output_names cannot be empty (operator %r)." + "" % self.__class__.__name__) + output_names = output_names.copy() + for i in range(len(output_names)): # pylint: disable=C0200 + if isinstance(output_names[i], str): + output_names[i] = Variable(output_names[i]) + elif output_names is not None: + raise TypeError( + "output_names must be a string or a list not %r." + "" % type(output_names)) + + if op_version is None: + if domain == '': + self.op_version = max_supported_opset() + else: + self.op_version = None + else: + self.op_version = op_version + self.since_version = self.__class__.since_version + + if (self.op_version is not None and + self.op_version < self.since_version): + schema = self.find_schema(self.op_version) + self.since_version = schema.since_version + self.expected_inputs = schema.expected_inputs.copy() + self.expected_outputs = schema.expected_outputs.copy() + self.input_range = schema.input_range + self.output_range = schema.output_range + else: + self.expected_inputs = ( + None if self.__class__.expected_inputs is None + else self.__class__.expected_inputs.copy()) + self.expected_outputs = ( + None if self.__class__.expected_outputs is None + else self.__class__.expected_outputs.copy()) + self.input_range = self.__class__.input_range + self.output_range = self.__class__.output_range + if self.__class__.__name__ not in { + 'OnnxScan', 'OnnxLoop', 'OnnxIf'}: + # The minimum opset depends on embedded graph + # by default, it takes the given op_version but the + # optimal value could be lower. + self.op_version = self.since_version + if self.op_version is None: + self.op_version = self.since_version + + if (self.op_version is not None and + self.op_version < self.since_version): + raise RuntimeError( + "Operator '{}': requested version {} < " + "{} schema version.".format( + self.__class__.__name__, + self.op_version, self.since_version)) + + self.state = None + self.domain = domain + self.kwargs = kwargs + self.onnx_prefix_name = None + + # check inputs + if len(inputs) == 0: + if self.input_range[0] == self.input_range[1]: + self.inputs = [OnnxOperator.UnscopedVariable(_[0]) + for _ in self.expected_inputs] + else: + # The number of inputs may vary. + self.inputs = None + else: + self.inputs = [] + for inp in inputs: + if isinstance(inp, str): + self.inputs.append(Variable(inp)) + elif isinstance(inp, (OnnxOperator, Variable, + OnnxOperatorItem)): + self.inputs.append(inp) + elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)): + self.inputs.append(inp) + else: + raise TypeError( + "Unable to interpret the input name for type {} in " + "operator '{}' (value={}).".format( + type(inp), self.__class__.__name__, inp)) + + if self.inputs is not None: + if (len(self.inputs) < self.input_range[0] or + len(self.inputs) > self.input_range[1]): + raise RuntimeError( + "Operator '{}' expects a number of inputs " + "in [{}, {}] not {} (expected opset={}, " + "class opset={})".format( + self.operator_name, *self.input_range, + len(self.inputs), op_version, self.op_version)) + # global context + if global_context is None: + self.global_context = None + else: + if not isinstance(global_context, dict): + raise TypeError( + "global_context must be a dictionary not %r." + "" % type(global_context)) + for k, v in global_context.items(): + if not isinstance(v, (OnnxOperator, OnnxOperatorItem)): + raise TypeError( + "Value %r in must be an OnnxOperator or an " + "OnnxOperatorItem not %r." % (k, type(v))) + self.global_context = global_context + + # check output + self.output_names = output_names + self.output_variables = None + + if self.output_names is not None: + if len(self.output_names) == 0: + raise ValueError( + "output_names can be None but cannot be empty for " + "operator %r." % self) + if self.output_variables is None: + self.output_variables = [None for o in self.output_names] + for i in range(len(self.output_names)): # pylint: disable=C0200 + name = self.output_names[i] + if isinstance(name, Variable): + self.output_variables[i] = name + else: + raise TypeError("output_names must be a list of strings " + "and element %r is %r (%r)" % ( + i, type(name), name)) + if all(map(lambda x: x is None, self.output_variables)): + self.output_variables = None + + if (self.output_names is not None and ( + self.expected_outputs is None or + len(self.output_names) > len(self.expected_outputs))): + if self.expected_outputs is None: + self.expected_outputs = [] + for i in range(len(self.expected_outputs), + len(self.output_names)): + self.expected_outputs.append((self.output_names[i], None)) + + if (self.expected_inputs is None or + len(self.inputs) > len(self.expected_inputs)): + if self.expected_inputs is None: + self.expected_inputs = [] + for i in range(len(self.expected_inputs), + len(self.inputs)): + inp = self.inputs[i] + if isinstance(inp, str): + inp = (inp, None) + elif hasattr(inp, 'add_to'): + # OnnxOperator + existing = set(_[0] for _ in self.expected_inputs) + i = 10 + name = "input%d" % (10 + i) + while name in existing: + i += 1 + name = "input%d" % (10 + i) + inp = (name, None) + self.expected_inputs.append(inp) + + self._post_process_attributes() + self._check() + + def _check(self): + input_types = (Variable, OnnxOperator, numpy.ndarray) + for o in self.inputs: + if not isinstance(o, input_types): + raise TypeError( + "Wrong type for inputs %r." % ( + self.inputs, )) + if self.output_names is not None: + for o in self.output_names: + if not isinstance(o, Variable): + raise TypeError( + "Wrong type for output_names %r." % ( + self.output_names, )) + + def _post_process_attributes(self): + """ + Walks through attributes and replaces them by ONNX values. + """ + # Looks into attributes if there is any tuple + # (GraphProto, OnnxOperator). In that case, the function + # replaces the tuple by the graph proto and keeps + # in attributes graph_algebra the OnnxOperator + # which is the source of it. + updates = {} + graph_algebra = {} + for k, v in self.kwargs.items(): + if isinstance(v, tuple) and isinstance(v[0], GraphProto): + updates[k] = v[0] + graph_algebra[k] = v[1] + + if len(graph_algebra) > 0: + self.kwargs.update(updates) + self.graph_algebra = graph_algebra + + if self.__class__.__name__ == "OnnxConstantOfShape": + if "value" in self.kwargs: + value = self.kwargs['value'] + if isinstance(value, TensorProto): + return + if isinstance(value, numpy.ndarray): + if value.shape == (1, ): + val = value[0] + elif len(value.shape) == 0: + val = value + else: + raise RuntimeError( + "Unexpected shape %r for value, it must be " + "an array of one element." % value.shape) + self.kwargs['value'] = from_array( + numpy.array([val], dtype=value.dtype)) + return + raise TypeError( + "Unexpected type %r for value. It should be an array " + "of one element." % type(value)) + return + + if self.__class__.__name__ == "OnnxCast": + if "to" in self.kwargs: + value = self.kwargs['to'] + if not isinstance(value, int): + try: + to = numpy_type_prototype(value) + except ValueError as e: + raise ValueError( + "Unable to convert argument to in operator cast, " + "type is %r, value is %r." % (type(value), value)) from e + self.kwargs['to'] = to + return + + def find_schema(self, op_version): + """ + Checks if there is an existing schema for a + specific version. + + :param op_version: requested version + :return: schema + """ + if not hasattr(self.__class__, 'past_version'): + raise RuntimeError("Missing attribute 'past_version', there is " + "no other available schema.") + found = None + for v in self.past_version.values(): + if v.since_version > op_version: + continue + if found is None or v.since_version > found.since_version: + found = v + if found is None: + raise RuntimeError( + "Operator '{}': requested version {} < " + "{} schema version.".format( + self.__class__.__name__, + op_version, self.since_version)) + return found + + def __str__(self): + """ + usual + """ + return "{}({} in) -> {}".format( + self.__class__.__name__, + len(self.inputs) if self.inputs is not None else 0, + [str(o) for o in self.output_names] + if self.output_names is not None else "?") + + def set_onnx_name_prefix(self, onnx_prefix_name): + """ + Provides a name to define a prefix in the onnx graph + to avoid to get unreadable node names. The method + does not overwrite an existing name, it propagates + the prefix to inputs and stops the propagation + if the prefix is already defined. + """ + if self.onnx_prefix_name is None: + self.onnx_prefix_name = onnx_prefix_name + for inp in self.inputs: + if hasattr(inp, 'set_onnx_prefix_name'): + inp.set_onnx_name_prefix(onnx_prefix_name) + return self + + @property + def onnx_prefix(self): + "Returns a prefix for results coming out from this node." + if self.onnx_prefix_name is None: + name = self.__class__.__name__ + if name.startswith("Onnx"): + name = name[4:] + return 'out_' + name[:3].lower() + return self.onnx_prefix_name + + def __getitem__(self, index): + """ + Returns an accessor to one of the output + of this node. + """ + return OnnxOperatorItem(self, index, self.op_version) + + def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None): + """ + Builds a graph as a list of nodes to walk through in that order. + """ + def _preprocess_list(inputs): + new_inputs = {} + for el in inputs: + if isinstance(el, str): + new_inputs[el] = Variable(el) + elif isinstance(el, Variable): + new_inputs[el.name] = el + else: + raise TypeError( + "Unable to handle input type %r (%r)." % ( + type(el), el)) + return new_inputs + + def _process_input(inputs, set_inputs, inp, new_inputs): + if isinstance(inp, OnnxOperator): + new_stack.append(inp) + elif isinstance(inp, Variable): + if inp.name in set_inputs: + return + set_inputs.add(inp.name) + if inputs is None: + new_inputs.append(inp) + elif isinstance(inputs, dict): + if inp.name in inputs: + new_inputs.append(inp.copy_merge(inputs[inp.name])) + else: + raise ValueError( # pragma: no cover + "Unable to find input %r in %r." % ( + inp, inputs)) + elif is_numpy_dtype(inputs): + new_inputs.append(inp.copy_add(inputs)) + elif isinstance(inputs, Variable): + if inp.name == inputs.name: + new_inputs.append(inp.copy_merge(inputs)) + else: + new_inputs.append(inp) + else: + raise RuntimeError( # pragma: no cover + "Unable to handle inputs=%r." % inputs) + elif isinstance(inp, numpy.ndarray): + pass + else: + raise TypeError( + "Unexpected input type %r in node type %r." % ( + type(inp), type(obj))) + + node_outputs = [self] + if other_outputs is not None: + node_outputs += other_outputs + + # preprocess inputs, outputs + _keep_inputs = None + if isinstance(inputs, list): + _keep_inputs = inputs + inputs = _preprocess_list(inputs) + _keep_outputs = None + if isinstance(outputs, list): + _keep_outputs = outputs + outputs = _preprocess_list(outputs) + + # walk through graphs + stack = list(node_outputs) + new_inputs = [] + set_inputs = set() + memo = [] + while len(stack) > 0: + memo.extend(stack) + new_stack = [] + for obj in stack: + for inp in obj.inputs: + _process_input(inputs, set_inputs, inp, new_inputs) + stack = new_stack + + if len(new_inputs) == 0: + raise RuntimeError( + "No detected inputs inputs=%r outputs=%r." % ( + inputs, outputs)) + + # eliminate duplicates + done = set() + nodes = [] + for node in reversed(memo): + if id(node) in done: + continue + done.add(id(node)) + nodes.append(node) + + def _get_type(node, name=None, outputs=None): + if outputs is None: + raise NotImplementedError( + "outputs is None, expected_outputs=%r" % ( + node.expected_outputs, )) + if isinstance(outputs, Variable): + if name is None: + return outputs.dtype + if isinstance(name, Variable): + return outputs.dtype or name.dtype + else: + raise RuntimeError( # pragma: no cover + "Unable to handle outputs=%r." % outputs) + if isinstance(outputs, dict): + if name is None: + raise RuntimeError( + "Unable to get type among %r, name=None." % ( + outputs, )) + if isinstance(name, Variable): + n = name.name + else: + n = name + if n not in outputs: + raise ValueError( # pragma: no cover + "Unable to find %r in %r." % ( + name, outputs)) + return outputs[n] + if isinstance(outputs, list): + raise NotImplementedError( + "Unexpected type for name=%r, outputs=%r." % ( + name, outputs)) + if is_numpy_dtype(outputs): + return outputs + raise RuntimeError( # pragma: no cover + "Unable to handle outputs=%r." % outputs) + + # outputs + new_outputs = [] + for node in node_outputs: + if node.output_names is None: + n = self.output_range[0] + for i in range(n): + to = _get_type(node, outputs=outputs) + res = ('out%d' % i, to) + new_outputs.append(Variable(res[0], added_dtype=to)) + else: + for o in self.output_names: + to = _get_type(node, o, outputs=outputs) + res = (o, to) + new_outputs.append(o.copy_merge(to)) + if len(new_outputs) == 0: + raise RuntimeError( + "No detected outputs inputs=%r outputs=%r." % ( + inputs, outputs)) + + return nodes, new_inputs, new_outputs + + def add_to(self, builder): + """ + Adds to graph builder. + + :param builder: instance of @see cl _GraphBuilder, + it must have a method `add_node` + """ + inputs = builder.get_input_names(self, self.inputs) + n_outputs = ( + self.output_range[0] if self.output_names is None + else len(self.output_names)) + outputs = [builder.get_output_name(self, i) for i in range(n_outputs)] + builder.add_node( + self.operator_name, + builder.get_unique_name('_' + self.operator_name.lower()), + inputs, outputs, domain=self.domain, opset=self.op_version, + **self.kwargs) + + def to_onnx(self, inputs=None, outputs=None, + other_outputs=None, target_opset=None, + verbose=0): + """ + Converts this operator into an ONNX graph. + + :param inputs: information about type + :param outputs: information about types + :param other_outputs: additional nodes to consider + as graph outputs but not outputs of this particular + node + :param target_opset: dictionary with target opset per domain, + None for the default one + :param verbose: prints information + """ + # opsets + if isinstance(target_opset, dict): + dom = self.domain or '' + target_opset = target_opset.get(dom, None) + elif isinstance(target_opset, int): + if self.domain not in ('', None): + # The target_opset is for the domain '' we ignore it. + target_opset = None + elif target_opset is not None: + raise TypeError( + "target_opset must be a dictionary {domain: " + "target_opset} not %r for operator %r." % ( + target_opset, self.__class__.__name__)) + + if self.domain in ('', None) and target_opset == 1: + raise RuntimeError("target_opset cannot be 1.") + if (self.op_version is not None and target_opset is not None and + self.op_version > target_opset): + raise RuntimeError( + "target_opset={} is lower than the version={} requested " + "for this node '{}'.".format( + target_opset, self.op_version, self.__class__.__name__)) + + # get the graph + nodes, graph_inputs, graph_outputs = self._node_to_graph( + other_outputs, inputs, outputs) + if len(nodes) == 0: + raise RuntimeError( # pragma: no cover + "Node list is empty.") + if verbose > 1: + for i, n in enumerate(nodes): + print("nodes[%d]=%r" % (i, n)) + for i, n in enumerate(graph_inputs): + print("graph_inputs[%d]=%r" % (i, n)) + builder = _GraphBuilder() + for node in nodes: + node.add_to(builder) + + return builder.to_onnx(inputs=graph_inputs, + outputs=graph_outputs, + target_opset=target_opset, + verbose=verbose) + + +def _default_OPSET_TO_IR_VERSION(): + """ + Returns the default mapping between opset and ir_version. + + .. runpython:: + :showcode: + + import pprint + from mlprodict.npy.xop_graph_builder import _default_OPSET_TO_IR_VERSION + pprint.pprint(_default_OPSET_TO_IR_VERSION()) + """ + return { + 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, + 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7, + 13: 7, 14: 7, 15: 8 + } + + +class _GraphBuilder: + """ + Graph builder. + """ + + def __init__(self): + self.initializer = [] + self.node = [] + self.input = [] + self.output = [] + self.opsets = {} + self.names = set() + self.input_names = {} + self.output_names = {} + self.output_names_rev = {} + + @staticmethod + def number2alpha(index): + """ + Converts a numbers into a string keeping the same + alphabetical order. + """ + dec = str(int(index)) + if len(dec) == 1: + return dec + return chr(96 + len(dec)) + dec + + def get_unique_name(self, name): + """ + Returns a unique name to name an output. + """ + if not isinstance(name, str): + raise TypeError( # pragma: no cover + "name must be a string not %r." % type(name)) + if name not in self.names: + self.names.add(name) + return name + i = 1 + new_name = "%s_%s" % (name, self.number2alpha(i)) + while new_name in self.names: + i += 1 + new_name = "%s_%s" % (name, self.number2alpha(i)) + self.names.add(new_name) + return new_name + + def get_output_name(self, node, index): + """ + Returns the output name for a node. + """ + key = id(node), index + if key in self.output_names: + name = self.output_names[key] + return name + + if node.output_names is None: + prefix = node.onnx_prefix + n = '%s%d' % (prefix, index) + else: + output = node.output_names[index] + if isinstance(output, Variable): + n = output.name + else: + raise TypeError( # pragma: no cover + "Unexpected type %r for output %d (output_names=%r)." % ( + type(output), index, node.output_names)) + + name = self.get_unique_name(n) + self.output_names[key] = name + self.output_names_rev[name] = key + if node.output_names is not None: + var = node.output_names[index] + if isinstance(var, Variable): + var = var.name + if var != name: + raise RuntimeError( + "Output unique name %r is different from the " + "expected name %r at position %r." % ( + name, node.output_names[index], index)) + return name + + def get_input_names(self, node, inputs): + """ + Returns input names for node *node* and inputs *inputs*. + + :param node: node + :param inputs: inputs + :return: name + """ + names = [] + for i in inputs: + if isinstance(i, Variable): + names.append(i.name) + self.names.add(i.name) + self.input_names[i.name] = i + elif isinstance(i, OnnxOperator): + name = self.get_output_name(i, 0) + names.append(name) + self.names.add(name) + elif isinstance(i, OnnxOperatorItem): + name = self.get_output_name(i.onnx_op, i.index) + names.append(name) + self.names.add(name) + elif isinstance(i, numpy.ndarray): + # Adding an initializer + name = self.get_unique_name('init') + init = from_array(i, name) + self.initializer.append(init) + names.append(name) + self.names.add(name) + else: + raise TypeError( # pragma: no cover + "Unexpected type for an input %r." % type(i)) + return names + + def add_node(self, op_type, name, inputs, outputs, domain='', + opset=None, **attributes): + """ + Adds a node to the graph. + + :param op_type: operator type + :param name: node name + :param inputs: inputs name list + :param outputs: outputs name list + :param domain: node domain + :param opset: node opset + """ + if not isinstance(inputs, list): + raise TypeError( # pragma: no cover + "inputs must be a list not %r." % type(inputs)) + if not isinstance(outputs, list): + raise TypeError( # pragma: no cover + "inputs must be a list not %r." % type(outputs)) + if any(map(lambda x: not isinstance(x, str), inputs)): + raise TypeError( # pragma: no cover + "inputs must be all strings not %r." % inputs) + if any(map(lambda x: not isinstance(x, (str, Variable)), outputs)): + raise TypeError( # pragma: no cover + "outputs must be all strings not %r." % outputs) + if opset is not None: + if domain not in self.opsets: + self.opsets[domain] = opset + else: + self.opsets[domain] = max(opset, self.opsets[domain]) + node = make_node(op_type, inputs, outputs, name=name, + domain=domain, **attributes) + self.node.append(node) + + def _process_io(self, inputs, input_names): + if inputs is None: + return [ + make_tensor_value_info( + 'X', TensorProto.FLOAT, None) # pylint: disable=E1101 + for name in self.input_names] + + if not isinstance(inputs, list): + if is_numpy_dtype(inputs): + inputs = [inputs] + + if input_names is None: + # outputs + input_names = [] + for inp in inputs: + if isinstance(inp, Variable): + if inp.name in self.output_names_rev: + input_names.append(inp) + elif isinstance(inp, tuple) and len(inp) == 2: + var, dtype = inp + if var.name in self.output_names_rev: + input_names.append(Variable(var.name, dtype)) + else: + raise TypeError( + "Unexpected type %r in %r." % (inp, inputs)) + if len(input_names) == 0: + raise RuntimeError( + "Unable to cross %r and %r." % (input, self.output_names_rev)) + elif not isinstance(input_names, list): + raise RuntimeError( + "Unexpected type for input_names %r." % type(input_names)) + + if len(input_names) != len(inputs): + raise RuntimeError( # pragma: no cover + "Mismatch between %r and %r." % ( + input_names, inputs)) + + res = [] + for inp, var in zip(inputs, input_names): + if isinstance(inp, (str, tuple)): + raise TypeError( + "inp not Variable but %r (%r)." % (type(inp), inp)) + if isinstance(var, (str, tuple)): + raise TypeError( + "var not Variable but %r (%r)." % (type(var), var)) + if isinstance(var, (str, tuple)): + raise TypeError( + "var not Variable but %r (%r)." % (type(var), var)) + # inp: Variable + # var: str + if inp != var: + raise RuntimeError( + "Unexpected %r != %r." % (inp, var)) + res.append(make_tensor_value_info( + inp.name, inp.proto_added_type, None)) + + return res + + def to_onnx(self, inputs=None, outputs=None, + target_opset=None, verbose=0): + """ + Converts this operator into an ONNX graph. + + :param inputs: specific inputs (as a dictionary) or + default inputs if not specified + :param outputs: specific outputs + :param target_opset: dictionary with target opset per domain, + None for the default one + :param verbose: prints information + :return: onnx graph + """ + # inputs and outputs + self.input = self._process_io(inputs, list(self.input_names.values())) + self.output = self._process_io(outputs, None) + + graph = make_graph( + self.node, 'XOP', self.input, self.output, self.initializer) + onnx_model = make_model(graph) + opv = self.opsets.get('', max_supported_opset()) + opset2ir = _default_OPSET_TO_IR_VERSION() + irv = opset2ir.get(opv, max(opset2ir.values())) + onnx_model.ir_version = irv + + del onnx_model.opset_import[:] # pylint: disable=E1101 + for k, v in self.opsets.items(): + op_set = onnx_model.opset_import.add() # pylint: disable=E1101 + op_set.domain = k or '' + op_set.version = v + return onnx_model diff --git a/mlprodict/npy/xops_opset.py b/mlprodict/npy/xop_opset.py similarity index 100% rename from mlprodict/npy/xops_opset.py rename to mlprodict/npy/xop_opset.py diff --git a/mlprodict/npy/xop_variable.py b/mlprodict/npy/xop_variable.py index 8e3aea7b3..27c7f0c10 100644 --- a/mlprodict/npy/xop_variable.py +++ b/mlprodict/npy/xop_variable.py @@ -29,7 +29,7 @@ def is_numpy_dtype(dtype): :param dtype: anything :return: boolean """ - if isinstance(dtype, (list, dict)): + if isinstance(dtype, (list, dict, Variable)): return False if dtype in NP_TYPE_TO_TENSOR_TYPE: return True @@ -57,10 +57,24 @@ def numpy_type_prototype(dtype): class Variable: """ - An input to an ONNX graph. + An input or output to an ONNX graph. + + :param name: name + :param dtype: :epkg:`numpy` dtype (can be None) + :param shape: shape (can be None) + :param added_dtype: :epkg:`numpy` dtype specified at conversion type + (can be None) """ def __init__(self, name, dtype=None, shape=None, added_dtype=None): + if dtype is not None: + if isinstance(dtype, (int, Variable, tuple)): + raise TypeError( + "Unexpected type %r for dtype." % type(dtype)) + if added_dtype is not None: + if isinstance(added_dtype, (int, Variable, tuple)): + raise TypeError( + "Unexpected type %r for added_dtype." % type(added_dtype)) self.name_ = name self.dtype_ = dtype self.added_dtype_ = added_dtype @@ -68,9 +82,14 @@ def __init__(self, name, dtype=None, shape=None, added_dtype=None): @property def name(self): - "Returns the variable name." + "Returns the variable name (`self.name_`)." return self.name_ + @property + def dtype(self): + "Returns `self.dtype_`." + return self.dtype_ + @property def proto_type(self): "Returns the proto type for `self.dtype_`." @@ -117,6 +136,20 @@ def copy_add(self, dtype): "Cannot copy as added_dtype is not None.") return Variable(self.name_, self.dtype_, self.shape_, dtype) + def copy_merge(self, var): + """ + Merges information from both Variable. + """ + if not isinstance(var, Variable): + return self.copy_add(var) + res = Variable(self.name_, self.dtype_, + self.shape_, self.added_dtype_) + if self.added_dtype_ is None and var.dtype_ is not None: + res.added_dtype_ = var.dtype_ + if self.shape_ is None and var.shape_ is not None: + res.shape_ = var.shape_ + return res + def __eq__(self, other): """ Compares every attributes. From c5237f68d285a1db7336b192d3a13f1d8f55d86d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 17 Feb 2022 01:38:29 +0100 Subject: [PATCH 10/13] support for if --- _unittests/ut_npy/test_xop.py | 66 +++++++++++++++++++++++++++++++++++ mlprodict/npy/xop_ops.py | 23 ++++++------ 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index 51d437d40..29da99fbe 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -147,6 +147,72 @@ def test_onnx_var_list(self): got = oinf.run({'X': x}) self.assertEqualArray(x.astype(numpy.int64), got['Y']) + def test_if(self): + OnnxConstant, OnnxIf, OnnxGreater = loadop( + "OnnxConstant", "OnnxIf", "OnnxGreater") + bthen = OnnxConstant( + value_floats=numpy.array([0], dtype=numpy.float32), + output_names=['res_then']) + bthen.set_onnx_name_prefix('then') + + belse = OnnxConstant( + value_floats=numpy.array([1], dtype=numpy.float32), + output_names=['res_else']) + belse.set_onnx_name_prefix('else') + + bthen_body = bthen.to_onnx( + [], [Variable('res_then', numpy.float32)]) + belse_body = belse.to_onnx( + [], [Variable('res_else', numpy.float32)]) + + onx = OnnxIf( + OnnxGreater('X', numpy.array([0], dtype=numpy.float32)), + output_names=['Z'], + then_branch=bthen_body.graph, + else_branch=belse_body.graph) + + x = numpy.array([1, 2], dtype=numpy.float32) + model_def = onx.to_onnx({'X': numpy.float32}, {'Z': numpy.float32}) + got = OnnxInference(model_def).run({'X': x}) + self.assertEqualArray(numpy.array([0.], dtype=numpy.float32), + got['Z']) + + x = numpy.array([-1, -2], dtype=numpy.float32) + y = numpy.array([-1, -3], dtype=numpy.float32) + model_def = onx.to_onnx({'X': numpy.float32}, {'Z': numpy.float32}) + got = OnnxInference(model_def).run({'X': x}) + self.assertEqualArray( + numpy.array([1.], dtype=numpy.float32), got['Z']) + + def test_if2(self): + OnnxAdd, OnnxSub, OnnxIf, OnnxGreater, OnnxReduceSum = loadop( + "OnnxAdd", "OnnxSub", "OnnxIf", "OnnxGreater", "OnnxReduceSum") + + node = OnnxAdd('x1', 'x2', output_names=['absxythen']) + then_body = node.to_onnx( + [Variable('x1', numpy.float32), + Variable('x2', numpy.float32)], + {'absxythen': numpy.float32}) + node = OnnxSub('x1', 'x2', output_names=['absxyelse']) + else_body = node.to_onnx( + [Variable('x1', numpy.float32), + Variable('x2', numpy.float32)], + {'absxyelse': numpy.float32}) + del else_body.graph.input[:] + del then_body.graph.input[:] + + cond = OnnxGreater(OnnxReduceSum('x1'), OnnxReduceSum('x2')) + ifnode = OnnxIf(cond, then_branch=then_body.graph, + else_branch=else_body.graph, + output_names=['y']) + model_def = ifnode.to_onnx( + [Variable('x1', numpy.float32), + Variable('x2', numpy.float32)], + {'y': numpy.float32}) + oinf = OnnxInference(model_def) + dot = oinf.to_dot() + self.assertIn("out_red0 -> _greater;", dot) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/npy/xop_ops.py b/mlprodict/npy/xop_ops.py index 47be11dbf..290e1fccb 100644 --- a/mlprodict/npy/xop_ops.py +++ b/mlprodict/npy/xop_ops.py @@ -465,11 +465,6 @@ def _process_input(inputs, set_inputs, inp, new_inputs): _process_input(inputs, set_inputs, inp, new_inputs) stack = new_stack - if len(new_inputs) == 0: - raise RuntimeError( - "No detected inputs inputs=%r outputs=%r." % ( - inputs, outputs)) - # eliminate duplicates done = set() nodes = [] @@ -815,15 +810,21 @@ def _process_io(self, inputs, input_names): "Mismatch between %r and %r." % ( input_names, inputs)) + if isinstance(input_names, list): + d_input_names = {inp.name: inp for inp in input_names} + elif isinstance(input_names, dict): + d_input_names = input_names + else: + raise TypeError( + "Unexpected type for input_names %r (%r)." % ( + type(input_names), input_names)) res = [] - for inp, var in zip(inputs, input_names): - if isinstance(inp, (str, tuple)): + for inp in inputs: + if not isinstance(inp, Variable): raise TypeError( "inp not Variable but %r (%r)." % (type(inp), inp)) - if isinstance(var, (str, tuple)): - raise TypeError( - "var not Variable but %r (%r)." % (type(var), var)) - if isinstance(var, (str, tuple)): + var = d_input_names[inp.name] + if not isinstance(var, Variable): raise TypeError( "var not Variable but %r (%r)." % (type(var), var)) # inp: Variable From 34cb616f2b82c5f430c14a1f424a6020d0b63fb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 17 Feb 2022 13:00:31 +0100 Subject: [PATCH 11/13] fix shapes --- _unittests/ut_npy/test_xop.py | 54 +++++++++++++++++++++++++- mlprodict/npy/xop_ops.py | 2 +- mlprodict/npy/xop_variable.py | 54 +++++++++++++++++++------- mlprodict/onnx_tools/onnx2py_helper.py | 39 +++++++++++++++++-- 4 files changed, 129 insertions(+), 20 deletions(-) diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index 29da99fbe..b58aa96ed 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -4,11 +4,14 @@ """ import unittest import numpy +from onnx import TensorProto from pyquickhelper.pycode import ExtTestCase from mlprodict.npy.xop import loadop from mlprodict.npy.xop_variable import Variable from mlprodict.npy.xop_ops import _GraphBuilder from mlprodict.onnxrt import OnnxInference +from mlprodict.plotting.text_plot import onnx_simple_text_plot +from mlprodict.onnx_tools.onnx2py_helper import get_dtype_shape class TestXOps(ExtTestCase): @@ -178,7 +181,6 @@ def test_if(self): got['Z']) x = numpy.array([-1, -2], dtype=numpy.float32) - y = numpy.array([-1, -3], dtype=numpy.float32) model_def = onx.to_onnx({'X': numpy.float32}, {'Z': numpy.float32}) got = OnnxInference(model_def).run({'X': x}) self.assertEqualArray( @@ -213,6 +215,56 @@ def test_if2(self): dot = oinf.to_dot() self.assertIn("out_red0 -> _greater;", dot) + def test_onnx_abs_shape_variable(self): + OnnxAbs = loadop("OnnxAbs") + ov = OnnxAbs('X', output_names=['Y']) + onx = ov.to_onnx([Variable('X', numpy.float32, [1, 2])], + [Variable('Y', numpy.float32, [1, 2])], + verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([[-2, 2]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x), got['Y']) + self.assertIn("input: name='X'", onnx_simple_text_plot(onx)) + dtype, shape = get_dtype_shape(onx.graph.input[0]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (1, 2)) + dtype, shape = get_dtype_shape(onx.graph.output[0]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (1, 2)) + + def test_onnx_abs_shape_variable_batch(self): + OnnxAbs = loadop("OnnxAbs") + ov = OnnxAbs('X', output_names=['Y']) + onx = ov.to_onnx([Variable('X', numpy.float32, [None, 2])], + [Variable('Y', numpy.float32, [None, 2])], + verbose=0) + oinf = OnnxInference(onx) + x = numpy.array([[-2, 2]], dtype=numpy.float32) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x), got['Y']) + dtype, shape = get_dtype_shape(onx.graph.input[0]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (None, 2)) + dtype, shape = get_dtype_shape(onx.graph.output[0]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (None, 2)) + + def test_onnx_abs_shape_numpy(self): + OnnxAbs = loadop("OnnxAbs") + ov = OnnxAbs('X', output_names=['Y']) + x = numpy.array([-2, 2], dtype=numpy.float32) + onx = ov.to_onnx({'X': x}, {'Y': x}, verbose=0) + oinf = OnnxInference(onx) + got = oinf.run({'X': x}) + self.assertEqualArray(numpy.abs(x), got['Y']) + dtype, shape = get_dtype_shape(onx.graph.input[0]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (2, )) + dtype, shape = get_dtype_shape(onx.graph.output[0]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (2, )) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/npy/xop_ops.py b/mlprodict/npy/xop_ops.py index 290e1fccb..c6b921049 100644 --- a/mlprodict/npy/xop_ops.py +++ b/mlprodict/npy/xop_ops.py @@ -833,7 +833,7 @@ def _process_io(self, inputs, input_names): raise RuntimeError( "Unexpected %r != %r." % (inp, var)) res.append(make_tensor_value_info( - inp.name, inp.proto_added_type, None)) + inp.name, inp.proto_added_type, inp.proto_added_shape)) return res diff --git a/mlprodict/npy/xop_variable.py b/mlprodict/npy/xop_variable.py index 27c7f0c10..f0e2845a5 100644 --- a/mlprodict/npy/xop_variable.py +++ b/mlprodict/npy/xop_variable.py @@ -64,21 +64,33 @@ class Variable: :param shape: shape (can be None) :param added_dtype: :epkg:`numpy` dtype specified at conversion type (can be None) + :param added_shape: :epkg:`numpy` shape specified at conversion type + (can be None) """ - def __init__(self, name, dtype=None, shape=None, added_dtype=None): - if dtype is not None: - if isinstance(dtype, (int, Variable, tuple)): - raise TypeError( - "Unexpected type %r for dtype." % type(dtype)) - if added_dtype is not None: - if isinstance(added_dtype, (int, Variable, tuple)): - raise TypeError( - "Unexpected type %r for added_dtype." % type(added_dtype)) + def __init__(self, name, dtype=None, shape=None, added_dtype=None, + added_shape=None): + if (dtype is not None and isinstance( + dtype, (int, Variable, tuple, numpy.ndarray))): + raise TypeError( + "Unexpected type %r for dtype." % type(dtype)) + if (added_dtype is not None and isinstance( + added_dtype, (int, Variable, tuple, numpy.ndarray))): + raise TypeError( + "Unexpected type %r for added_dtype." % type(added_dtype)) + if shape is not None and not isinstance(shape, (tuple, list)): + raise TypeError( + "Unexpected type %r for shape." % type(shape)) + if (added_shape is not None and not isinstance( + added_shape, (tuple, list))): + raise TypeError( + "Unexpected type %r for added_shape." % type(added_shape)) + self.name_ = name self.dtype_ = dtype self.added_dtype_ = added_dtype self.shape_ = shape + self.added_shape_ = added_shape @property def name(self): @@ -105,10 +117,19 @@ def proto_added_type(self): return 0 return numpy_type_prototype(dt) + @property + def proto_added_shape(self): + "Returns the shape for `self.added_shape_` or `self.shape`." + dt = self.added_shape_ or self.shape_ + if dt is None: + return None + return list(dt) + def __repr__(self): "usual" kwargs = dict(dtype=self.dtype_, shape=self.shape_, - added_dtype=self.added_dtype_) + added_dtype=self.added_dtype_, + added_shape=self.added_shape_) kwargs = {k: v for k, v in kwargs.items() if v is not None} if len(kwargs) > 0: msg = ", " + ", ".join("%s=%r" % (k, v) for k, v in kwargs.items()) @@ -134,7 +155,11 @@ def copy_add(self, dtype): if self.added_dtype_ is not None: raise RuntimeError( "Cannot copy as added_dtype is not None.") - return Variable(self.name_, self.dtype_, self.shape_, dtype) + if isinstance(dtype, numpy.ndarray): + dtype, shape = dtype.dtype, dtype.shape + else: + shape = None + return Variable(self.name_, self.dtype_, self.shape_, dtype, shape) def copy_merge(self, var): """ @@ -143,11 +168,12 @@ def copy_merge(self, var): if not isinstance(var, Variable): return self.copy_add(var) res = Variable(self.name_, self.dtype_, - self.shape_, self.added_dtype_) + self.shape_, self.added_dtype_, + self.added_shape_) if self.added_dtype_ is None and var.dtype_ is not None: res.added_dtype_ = var.dtype_ - if self.shape_ is None and var.shape_ is not None: - res.shape_ = var.shape_ + if self.added_shape_ is None and var.shape_ is not None: + res.added_shape_ = var.shape_ return res def __eq__(self, other): diff --git a/mlprodict/onnx_tools/onnx2py_helper.py b/mlprodict/onnx_tools/onnx2py_helper.py index 435840c75..4c1b91090 100644 --- a/mlprodict/onnx_tools/onnx2py_helper.py +++ b/mlprodict/onnx_tools/onnx2py_helper.py @@ -15,8 +15,8 @@ def to_bytes(val): """ Converts an array into protobuf and then into bytes. - @param val array - @return bytes + :param val: array + :return: bytes .. exref:: :title: Converts an array into bytes (serialization) @@ -75,8 +75,8 @@ def from_bytes(b): """ Retrieves an array from bytes then protobuf. - @param b bytes - @return array + :param b: bytes + :return: array .. exref:: :title: Converts bytes into an array (serialization) @@ -410,6 +410,37 @@ def _var_as_dict(var): "Unable to guess which object it is.\n{}\n---".format(var)) +def get_dtype_shape(obj): + """ + Returns the shape of a tensor. + + :param obj: onnx object + :return: `(dtype, shape)` or `(None, None)` if not applicable + """ + if not hasattr(obj, 'type'): + return None + t = obj.type + if not hasattr(t, 'tensor_type'): + return None + t = t.tensor_type + dtype = t.elem_type + if not hasattr(t, 'shape'): + return dtype, None + shape = t.shape + ds = [] + for dim in shape.dim: + d = dim.dim_value + s = dim.dim_param + if d == 0: + if s == '': + ds.append(None) + else: + ds.append(s) + else: + ds.append(d) + return dtype, tuple(ds) + + def onnx_model_opsets(onnx_model): """ Extracts opsets in a dictionary. From 4a2496ed652c16279d804a764368e75bf6fa6b44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 17 Feb 2022 21:06:35 +0100 Subject: [PATCH 12/13] add scan --- _unittests/ut_npy/test_xop.py | 59 +++++++++++++++++ _unittests/ut_npy/test_xop_doc.py | 6 +- mlprodict/npy/onnx_numpy_compiler.py | 4 +- mlprodict/npy/xop.py | 40 ++++++++++++ mlprodict/npy/xop_ops.py | 95 +++++++++++++++++++--------- 5 files changed, 171 insertions(+), 33 deletions(-) diff --git a/_unittests/ut_npy/test_xop.py b/_unittests/ut_npy/test_xop.py index b58aa96ed..2f40f7318 100644 --- a/_unittests/ut_npy/test_xop.py +++ b/_unittests/ut_npy/test_xop.py @@ -4,6 +4,7 @@ """ import unittest import numpy +from scipy.spatial.distance import squareform, pdist from onnx import TensorProto from pyquickhelper.pycode import ExtTestCase from mlprodict.npy.xop import loadop @@ -265,6 +266,64 @@ def test_onnx_abs_shape_numpy(self): self.assertEqual(dtype, TensorProto.FLOAT) self.assertEqual(shape, (2, )) + def test_scan_pdist(self): + (OnnxSub, OnnxIdentity, OnnxReduceSumSquare, OnnxScan, + OnnxAdd) = loadop('Sub', 'Identity', + 'ReduceSumSquare', 'Scan', 'Add') + + def onnx_squareform_pdist(X, dtype=None, op_version=None, **kwargs): + diff = OnnxSub('next_in', 'next', + op_version=op_version) + id_next = OnnxIdentity('next_in', output_names=['next_out'], + op_version=op_version) + flat = OnnxReduceSumSquare(diff, axes=[1], op_version=op_version, + output_names=['scan_out'], keepdims=0) + scan_body = id_next.to_onnx( + [Variable('next_in', numpy.float32, (None, None)), # tensor_type([None, None])), + Variable('next', numpy.float32, (None, ))], # tensor_type([None]))]), + outputs=[Variable('next_out', numpy.float32, (None, None)), # ([None, None])), + Variable('scan_out', numpy.float32, (None, ))], # tensor_type([None]))], + other_outputs=[flat], + target_opset=op_version) + output_names = [o.name for o in scan_body.graph.output] + self.assertEqual(['next_out', 'scan_out'], output_names) + dtype, shape = get_dtype_shape(scan_body.graph.output[0]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (None, None)) + dtype, shape = get_dtype_shape(scan_body.graph.output[1]) + self.assertEqual(dtype, TensorProto.FLOAT) + self.assertEqual(shape, (None, )) + + node = OnnxScan(X, X, output_names=['S1', 'S2'], + num_scan_inputs=1, + body=(scan_body.graph, [id_next, flat]), + op_version=op_version, **kwargs) + return node[1] + + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxAdd('input', 'input') + cdist = onnx_squareform_pdist(cop, dtype=numpy.float32) + cop2 = OnnxIdentity(cdist, output_names=['cdist']) + + model_def = cop2.to_onnx( + {'input': numpy.float32}, + outputs=[Variable('cdist', numpy.float32)]) + + sess = OnnxInference(model_def) + res = sess.run({'input': x}) + self.assertEqual(list(res.keys()), ['cdist']) + exp = squareform(pdist(x * 2, metric="sqeuclidean")) + self.assertEqualArray(exp, res['cdist']) + + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((2, 3)) + res = sess.run({'input': x}) + self.assertEqual(list(res.keys()), ['cdist']) + exp = squareform(pdist(x * 2, metric="sqeuclidean")) + self.assertEqualArray(exp, res['cdist']) + if __name__ == "__main__": + # TestXOps().test_scan_pdist() unittest.main() diff --git a/_unittests/ut_npy/test_xop_doc.py b/_unittests/ut_npy/test_xop_doc.py index 713283825..4dea23683 100644 --- a/_unittests/ut_npy/test_xop_doc.py +++ b/_unittests/ut_npy/test_xop_doc.py @@ -3,7 +3,7 @@ """ import unittest from pyquickhelper.pycode import ExtTestCase -from mlprodict.npy.xop import _dynamic_class_creation +from mlprodict.npy.xop import _dynamic_class_creation, Xop from mlprodict.npy.xop_auto import get_rst_doc @@ -22,6 +22,10 @@ def test_auto_import(self): from mlprodict.npy.xop_auto_import_ import OnnxAdd # pylint: disable=E0611 self.assertEqual(OnnxAdd.__name__, 'OnnxAdd') + def test_loading_factory(self): + Add = Xop.Add + self.assertEqual(Add.__name__, 'OnnxAdd') + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index aba9da898..fd4fc3d9d 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -7,9 +7,7 @@ import inspect from typing import Any import numpy -from ..tools.ort_wrapper import InferenceSession from ..onnx_tools.optim._main_onnx_optim import onnx_optimisations -from ..onnxrt import OnnxInference from .onnx_version import FctVersion from .onnx_numpy_annotation import get_args_kwargs @@ -437,11 +435,13 @@ def _build_runtime(self, op_version=None, runtime=None, inputs, outputs, _, n_optional, n_variables = self._parse_annotation( signature=signature, version=version) if runtime != 'onnxruntime': + from ..onnxrt import OnnxInference rt = OnnxInference(onx, runtime=runtime) self.rt_fct_ = OnnxNumpyFunctionOnnxInference( self, rt, inputs=inputs, outputs=outputs, n_optional=n_optional, n_variables=n_variables) else: + from ..tools.ort_wrapper import InferenceSession rt = InferenceSession(onx.SerializeToString()) self.rt_fct_ = OnnxNumpyFunctionInferenceSession( self, rt, inputs=inputs, outputs=outputs, diff --git a/mlprodict/npy/xop.py b/mlprodict/npy/xop.py index 66b44ac30..3908ac677 100644 --- a/mlprodict/npy/xop.py +++ b/mlprodict/npy/xop.py @@ -293,3 +293,43 @@ def loadop(*names, cache=False, verbose=0, fLOG=print): if len(res) == 1: return res[0] return res + + +class OnnxLoadFactory: + """ + Automatically creating all operators from onnx packages + takes time. That's why function @see cl loadop only creates + classes for the requested operators. This class does the same + when an attributes is requested. + + :: + + cl = OnnxLoadOperators() + x = cl.Add(...) + + It is equivalent to: + + :: + + OnnxAdd = loadop('Add') + x = OnnxAdd(...) + """ + + def __init__(self): + self._loaded_classes = {} + + def __getattr__(self, name): + """ + + """ + if name == '_loaded_classes': + return self._loaded_classes + if name in self._loaded_classes: + return self._loaded_classes[name] + cl = loadop(name) + self._loaded_classes[name] = cl + self._loaded_classes[cl.__name__] = cl + return cl + + +onnx_load_factory = Xop = OnnxLoadFactory() diff --git a/mlprodict/npy/xop_ops.py b/mlprodict/npy/xop_ops.py index c6b921049..38bbea5eb 100644 --- a/mlprodict/npy/xop_ops.py +++ b/mlprodict/npy/xop_ops.py @@ -32,6 +32,21 @@ def __init__(self, onx_op, index, op_version=None): self.index = index self.op_version = op_version + @property + def inputs(self): + "Returns the only inputs in a list." + inp = self.onx_op.inputs + return [inp[self.index]] + + def add_to(self, builder): + """ + Adds to graph builder. + + :param builder: instance of @see cl _GraphBuilder, + it must have a method `add_node` + """ + self.onx_op.add_to(builder) + def __str__(self): """ usual @@ -253,7 +268,8 @@ def __init__(self, *inputs, op_version=None, output_names=None, self._check() def _check(self): - input_types = (Variable, OnnxOperator, numpy.ndarray) + input_types = (Variable, OnnxOperator, + OnnxOperatorItem, numpy.ndarray) for o in self.inputs: if not isinstance(o, input_types): raise TypeError( @@ -388,6 +404,24 @@ def __getitem__(self, index): """ return OnnxOperatorItem(self, index, self.op_version) + def add_to(self, builder): + """ + Adds to graph builder. + + :param builder: instance of @see cl _GraphBuilder, + it must have a method `add_node` + """ + inputs = builder.get_input_names(self, self.inputs) + n_outputs = ( + self.output_range[0] if self.output_names is None + else len(self.output_names)) + outputs = [builder.get_output_name(self, i) for i in range(n_outputs)] + builder.add_node( + self.operator_name, + builder.get_unique_name('_' + self.operator_name.lower()), + inputs, outputs, domain=self.domain, opset=self.op_version, + **self.kwargs) + def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None): """ Builds a graph as a list of nodes to walk through in that order. @@ -406,7 +440,7 @@ def _preprocess_list(inputs): return new_inputs def _process_input(inputs, set_inputs, inp, new_inputs): - if isinstance(inp, OnnxOperator): + if isinstance(inp, (OnnxOperator, OnnxOperatorItem)): new_stack.append(inp) elif isinstance(inp, Variable): if inp.name in set_inputs: @@ -511,19 +545,30 @@ def _get_type(node, name=None, outputs=None): "Unable to handle outputs=%r." % outputs) # outputs + set_names = set() new_outputs = [] for node in node_outputs: if node.output_names is None: n = self.output_range[0] for i in range(n): to = _get_type(node, outputs=outputs) - res = ('out%d' % i, to) - new_outputs.append(Variable(res[0], added_dtype=to)) + res = 'out%d' % i + var = Variable(res, added_dtype=to) + if var.name in set_names: + raise RuntimeError( + "Duplicated output name var=%r." % var) + set_names.add(var.name) + new_outputs.append(var) else: - for o in self.output_names: + for o in node.output_names: to = _get_type(node, o, outputs=outputs) res = (o, to) - new_outputs.append(o.copy_merge(to)) + var = o.copy_merge(to) + if var.name in set_names: + raise RuntimeError( + "Duplicated output name o=%r var=%r." % (o, var)) + set_names.add(var.name) + new_outputs.append(var) if len(new_outputs) == 0: raise RuntimeError( "No detected outputs inputs=%r outputs=%r." % ( @@ -531,24 +576,6 @@ def _get_type(node, name=None, outputs=None): return nodes, new_inputs, new_outputs - def add_to(self, builder): - """ - Adds to graph builder. - - :param builder: instance of @see cl _GraphBuilder, - it must have a method `add_node` - """ - inputs = builder.get_input_names(self, self.inputs) - n_outputs = ( - self.output_range[0] if self.output_names is None - else len(self.output_names)) - outputs = [builder.get_output_name(self, i) for i in range(n_outputs)] - builder.add_node( - self.operator_name, - builder.get_unique_name('_' + self.operator_name.lower()), - inputs, outputs, domain=self.domain, opset=self.op_version, - **self.kwargs) - def to_onnx(self, inputs=None, outputs=None, other_outputs=None, target_opset=None, verbose=0): @@ -725,7 +752,7 @@ def get_input_names(self, node, inputs): names.append(name) self.names.add(name) elif isinstance(i, OnnxOperatorItem): - name = self.get_output_name(i.onnx_op, i.index) + name = self.get_output_name(i.onx_op, i.index) names.append(name) self.names.add(name) elif isinstance(i, numpy.ndarray): @@ -786,15 +813,17 @@ def _process_io(self, inputs, input_names): if input_names is None: # outputs + set_names = set() input_names = [] for inp in inputs: if isinstance(inp, Variable): + if inp.name in set_names: + raise ValueError( + "Names already taken %r in %r." % ( + inp.name, inputs)) + set_names.add(inp.name) if inp.name in self.output_names_rev: input_names.append(inp) - elif isinstance(inp, tuple) and len(inp) == 2: - var, dtype = inp - if var.name in self.output_names_rev: - input_names.append(Variable(var.name, dtype)) else: raise TypeError( "Unexpected type %r in %r." % (inp, inputs)) @@ -811,13 +840,19 @@ def _process_io(self, inputs, input_names): input_names, inputs)) if isinstance(input_names, list): - d_input_names = {inp.name: inp for inp in input_names} + d_input_names = {} + for inp in input_names: + if inp.name in d_input_names: + raise ValueError( + "Duplicated name %r in %r." % (inp.name, input_names)) + d_input_names[inp.name] = inp elif isinstance(input_names, dict): d_input_names = input_names else: raise TypeError( "Unexpected type for input_names %r (%r)." % ( type(input_names), input_names)) + res = [] for inp in inputs: if not isinstance(inp, Variable): From 252b6f9025061b4c9f53eaab5fcccc01ee17d2fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 18 Feb 2022 01:42:31 +0100 Subject: [PATCH 13/13] Update test_cli_dynamic_doc.py --- _unittests/ut_cli/test_cli_dynamic_doc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_cli/test_cli_dynamic_doc.py b/_unittests/ut_cli/test_cli_dynamic_doc.py index 1253b24ec..581210065 100644 --- a/_unittests/ut_cli/test_cli_dynamic_doc.py +++ b/_unittests/ut_cli/test_cli_dynamic_doc.py @@ -19,7 +19,8 @@ def test_cli_onnx_code(self): st = BufferedPrint() main(args=["dynamic_doc", '--verbose', '1'], fLOG=st.fprint) res = str(st) - self.assertIn("Abs", res) + if len(res) > 0: + self.assertIn("Abs", res) if __name__ == "__main__":