diff --git a/_unittests/ut_onnxrt/test_onnx_inference.py b/_unittests/ut_onnxrt/test_onnx_inference.py index bd5a650e4..53ed163bc 100644 --- a/_unittests/ut_onnxrt/test_onnx_inference.py +++ b/_unittests/ut_onnxrt/test_onnx_inference.py @@ -158,6 +158,8 @@ def test_onnx_inference_verbose_intermediate(self): self.assertIsInstance(inp, list) out = oinf.output_names_shapes self.assertIsInstance(out, list) + out = oinf.output_names_shapes_types + self.assertIsInstance(out, list) if __name__ == "__main__": diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py index a559d0a28..fc98ab6ba 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py @@ -26,6 +26,71 @@ def setUp(self): logger = getLogger('skl2onnx') logger.disabled = True + @ignore_warnings(DeprecationWarning) + def test_sequence_insert(self): + + def expect(node, inputs, outputs, name): + ginputs = [ + make_sequence_value_info( + node.input[0], TensorProto.FLOAT, []), # pylint: disable=E1101, + make_sequence_value_info( + node.input[1], TensorProto.FLOAT, []), # pylint: disable=E1101, + ] + if len(node.input) > 2: + ginputs.append( + make_tensor_value_info( + node.input[2], TensorProto.INT64, []), # pylint: disable=E1101 + ) + goutputs = [ + make_sequence_value_info( + node.output[0], TensorProto.FLOAT, []), # pylint: disable=E1101, + ] + model_def = make_model( + opset_imports=[ + make_operatorsetid('', get_opset_number_from_onnx())], + graph=make_graph( + name=name, inputs=ginputs, outputs=goutputs, + nodes=[node])) + oinf = OnnxInference(model_def) + got = oinf.run({n: v for n, v in zip(node.input, inputs)}) + self.assertEqual(len(got), 1) + oseq = got['output_sequence'] + self.assertEqual(len(oseq), len(outputs)) + for e, g in zip(outputs, oseq): + self.assertEqualArray(e, g) + + test_cases = { + 'at_back': [numpy.array([10, 11, 12]).astype(numpy.int64)], + 'at_front': [numpy.array([-2, -1, 0]), + numpy.array([0]).astype(numpy.int64)]} + sequence = [numpy.array([1, 2, 3, 4]).astype(numpy.int64), + numpy.array([5, 6, 7]).astype(numpy.int64), + numpy.array([8, 9]).astype(numpy.int64)] + + for test_name, test_inputs in test_cases.items(): + with self.subTest(test_name=test_name): + tensor = test_inputs[0].astype(numpy.int64) + + if len(test_inputs) > 1: + node = make_node( + 'SequenceInsert', + inputs=['sequence', 'tensor', 'position'], + outputs=['output_sequence']) + position = test_inputs[1] + inserted = self.sequence_insert_reference_implementation( + sequence, tensor, position) + expect(node, inputs=[sequence, tensor, position], outputs=inserted, + name='test_sequence_insert_' + test_name) + else: + node = make_node( + 'SequenceInsert', + inputs=['sequence', 'tensor'], + outputs=['output_sequence']) + inserted = self.sequence_insert_reference_implementation( + sequence, tensor) + expect(node, inputs=[sequence, tensor], outputs=inserted, + name='test_sequence_insert_' + test_name) + @ignore_warnings(DeprecationWarning) def test_loop(self): # Given a tensor x of values [x1, ..., xN], @@ -120,7 +185,7 @@ def test_loop(self): expected = numpy.array([ 1., 1., 2., 1., 2., 3., 1., 2., 3., 4., 1., 2., 3., 4., 5.], dtype=numpy.float32) - for rt in ['onnxruntime1', 'python']: + for rt in ['onnxruntime1', 'python', 'python_compiled']: with self.subTest(rt=rt): oinf = OnnxInference(model_def, runtime=rt) inputs = { @@ -140,6 +205,122 @@ def test_loop(self): continue self.assertIsInstance(v, SequenceType) + @ignore_warnings(DeprecationWarning) + def test_loop_additional_input(self): + # Given a tensor x of values [x1, ..., xN], + # Return a sequence of tensors of + # [[x1], [x1, x2], ..., [x1, ..., xN]] + + cond_in = make_tensor_value_info( + 'cond_in', TensorProto.BOOL, []) # pylint: disable=E1101 + cond_out = make_tensor_value_info( + 'cond_out', TensorProto.BOOL, []) # pylint: disable=E1101 + iter_count = make_tensor_value_info( + 'iter_count', TensorProto.INT64, []) # pylint: disable=E1101 + seq_in = make_tensor_sequence_value_info( + 'seq_in', TensorProto.FLOAT, None) # pylint: disable=E1101 + seq_out = make_tensor_sequence_value_info( + 'seq_out', TensorProto.FLOAT, None) # pylint: disable=E1101 + + x = numpy.array([1, 2, 3, 4, 5]).astype(numpy.float32) + + x_const_node = make_node( + 'Constant', inputs=[], outputs=['x'], + value=make_tensor( + name='const_tensor_x', data_type=TensorProto.FLOAT, # pylint: disable=E1101 + dims=x.shape, vals=x.flatten().astype(float))) + + zero_const_node = make_node( + 'Constant', inputs=[], outputs=['slice_start'], + value=make_tensor( + name='const_tensor_zero', data_type=TensorProto.INT64, # pylint: disable=E1101 + dims=(1,), vals=[0])) + + axes_node = make_node( + 'Constant', inputs=[], outputs=['axes'], + value=make_tensor( + name='const_tensor_axes', data_type=TensorProto.INT64, # pylint: disable=E1101 + dims=(), vals=[0])) + + add_node = make_node( + 'Add', inputs=['iter_count', 'XI'], outputs=['slice_end']) + + slice_node = make_node( + 'Slice', inputs=['x', 'slice_start', 'slice_end'], outputs=['slice_out']) + + insert_node = make_node( + 'SequenceInsert', inputs=['seq_in', 'slice_out'], outputs=['seq_out']) + + identity_node = make_node( + 'Identity', inputs=['cond_in'], outputs=['cond_out']) + + loop_body = make_graph( + [identity_node, x_const_node, zero_const_node, add_node, + axes_node, slice_node, insert_node], + 'loop_body', [iter_count, cond_in, seq_in], [cond_out, seq_out]) + + node = make_node( + 'Loop', inputs=['trip_count', 'cond', 'seq_empty'], + outputs=['seq_res'], body=loop_body) + node1 = make_node('Neg', inputs=['XI'], outputs=['Y']) + node_concat = make_node( + 'ConcatFromSequence', inputs=['seq_res'], + outputs=['res'], axis=0, new_axis=0) + + trip_count = numpy.array(5).astype(numpy.int64) + seq_empty = [] # type: List[Any] + cond = numpy.array(1).astype(numpy.bool) + + model_def = make_model( + opset_imports=[ + make_operatorsetid('', get_opset_number_from_onnx())], + graph=make_graph( + name='loop_test', + inputs=[ + make_tensor_value_info( + 'trip_count', TensorProto.INT64, trip_count.shape), # pylint: disable=E1101 + make_tensor_value_info( + 'cond', TensorProto.BOOL, cond.shape), # pylint: disable=E1101 + make_sequence_value_info( + 'seq_empty', TensorProto.FLOAT, []), # pylint: disable=E1101 + make_tensor_value_info( + 'XI', TensorProto.INT64, [])], # pylint: disable=E1101 + outputs=[ + make_tensor_value_info( + 'res', TensorProto.FLOAT, None), # pylint: disable=E1101 + make_tensor_value_info( + 'Y', TensorProto.INT64, [])], # pylint: disable=E1101 + nodes=[node1, node, node_concat])) + + expected = numpy.array([ + 1., 1., 2., 1., 2., 3., 1., 2., + 3., 4., 1., 2., 3., 4., 5.], dtype=numpy.float32) + X = numpy.array([1], dtype=numpy.int64) + for rt in ['python', 'onnxruntime1', 'python_compiled']: + with self.subTest(rt=rt): + oinf = OnnxInference(model_def, runtime=rt) + inputs = { + 'trip_count': trip_count, 'cond': cond, + 'seq_empty': seq_empty, + 'XI': X} + if rt == 'python_compiled': + code = str(oinf) + self.assertIn("context={'XI': XI}", code) + got = oinf.run(inputs) + self.assertEqualArray(-X, got['Y']) + self.assertEqualArray(expected, got['res']) + if rt == 'python': + siz = oinf.infer_sizes(inputs) + self.assertIsInstance(siz, dict) + typ = oinf.infer_types() + self.assertEqual(typ["trip_count"], numpy.int64) + if 'cond' in typ: + self.assertEqual(typ["cond"], numpy.bool_) + for k, v in typ.items(): + if k in {'trip_count', 'cond', 'Y', 'XI'}: + continue + self.assertIsInstance(v, SequenceType) + def sequence_insert_reference_implementation( self, sequence, tensor, position=None): seq = list(sequence) @@ -150,71 +331,6 @@ def sequence_insert_reference_implementation( seq.append(tensor) return seq - @ignore_warnings(DeprecationWarning) - def test_sequence_insert(self): - - def expect(node, inputs, outputs, name): - ginputs = [ - make_sequence_value_info( - node.input[0], TensorProto.FLOAT, []), # pylint: disable=E1101, - make_sequence_value_info( - node.input[1], TensorProto.FLOAT, []), # pylint: disable=E1101, - ] - if len(node.input) > 2: - ginputs.append( - make_tensor_value_info( - node.input[2], TensorProto.INT64, []), # pylint: disable=E1101 - ) - goutputs = [ - make_sequence_value_info( - node.output[0], TensorProto.FLOAT, []), # pylint: disable=E1101, - ] - model_def = make_model( - opset_imports=[ - make_operatorsetid('', get_opset_number_from_onnx())], - graph=make_graph( - name=name, inputs=ginputs, outputs=goutputs, - nodes=[node])) - oinf = OnnxInference(model_def) - got = oinf.run({n: v for n, v in zip(node.input, inputs)}) - self.assertEqual(len(got), 1) - oseq = got['output_sequence'] - self.assertEqual(len(oseq), len(outputs)) - for e, g in zip(outputs, oseq): - self.assertEqualArray(e, g) - - test_cases = { - 'at_back': [numpy.array([10, 11, 12]).astype(numpy.int64)], - 'at_front': [numpy.array([-2, -1, 0]), - numpy.array([0]).astype(numpy.int64)]} - sequence = [numpy.array([1, 2, 3, 4]).astype(numpy.int64), - numpy.array([5, 6, 7]).astype(numpy.int64), - numpy.array([8, 9]).astype(numpy.int64)] - - for test_name, test_inputs in test_cases.items(): - with self.subTest(test_name=test_name): - tensor = test_inputs[0].astype(numpy.int64) - - if len(test_inputs) > 1: - node = make_node( - 'SequenceInsert', - inputs=['sequence', 'tensor', 'position'], - outputs=['output_sequence']) - position = test_inputs[1] - inserted = self.sequence_insert_reference_implementation( - sequence, tensor, position) - expect(node, inputs=[sequence, tensor, position], outputs=inserted, - name='test_sequence_insert_' + test_name) - else: - node = make_node( - 'SequenceInsert', - inputs=['sequence', 'tensor'], - outputs=['output_sequence']) - inserted = self.sequence_insert_reference_implementation( - sequence, tensor) - expect(node, inputs=[sequence, tensor], outputs=inserted, - name='test_sequence_insert_' + test_name) - if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_tools/data/bug_graph.onnx b/_unittests/ut_tools/data/bug_graph.onnx new file mode 100644 index 000000000..4dbdf0eec Binary files /dev/null and b/_unittests/ut_tools/data/bug_graph.onnx differ diff --git a/_unittests/ut_tools/data/bug_graph_infinite.onnx b/_unittests/ut_tools/data/bug_graph_infinite.onnx new file mode 100644 index 000000000..228943cd9 Binary files /dev/null and b/_unittests/ut_tools/data/bug_graph_infinite.onnx differ diff --git a/_unittests/ut_tools/test_code_helper.py b/_unittests/ut_tools/test_code_helper.py index 4ca9694c8..759287b9e 100644 --- a/_unittests/ut_tools/test_code_helper.py +++ b/_unittests/ut_tools/test_code_helper.py @@ -9,7 +9,8 @@ from scipy.sparse import csr_matrix from pyquickhelper.pycode import ExtTestCase, get_temp_folder from mlprodict.tools.code_helper import ( - debug_print, debug_dump, numpy_min, numpy_max, make_callable) + debug_print, debug_dump, numpy_min, numpy_max, make_callable, + print_code) class TestCodeHelper(ExtTestCase): @@ -84,6 +85,11 @@ def fctf(b=True): self.assertTrue(fct(True)) # pylint: disable=E1102 self.assertFalse(fct(False)) # pylint: disable=E1102 + def test_print_code(self): + code = "a=1\nb=2" + cc = print_code(code) + self.assertEqual(cc, "001 a=1\n002 b=2") + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_tools/test_graphs.py b/_unittests/ut_tools/test_graphs.py index 4d1ad4f59..401eb7508 100644 --- a/_unittests/ut_tools/test_graphs.py +++ b/_unittests/ut_tools/test_graphs.py @@ -2,6 +2,7 @@ """ @brief test log(time=3s) """ +import os import unittest import numpy from sklearn.datasets import load_iris @@ -11,6 +12,7 @@ from pyquickhelper.pycode import ExtTestCase from skl2onnx.algebra.onnx_ops import OnnxAdd, OnnxSub # pylint: disable=E0611 from mlprodict.onnx_conv import to_onnx +from mlprodict.onnxrt import OnnxInference from mlprodict.tools import get_opset_number_from_onnx from mlprodict.tools.graphs import onnx2bigraph, BiGraph @@ -86,6 +88,24 @@ def test_pipe_graph_display_text(self): 'inout', 'O0 I0', 'A S']: self.assertIn(c, text) + def test_bug_graph(self): + this = os.path.abspath(os.path.dirname(__file__)) + data = os.path.join(this, "data", "bug_graph.onnx") + oinf = OnnxInference( + data, inside_loop=True, + static_inputs=['StatefulPartitionedCall/Reshape:0']) + text = oinf.to_text(distance=8) + self.assertIn( + "cond___pcen/simple_rnn/while/Identity_graph_outputs_Identity__4:0", + text) + + def test_bug_graph_infinite(self): + this = os.path.abspath(os.path.dirname(__file__)) + data = os.path.join(this, "data", "bug_graph_infinite.onnx") + oinf = OnnxInference(data, inside_loop=True) + text = oinf.to_text(distance=8) + self.assertIn("slice_end", text) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/__init__.py b/mlprodict/__init__.py index c6ad7fbd2..5a0d9c63f 100644 --- a/mlprodict/__init__.py +++ b/mlprodict/__init__.py @@ -4,7 +4,7 @@ @brief Ways to speed up predictions for a machine learned model. """ -__version__ = "0.6.1522" +__version__ = "0.6.1545" __author__ = "Xavier Dupré" diff --git a/mlprodict/onnxrt/onnx_inference.py b/mlprodict/onnxrt/onnx_inference.py index c1c0bac18..9a9a7bd7e 100644 --- a/mlprodict/onnxrt/onnx_inference.py +++ b/mlprodict/onnxrt/onnx_inference.py @@ -9,12 +9,13 @@ from time import perf_counter import warnings import textwrap +import pprint import numpy from scipy.sparse import coo_matrix from onnx import load, load_model, checker, shape_inference from onnx import onnx_pb as onnx_proto from onnx.helper import make_model -from ..tools.code_helper import make_callable +from ..tools.code_helper import make_callable, print_code from ..onnx_tools.onnx2py_helper import ( _var_as_dict, numpy_min, numpy_max, guess_numpy_type_from_string) from ..onnx_tools.onnx_manipulations import ( @@ -59,6 +60,11 @@ class OnnxInference: :param ir_version: if not None, overwrite the default version :param target_opset: used to overwrite *target_opset* :param runtime_options: specific options for the runtime + :param inside_loop: tells the runtime the graph is meant to + be repeated multiple times (in that case, inputs and + outputs may share the same name) + :param static_inputs: Loop can use static variables, + variables from the graph which runs the loop Among the possible runtime_options, there are: * *enable_profiling*: enables profiling for :epkg:`onnxruntime` @@ -71,7 +77,8 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None, skip_run=False, inplace=True, input_inplace=False, ir_version=None, target_opset=None, runtime_options=None, - session_options=None): + session_options=None, inside_loop=False, + static_inputs=None): if isinstance(onnx_or_bytes_or_stream, bytes): self.obj = load_model(BytesIO(onnx_or_bytes_or_stream)) elif isinstance(onnx_or_bytes_or_stream, BytesIO): @@ -94,6 +101,8 @@ def __init__(self, onnx_or_bytes_or_stream, runtime=None, self.inplace = inplace self.force_target_opset = target_opset self.runtime_options = runtime_options + self.inside_loop = inside_loop + self.static_inputs = static_inputs self._init() def __getstate__(self): @@ -106,7 +115,9 @@ def __getstate__(self): 'skip_run': self.skip_run, 'input_inplace': self.input_inplace, 'inplace': self.inplace, - 'force_target_opset': self.force_target_opset} + 'force_target_opset': self.force_target_opset, + 'static_inputs': self.static_inputs, + 'inside_loop': self.inside_loop} def __setstate__(self, state): """ @@ -120,6 +131,8 @@ def __setstate__(self, state): self.input_inplace = state['input_inplace'] self.inplace = state['inplace'] self.force_target_opset = state['force_target_opset'] + self.static_inputs = state['static_inputs'] + self.inside_loop = state['inside_loop'] self._init() def _init(self): @@ -137,8 +150,9 @@ def _init(self): for xy in ino: shape = xy.type.tensor_type.shape for d in shape.dim: - if d.dim_value == 0 and "0" in str(d): + if d.dim_value == 0 and "0" in str(d) and 'dim_param' not in str(d): # d.dim_value returns 0 whether is is 0 or empty. + # it may be a parameter as well raise RuntimeError( # pragma: no cover "Wrong ONNX file, one input or output has an empty shape: " "{}.".format(xy)) @@ -150,6 +164,7 @@ def _init(self): else: self.target_opset_ = {'': self.force_target_opset} self.ir_version_ = self.graph_['ir_version'] + if not self.skip_run: if self.runtime == 'onnxruntime1': # Loads the onnx with onnxruntime as a single file. @@ -161,6 +176,7 @@ def _init(self): else: self.sequence_ = self.graph_['sequence'] self.inits_ = self.graph_['inits'] + self.statics_ = self.graph_['statics'] dtype = self._guess_input_dtype() variables = self.inits_.copy() for node in self.sequence_: @@ -180,6 +196,7 @@ def _init(self): for k, v in node.ops_.typed_outputs_: variables[k] = v self._run = self._run_sequence_runtime + if not self.skip_run and self.runtime in ('python', None): self.shapes_ = self._set_shape_inference_runtime() if self.inplace: @@ -206,7 +223,13 @@ def _run_sequence_runtime_compiled( Every parameter with a default value is ignored. Switch to ``runtime='python'`` to enable those. """ - return self._run_compiled(inputs) # pylint: disable=E1101 + try: + return self._run_compiled(inputs) # pylint: disable=E1101 + except NameError as e: + raise RuntimeError( # pragma: no cover + "Unable to compute prediction due to %r. Code:\n%s" + "" % (e, print_code( + self._run_compiled_code))) from e # pylint: disable=E1101 def _guess_input_dtype(self): for _, v in self.graph_['inputs'].items(): @@ -279,6 +302,17 @@ def input_names_shapes(self): return [(_.name, _var_as_dict(_)['type']['shape']) for _ in self.obj.graph.input if _.name in names] + @staticmethod + def _get_type_property(info, prop): + if prop in info: + return info[prop] + if 'kind' in info and info['kind'] == 'sequence': + if prop == 'shape': + return ('?', ) + raise NotImplementedError( + "Unable to retrieve property %r from %r." + "" % (prop, info)) + @property def input_names_shapes_types(self): """ @@ -289,9 +323,10 @@ def input_names_shapes_types(self): .. versionchanged:: 0.6 The list does not include optional inputs anymore. """ + f = OnnxInference._get_type_property names = set(self.input_names) - return [(_.name, _var_as_dict(_)['type']['shape'], - 'tensor(%s)' % _var_as_dict(_)['type']['elem']) + return [(_.name, f(_var_as_dict(_)['type'], 'shape'), + 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem')) for _ in self.obj.graph.input if _.name in names] @property @@ -307,9 +342,25 @@ def output_names_shapes(self): Returns the names and shapes of all outputs. This method assumes all inputs are tensors. """ - return [(_.name, _var_as_dict(_)['type'].get('shape', None)) + f = OnnxInference._get_type_property + return [(_.name, f(_var_as_dict(_)['type'], 'shape')) for _ in self.obj.graph.output] + @property + def output_names_shapes_types(self): + """ + Returns the names, shapes, types of all outputs. + This method assumes all inputs are tensors. + It does not include the optional outputs. + + .. versionadd:: 0.7 + """ + names = set(self.output_names) + f = OnnxInference._get_type_property + return [(_.name, f(_var_as_dict(_)['type'], 'shape'), + 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem')) + for _ in self.obj.graph.output if _.name in names] + def global_index(self, name): """ Maps every name to one integer to avoid using dictionaries @@ -364,10 +415,17 @@ def to_sequence(self): variables = {} outputs = {} nodes = {} + statics = {} targets = {} for o in self.obj.opset_import: targets[o.domain] = o.version + # static variables + if self.static_inputs is not None: + for n in self.static_inputs: + statics[n] = {'name': n} + self.global_index(n) + # inputs for obj in self.obj.graph.input: variables[obj.name] = _var_as_dict(obj) @@ -421,6 +479,12 @@ def to_sequence(self): # names names = {} + for k, v in statics.items(): + if (k, 0) in names: + raise RuntimeError( # pragma: no cover + "Static variables '{}' already exists (tag='{}').".format( + k, names[k, 0][0])) + names[k, 0] = ('S', v) for k, v in inits.items(): if (k, 0) in names: raise RuntimeError( # pragma: no cover @@ -438,14 +502,20 @@ def to_sequence(self): names[k, 0] = ('I', v) for k, v in outputs.items(): if (k, 0) in names and self.runtime != 'empty': - raise RuntimeError( # pragma: no cover - "Output '{}' already exists (tag='{}').".format( - k, names[k, 0][0])) + if not self.inside_loop or names[k, 0][0] != 'I': + raise RuntimeError( # pragma: no cover + "Output '{}' already exists (tag='{}').".format( + k, names[k, 0][0])) + else: + # For input, output sharing the same name, we marked the name + # as an input. + continue names[k, 0] = ('O', v) for k, v in nodes.items(): if (k, 1) in names: raise RuntimeError( # pragma: no cover - "Node '{}' already exists (tag='{}').".format( + "Node '{}' already exists (tag='{}'). " + "Use inside_loop=True to bypass this exception.".format( k, names[k, 0][0])) names[k, 1] = ('N', v) @@ -459,7 +529,7 @@ def to_sequence(self): if (k, 1) in order: # The operator node is already processed. continue - if v[0] in {'I', 'C'}: + if v[0] in {'I', 'C', 'S'}: if (k, 0) not in order: order[k, 0] = len(order) # A data node. modif += 1 @@ -474,7 +544,8 @@ def to_sequence(self): for o in v[1].outputs: if (o, 0) in order: raise RuntimeError( # pragma: no cover - "Two nodes share the same output '{}' or an operator and an output " + "Two nodes share the same output '{}' " + "or an operator and an output " "share the same name. " "(node: {}).".format(o, v[1])) # We add a data node. @@ -516,9 +587,21 @@ def to_sequence(self): for k, ord in last_used.items(): sequence[ord].add_variable_to_clean(k) - return dict(inits=inits, inputs=variables, outputs=outputs, - nodes=nodes, sequence=sequence, intermediate=intermediate, - targets=targets, ir_version=self.obj.ir_version) + results = dict(inits=inits, inputs=variables, outputs=outputs, + nodes=nodes, sequence=sequence, + intermediate=intermediate, + targets=targets, ir_version=self.obj.ir_version, + statics=statics) + if len(sequence) < len(nodes): + # Not all node will be executed. + raise RuntimeError( + "Unable to run all nodes.\n--Nodes--\n%s\n--Sequence--\n%s" + "\n--Inputs--\n%s\n--Inits--\n%s\n--Statics\n%s" + "" % (pprint.pformat(nodes), pprint.pformat(sequence), + pprint.pformat(list(variables)), + pprint.pformat(list(inits)), + pprint.pformat(list(statics)))) + return results def run(self, inputs, clean_right_away=False, intermediate=False, verbose=0, node_time=False, @@ -673,7 +756,7 @@ def dispsimple(arr): threshold = 8 else: threshold = min( - 50, min(50 // arr.shape[1], 8) * arr.shape[1]) + 50, min(50 // max(arr.shape[1], 1), 8) * arr.shape[1]) if hasattr(arr, 'todense'): fLOG( # pragma: no cover numpy.array2string(arr.todense(), max_line_width=120, @@ -702,8 +785,10 @@ def dispsimple(arr): ' (sparse)' if isinstance(obj, coo_matrix) else '')) elif (isinstance(obj, list) and len(obj) > 0 and not isinstance(obj[0], dict)): # pragma: no cover - fLOG("-kv='{}' list len={} min={} max={}".format( - k, len(obj), min(obj), max(obj))) + fLOG("-kv='{}' list len={}".format(k, len(obj))) + if verbose >= 3 and len(obj) > 0: + fLOG("first={} last={}".format( + obj[0], obj[-1])) else: # pragma: no cover fLOG("-kv='{}' type={}".format(k, type(obj))) @@ -722,10 +807,12 @@ def dispsimple(arr): time=t2 - t)) else: node.run(values) + added = 0 for k in range(len(values)): # pylint: disable=C0200 if values[k] is None: continue if k not in keys and k not in printed: + added += 1 printed.add(k) name = list( name for name in self._global_index # pylint: disable=C0206 @@ -734,7 +821,8 @@ def dispsimple(arr): name = name[0] mini = numpy_min(values[k]) maxi = numpy_max(values[k]) - fLOG("+kr='{}': {} (dtype={} min={} max={}{})".format( + fLOG("+kr{}'{}': {} (dtype={} min={} max={}{})".format( + "=" if len(values[k].shape) == 0 or min(values[k].shape) > 0 else "*", name, values[k].shape, values[k].dtype, mini, maxi, ' sparse' if isinstance(values[k], coo_matrix) else '')) @@ -745,6 +833,8 @@ def dispsimple(arr): name, type(values[k]))) if verbose >= 3: # pragma: no cover dispsimple(values[k]) + if added == 0: + fLOG("? no new result") if intermediate: values = [(v, k, values[v]) for k, v in self._global_index.items()] @@ -982,7 +1072,22 @@ def _set_shape_inference_runtime(self): for k, v in self.inputs_.items(): # The function assumes the first dimension is unknown # and is the batch size. - values[k] = ShapeObject(v, use_n1=True, name=k) + try: + values[k] = ShapeObject(v, use_n1=True, name=k) + except TypeError as e: + raise TypeError( + "Unable to guess shape for %r (shape=%r)." % (k, v)) from e + + impossible = False + for k, v in self.statics_.items(): + # static inputs should be known. + try: + values[k] = ShapeObject(v) + except TypeError: + # default value is wrong + impossible = True + values[k] = None + for k, v in self.inits_.items(): values[k] = ShapeObject(v['value'], name=k) last = None @@ -990,15 +1095,17 @@ def _set_shape_inference_runtime(self): try: s = node._set_shape_inference_runtime(values) last = s - except IndexError as e: # pragma: no cover + except (IndexError, TypeError, KeyError, + AttributeError) as e: # pragma: no cover rows = [] if last is not None: for k, v in last.items(): rows.append("{}: {}".format(k, v)) for k in range(i + 1): rows.append("{} --> {}".format(k, self.sequence_[k])) - raise RuntimeError("Unable to infer shape of node {}\n{}".format( - i, '\n'.join(rows))) from e + if not impossible: + raise RuntimeError("Unable to infer shape of node {}\n{}".format( + i, '\n'.join(rows))) from e return values def infer_shapes(self): @@ -1020,6 +1127,8 @@ def _set_type_inference_runtime(self): "This method only works if the runtime is 'python' not " "'{}'.".format(self.runtime)) values = OrderedDict() + for k, v in self.statics_.items(): + values[k] = None for k, v in self.inputs_.items(): # The function assumes the first dimension is unknown # and is the batch size. @@ -1054,7 +1163,7 @@ def infer_types(self): """ return self._set_type_inference_runtime() - def _set_size_inference_runtime(self, inputs): + def _set_size_inference_runtime(self, inputs, context=None): """ Set sizes allocated during inference relying on the runtime. @@ -1065,11 +1174,17 @@ def _set_size_inference_runtime(self, inputs): "This method only works if the runtime is 'python' not " "'{}'.".format(self.runtime)) values = OrderedDict() + for k, v in self.statics_.items(): + if context is None: + raise RuntimeError( # pragma: no cover + "static variable but context is None.") + values[k] = context[k] for k, v in self.inits_.items(): values[k] = v['value'] for k, v in self.inputs_.items(): if k in inputs: values[k] = inputs[k] + last = None for i, node in enumerate(self.sequence_): try: @@ -1086,14 +1201,14 @@ def _set_size_inference_runtime(self, inputs): i, '\n'.join(rows))) from e return values - def infer_sizes(self, inputs): + def infer_sizes(self, inputs, context=None): """ Computes expected sizes. :param inputs: inputs as a dictionary :return: dictionary of dictionary of sizes """ - res = self._set_size_inference_runtime(inputs) + res = self._set_size_inference_runtime(inputs, context=context) return {k: v for k, v in res.items() if k.startswith('#')} def _guess_inplace(self, input_inplace=False): @@ -1127,6 +1242,8 @@ def _guess_inplace(self, input_inplace=False): """ forbid = {} values = OrderedDict() + for k in self.statics_: + values[k] = dict(inplace=False, to=[], fr=[]) for k in self.inputs_: values[k] = dict(inplace=input_inplace, to=[], fr=[]) for k in self.inits_: @@ -1223,24 +1340,41 @@ def clean_name(name): code = ['def compiled_run(dict_inputs):'] if debug: code.append(" printed = {}") + context = {} - for k, v in self.inits_.items(): + + # static variables + for k in sorted(self.statics_): + code.append(" # static: {0}".format(k)) + code.append(" {0} = dict_inputs['{1}']".format( + clean_name(k), k)) + if debug: + code.append( + " debug_print('i.{0}', {1}, printed)".format( + clean_name(k), k)) + + # initializers + for k, v in sorted(self.inits_.items()): if k.startswith("_OPT_"): raise RuntimeError( # pragma: no cover "The runtime cannot handle any constant name " "starting with '_OPT_': '{}'.".format(k)) if k in inputs: - context["_OPT_" + k] = v['value'] - code.append(" # init: _OPT_{0}".format(k)) + context["_OPT_" + clean_name(k)] = v['value'] + code.append(" # init: _OPT_{0} ({1})".format( + clean_name(k), k)) if debug: code.append( - " debug_print('c.[_OPT_{0}]', _OPT_{0}, printed)".format(k)) + " debug_print('c.[_OPT_{0}]', _OPT_{1}, printed)".format( + clean_name(k), k)) else: - context[k] = v['value'] - code.append(" # init: {0}".format(k)) + context[clean_name(k)] = v['value'] + code.append(" # init: {0} ({1})".format( + clean_name(k), k)) if debug: code.append( - " debug_print('c.[{0}]', {0}, printed)".format(k)) + " debug_print('c.[{0}]', {1}, printed)".format( + clean_name(k), k)) # method signature code.append(" # inputs") @@ -1262,10 +1396,20 @@ def clean_name(name): for i, node in enumerate(self.sequence_): name = "n{}_{}".format(i, node.ops_.__class__.__name__.lower()) context[name] = node.ops_._run - code.append(' ({1}, ) = {2}({0})'.format( - ', '.join(map(clean_name, node.inputs)), - ', '.join(map(clean_name, node.outputs)), - name)) + if (node.ops_.__class__.__name__ == 'Loop' and + node.ops_.need_context()): + # Adding context. + ctx = "{%s}" % ", ".join( + "'%s': %s" % (n, n) for n in node.ops_.additional_inputs) + code.append(' ({1}, ) = {2}({0}, context={3})'.format( + ', '.join(map(clean_name, node.inputs)), + ', '.join(map(clean_name, node.outputs)), + name, ctx)) + else: + code.append(' ({1}, ) = {2}({0})'.format( + ', '.join(map(clean_name, node.inputs)), + ', '.join(map(clean_name, node.outputs)), + name)) if debug: code.append(" print('''# {}''')".format(code[-1][4:])) for o in node.outputs: diff --git a/mlprodict/onnxrt/onnx_inference_exports.py b/mlprodict/onnxrt/onnx_inference_exports.py index 3a8aa4874..c59f7eced 100644 --- a/mlprodict/onnxrt/onnx_inference_exports.py +++ b/mlprodict/onnxrt/onnx_inference_exports.py @@ -585,14 +585,16 @@ def clean_args(args): "Unknown extension for file '{}'.".format(k)) return file_data - def to_text(self, recursive=False): + def to_text(self, recursive=False, grid=5, distance=5): """ It calls function @see fn onnx2bigraph to return the ONNX graph as text. :param recursive: dig into subgraphs too + :param grid: align text to this grid + :param distance: distance to the text :return: text """ bigraph = onnx2bigraph(self.oinf.obj, recursive=recursive) - graph = bigraph.display_structure() + graph = bigraph.display_structure(grid=grid, distance=distance) return graph.to_text() diff --git a/mlprodict/onnxrt/onnx_inference_node.py b/mlprodict/onnxrt/onnx_inference_node.py index 91bd24ff2..0fe298306 100644 --- a/mlprodict/onnxrt/onnx_inference_node.py +++ b/mlprodict/onnxrt/onnx_inference_node.py @@ -46,6 +46,7 @@ def _init(self, global_index): self.inplaces = [] self.inputs_indices = [global_index(name) for name in self.inputs] self.outputs_indices = [global_index(name) for name in self.outputs] + self._global_index = global_index def set_order(self, order): """ @@ -115,6 +116,29 @@ def setup_runtime(self, runtime=None, variables=None, rt_class=None, options=options if options else None, variables=variables) + @staticmethod + def _find_static_inputs(body): + """ + Determines the loop inputs. It is any defined inputs + by the subgraphs + any results used as a constant + in the subgraphs. + """ + inputs_set = set(i.name for i in body.input) + for init in body.initializer: + inputs_set.add(init.name) + for node in body.node: + for i in node.output: + inputs_set.add(i) + add_inputs = [] + for node in body.node: + for i in node.input: + if i not in inputs_set: + # no graph input or output node matches + # it must be a constant from the below graph + add_inputs.append(i) + inputs_set.add(i) + return add_inputs + def preprocess_parameters(self, runtime, rt_class, ir_version=None, target_opset=None): """ @@ -131,14 +155,23 @@ def preprocess_parameters(self, runtime, rt_class, ir_version=None, """ if 'atts' not in self.desc: return # pragma: no cover + inside_loop = self.onnx_node.op_type in {'Loop'} for _, v in self.desc['atts'].items(): if 'value' not in v: continue # pragma: no cover value = v['value'] if isinstance(value, onnx_proto.GraphProto): - sess = rt_class(v['value'], runtime=runtime, - ir_version=ir_version, - target_opset=target_opset) + static_inputs = OnnxInferenceNode._find_static_inputs(value) + try: + sess = rt_class(v['value'], runtime=runtime, + ir_version=ir_version, + target_opset=target_opset, + inside_loop=inside_loop, + static_inputs=static_inputs) + except RuntimeError as e: # pragma: no cover + raise RuntimeError( + "Unable to instantiate a node of type %r and name %r." + "" % (self.onnx_node.op_type, self.onnx_node.name)) from e v['value_rt'] = sess def run(self, values): @@ -154,12 +187,21 @@ def run(self, values): args = list(values[k] for k in self.inputs) else: args = list(values[k] for k in self.inputs_indices) - try: - res = self.ops_.run(*args) + if self.ops_.need_context(): + context = {n: values[self._global_index(n)] + for n in self.ops_.additional_inputs} + res = self.ops_.run(*args, context=context) + else: + res = self.ops_.run(*args) except TypeError as e: raise RuntimeError( # pragma: no cover - "Unable to run operator %r." % type(self.ops_)) from e + "Unable to run operator %r, inputs=%r." + "" % (type(self.ops_), self.inputs)) from e + except OverflowError as e: + raise RuntimeError( # pragma: no cover + "Unable to run operator %r, inputs=%r." + "" % (type(self.ops_), self.inputs)) from e if not isinstance(res, tuple): raise RuntimeError( # pragma: no cover @@ -271,12 +313,17 @@ def _set_size_inference_runtime(self, values): """ args = [values[k] for k in self.inputs] try: - res = self.ops_.infer_sizes(*args) + if self.ops_.need_context(): + context = {n: values[n] + for n in self.ops_.additional_inputs} + res = self.ops_.infer_sizes(*args, context=context) + else: + res = self.ops_.infer_sizes(*args) except (TypeError, ValueError) as e: raise TypeError( "Unable to call infer_sizes with {} arguments for class" " '{}' ({})".format(len(args), self.ops_.__class__.__name__, - self.ops_.infer_types)) from e + self.ops_.infer_sizes)) from e if not isinstance(res, tuple): raise RuntimeError( # pragma: no cover "Results of an operator should be a tuple for operator '{}'" diff --git a/mlprodict/onnxrt/ops_cpu/_op.py b/mlprodict/onnxrt/ops_cpu/_op.py index 125c028d0..ff1fdff7f 100644 --- a/mlprodict/onnxrt/ops_cpu/_op.py +++ b/mlprodict/onnxrt/ops_cpu/_op.py @@ -123,6 +123,15 @@ def __init__(self, onnx_node, desc=None, expected_attributes=None, "for node '{}' and options {}.".format( k, onnx_node.op_type, pprint.pformat(options))) + def need_context(self): + """ + Tells the runtime if this node needs the context + (all the results produced so far) as it may silently access + one of them (operator Loop). + The default answer is `False`. + """ + return False + def _find_custom_operator_schema(self, op_name): raise NotImplementedError( # pragma: no cover "This method should be overwritten for operator " @@ -702,7 +711,7 @@ def _run(self, a, b): # pylint: disable=W0221 try: self.numpy_fct(a, b, out=a) return (a, ) - except ValueError: + except (ValueError, TypeError): return (self.numpy_fct(a, b), ) if self.inplaces.get(1, False) and a.size <= b.size: if len(b.shape) == 1 and a.shape == (1, 1): @@ -710,7 +719,7 @@ def _run(self, a, b): # pylint: disable=W0221 try: self.numpy_fct(a, b, out=b) return (b, ) - except ValueError: + except (ValueError, TypeError): return (self.numpy_fct(a, b), ) return (self.numpy_fct(a, b), ) diff --git a/mlprodict/onnxrt/ops_cpu/op_concat_from_sequence.py b/mlprodict/onnxrt/ops_cpu/op_concat_from_sequence.py index 89e97a93e..be30787a3 100644 --- a/mlprodict/onnxrt/ops_cpu/op_concat_from_sequence.py +++ b/mlprodict/onnxrt/ops_cpu/op_concat_from_sequence.py @@ -19,6 +19,9 @@ def __init__(self, onnx_node, desc=None, **options): **options) def _run(self, seq): # pylint: disable=W0221 + if seq is None: + raise RuntimeError( # pragma: no cover + "A sequence cannot be null.") if self.new_axis == 1: seq2 = [s[..., numpy.newaxis] for s in seq] res = numpy.concatenate(seq2, axis=-1) diff --git a/mlprodict/onnxrt/ops_cpu/op_conv.py b/mlprodict/onnxrt/ops_cpu/op_conv.py index 1a68eac99..04c774ae7 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv.py +++ b/mlprodict/onnxrt/ops_cpu/op_conv.py @@ -40,6 +40,18 @@ def _run(self, X, W, B=None): # pylint: disable=W0221 raise ValueError( # pragma: no cover "X cannot be None for operator %r, ONNX=%r" % ( type(self), self.onnx_node)) + if min(X.shape) == 0: + raise RuntimeError( + "Unable to run operator Conv on an empty matrix. " + "X.shape=%r." % (X.shape, )) + if min(W.shape) == 0: + raise RuntimeError( + "Unable to run operator Conv on an empty matrix. " + "W.shape=%r." % (W.shape, )) + if B is not None and min(B.shape) == 0: + raise RuntimeError( + "Unable to run operator Conv on an empty matrix. " + "B.shape=%r." % (B.shape, )) if X.dtype == numpy.float32: return (self.rt32_.compute(X, W, B), ) return (self.rt64_.compute(X, W, B), ) diff --git a/mlprodict/onnxrt/ops_cpu/op_gather.py b/mlprodict/onnxrt/ops_cpu/op_gather.py index 335708043..161265acd 100644 --- a/mlprodict/onnxrt/ops_cpu/op_gather.py +++ b/mlprodict/onnxrt/ops_cpu/op_gather.py @@ -33,7 +33,7 @@ def _run(self, x, indices): # pylint: disable=W0221 return (numpy.empty((0, ), dtype=x.dtype), ) try: return (self.rt_[str(x.dtype)].compute(x, indices), ) - except KeyError: + except (KeyError, ValueError): return (numpy.take(x, indices, axis=self.axis), ) def _infer_shapes(self, x, indices): # pylint: disable=E0202,W0221 diff --git a/mlprodict/onnxrt/ops_cpu/op_loop.py b/mlprodict/onnxrt/ops_cpu/op_loop.py index 53987574b..c677ca212 100644 --- a/mlprodict/onnxrt/ops_cpu/op_loop.py +++ b/mlprodict/onnxrt/ops_cpu/op_loop.py @@ -3,9 +3,12 @@ """ @file @brief Runtime operator. + +.. versionadded:: 0.7 """ import numpy from ._op import OpRun +from ..shape_object import ShapeObject class Loop(OpRun): @@ -26,27 +29,57 @@ def __init__(self, onnx_node, desc=None, **options): self._run_meth = (self.body.run_in_scan if hasattr(self.body, 'run_in_scan') else self.body.run) + self.additional_inputs = self.body.static_inputs + + def need_context(self): + """ + The operator Loop needs to know all results produced + so far as the loop may silently access one of them. + Some information are not always referred in the list of inputs + (kind of static variables). + """ + return len(self.additional_inputs) > 0 - def _run(self, M, cond, v_initial, *args, callback=None): # pylint: disable=W0221 - inputs = {name: None for name in self.body.input_names} - inputs[self.body.input_names[2]] = v_initial - cond_name = self.body.output_names[1] + def _run(self, M, cond, v_initial, *args, callback=None, context=None): # pylint: disable=W0221 + loop_inputs = self.body.input_names + inputs = {name: None for name in loop_inputs} + inputs[loop_inputs[2]] = v_initial + cond_name = self.body.output_names[0] if len(args) > 0: - begin = len(self.body.input_names) - len(args) - for name, val in zip(self.body.input_names[begin:], args): + begin = len(loop_inputs) - len(args) + all_inputs = loop_inputs[begin:] + for name, val in zip(all_inputs, args): inputs[name] = val + if len(self.additional_inputs) > 0: + if context is None: + raise RuntimeError( + "Additional inputs %r are missing and context is None." + "" % (self.additional_inputs, )) + for a in self.additional_inputs: + if a in context: + inputs[a] = context[a] + else: + raise RuntimeError( + "Additional inputs %r not found in context\n%s." % ( + a, "\n".join(sorted(map(str, context))))) + it = 0 while cond and it < M: inputs[self.body.input_names[0]] = numpy.array(it, dtype=M.dtype) inputs[self.body.input_names[1]] = cond outputs = self._run_meth(inputs) cond = outputs[cond_name] + if cond is None: + raise RuntimeError( + "condition %r returned by the subgraph cannot be None." + "" % cond_name) for i, o in zip(self.body.input_names[2:], self.body.output_names[1:]): inputs[i] = outputs[o] if callback is not None: - callback(inputs) + callback(inputs, context=context) it += 1 + if it == 0: outputs = {self.body.output_names[1]: cond} for i, o in zip(self.body.input_names[2:], @@ -55,24 +88,38 @@ def _run(self, M, cond, v_initial, *args, callback=None): # pylint: disable=W02 for o in self.body.output_names: if o not in outputs: outputs[o] = numpy.empty(shape=tuple()) - return tuple([outputs[name] for name in self.body.output_names[1:]]) + res = tuple([outputs[name] for name in self.body.output_names[1:]]) + if any(r is None for r in res): + raise TypeError( # pragma: no cover + "Operator Loop produces a None value.") + return res def _infer_shapes(self, M, cond, v_initial, *args): # pylint: disable=W0221 res = self.body._set_shape_inference_runtime() - return tuple([res[name] for name in self.body.output_names[1:]]) + outputs = {k[0]: k[1:] for k in self.body.output_names_shapes_types} + ret = [] + for name in self.body.output_names[1:]: + if name in res: + ret.append(res[name]) + else: + find = outputs[name] + ret.append(ShapeObject(find[0], dtype=find[1])) + return tuple(ret) def _infer_types(self, M, cond, v_initial, *args): # pylint: disable=W0221 res = self.body._set_type_inference_runtime() return tuple([res[name] for name in self.body.output_names[1:]]) - def _infer_sizes(self, M, cond, v_initial, *args): # pylint: disable=W0221 + def _infer_sizes(self, M, cond, v_initial, *args, context=None): # pylint: disable=W0221 store = [] - def callback_(inputs): - res = self.body.infer_sizes(inputs) + def callback_(inputs, context=None): + res = self.body.infer_sizes(inputs, context=context) store.append(res) - res = self._run(M, cond, v_initial, *args, callback=callback_) + res = self._run(M, cond, v_initial, *args, callback=callback_, + context=context) + temp = 0 for v in store: for vv in v.values(): diff --git a/mlprodict/onnxrt/ops_cpu/op_reshape.py b/mlprodict/onnxrt/ops_cpu/op_reshape.py index 4b4f6ab03..372e707cf 100644 --- a/mlprodict/onnxrt/ops_cpu/op_reshape.py +++ b/mlprodict/onnxrt/ops_cpu/op_reshape.py @@ -13,8 +13,16 @@ def reshape_reference_implementation(data, shape): new_shape = numpy.copy(shape) zeros_index = numpy.where(shape == 0) - new_shape[zeros_index] = numpy.array(data.shape)[zeros_index] - reshaped = numpy.reshape(data, new_shape) + if len(data.shape) == 1 and data.shape[0] == 0: + reshaped = numpy.reshape(data, shape) + else: + try: + new_shape[zeros_index] = numpy.array(data.shape)[zeros_index] + except IndexError as e: + raise RuntimeError( + "Unable to reshape from shape %r to shape %r (or %r)." + "" % (data.shape, shape, new_shape)) from e + reshaped = numpy.reshape(data, new_shape) return reshaped diff --git a/mlprodict/onnxrt/ops_cpu/op_sequence_insert.py b/mlprodict/onnxrt/ops_cpu/op_sequence_insert.py index 648ecb957..39563e061 100644 --- a/mlprodict/onnxrt/ops_cpu/op_sequence_insert.py +++ b/mlprodict/onnxrt/ops_cpu/op_sequence_insert.py @@ -3,6 +3,8 @@ """ @file @brief Runtime operator. + +.. versionadded:: 0.7 """ from ._op import OpRun diff --git a/mlprodict/onnxrt/ops_empty/_op.py b/mlprodict/onnxrt/ops_empty/_op.py index d8f82defc..a0e0e7303 100644 --- a/mlprodict/onnxrt/ops_empty/_op.py +++ b/mlprodict/onnxrt/ops_empty/_op.py @@ -200,3 +200,12 @@ def run(self, *args, **kwargs): # inputs = {name: val for name, val in zip(self.inputs, args)} raise RuntimeError( # pragma: no cover "This runtime does nothing. Running it is useless.") + + def need_context(self): + """ + Tells the runtime if this node needs the context + (all the results produced so far) as it may silently access + one of them (operator Loop). + The default answer is `False`. + """ + return False diff --git a/mlprodict/onnxrt/ops_onnxruntime/_op.py b/mlprodict/onnxrt/ops_onnxruntime/_op.py index 169318be6..3c641d43f 100644 --- a/mlprodict/onnxrt/ops_onnxruntime/_op.py +++ b/mlprodict/onnxrt/ops_onnxruntime/_op.py @@ -279,3 +279,12 @@ def run(self, *args, **kwargs): list(sorted(inputs)), self.alg_class, dtypes, shapes, exp, exp_types, e, self.onnx_)) from e return tuple(res) + + def need_context(self): + """ + Tells the runtime if this node needs the context + (all the results produced so far) as it may silently access + one of them (operator Loop). + The default answer is `False`. + """ + return False diff --git a/mlprodict/onnxrt/shape_object.py b/mlprodict/onnxrt/shape_object.py index 1017e6e21..4b7a898ff 100644 --- a/mlprodict/onnxrt/shape_object.py +++ b/mlprodict/onnxrt/shape_object.py @@ -486,64 +486,72 @@ def __init__(self, shape, dtype=None, use_n1=False, name=None): self._dtype = dtype else: raise TypeError( # pragma: no cover - "Unexpected type for shape: {}".format(type(shape))) - - if self._dtype is None: - raise ValueError( - "dtype cannot be None, shape type is {}\n{}".format( + "Unexpected type for shape: {}, shape={}".format( type(shape), shape)) - if self._dtype in (float, 'double'): - self._dtype = numpy.float64 - elif self._dtype in ('float32', 'float'): - self._dtype = numpy.float32 - elif self._dtype in (numpy.float16, 'float16'): - self._dtype = numpy.float16 - elif self._dtype in ('int32', ): - self._dtype = numpy.int32 - elif self._dtype in (int, 'int', 'int64'): - self._dtype = numpy.int64 - elif self._dtype in (str, 'str', numpy.str_): - self._dtype = numpy.str_ - elif (hasattr(self._dtype, 'type') and self._dtype.type is numpy.string_): - pass - elif self._dtype in (bool, 'bool', numpy.bool_): - self._dtype = numpy.bool_ - elif self._dtype in (object, numpy.object_): - pass - elif self._dtype in (numpy.int8, 'int8', ): - self._dtype = numpy.int8 - elif self._dtype in (numpy.uint8, 'uint8', ): - self._dtype = numpy.uint8 - elif self._dtype in (numpy.int16, 'int16', ): - self._dtype = numpy.int16 - elif self._dtype in (numpy.uint16, 'uint16', ): - self._dtype = numpy.uint16 - elif self._dtype in (numpy.uint32, 'uint32', ): - self._dtype = numpy.uint32 - elif self._dtype in (numpy.uint64, 'uint64', ): - self._dtype = numpy.uint64 - elif self._dtype not in { - numpy.float32, numpy.float64, numpy.int32, numpy.int64, - numpy.str_, numpy.bool_, numpy.float16, None, - numpy.complex64, numpy.complex128, - 'map', 'sequence'}: - raise ValueError( # pragma: no cover - "dtype has an unexpected value: '{}'.".format(self._dtype)) - if self._shape is not None: - for i, a in enumerate(self._shape): - if not isinstance(a, DimensionObject): - raise TypeError( # pragma: no cover - 'Dimension {} has a wrong type {}'.format( - i, type(a))) - if use_n1: - sh = self._shape[0] if self._shape else None - if isinstance(sh, DimensionObject) and sh._dim is None: - sh._dim = 'n' - if self._shape is not None: - for s in self._shape: - if isinstance(s, int): - raise TypeError( # pragma: no cover - "Unexpected type int in shape %r." % self) + + def _dtype_again(): + if self._dtype is None: + raise ValueError( + "dtype cannot be None, shape type is {}\n{}".format( + type(shape), shape)) + if self._dtype in (float, 'double', 'tensor(double)'): + self._dtype = numpy.float64 + elif self._dtype in ('float32', 'float', 'tensor(float)'): + self._dtype = numpy.float32 + elif self._dtype in (numpy.float16, 'float16', 'tensor(float16)'): + self._dtype = numpy.float16 + elif self._dtype in ('int32', 'tensor(int32)'): + self._dtype = numpy.int32 + elif self._dtype in (int, 'int', 'int64', 'tensor(int64)'): + self._dtype = numpy.int64 + elif self._dtype in (str, 'str', numpy.str_, 'tensor(str)'): + self._dtype = numpy.str_ + elif (hasattr(self._dtype, 'type') and self._dtype.type is numpy.string_): + pass + elif self._dtype in (bool, 'bool', numpy.bool_): + self._dtype = numpy.bool_ + elif self._dtype in (object, numpy.object_): + pass + elif self._dtype in (numpy.int8, 'int8', ): + self._dtype = numpy.int8 + elif self._dtype in (numpy.uint8, 'uint8', ): + self._dtype = numpy.uint8 + elif self._dtype in (numpy.int16, 'int16', ): + self._dtype = numpy.int16 + elif self._dtype in (numpy.uint16, 'uint16', ): + self._dtype = numpy.uint16 + elif self._dtype in (numpy.uint32, 'uint32', ): + self._dtype = numpy.uint32 + elif self._dtype in (numpy.uint64, 'uint64', ): + self._dtype = numpy.uint64 + elif self._dtype == "tensor({'kind': 'tensor', 'elem': 'float', 'shape': })": + self._dtype = numpy.float32 + elif self._dtype not in { + numpy.float32, numpy.float64, numpy.int32, numpy.int64, + numpy.str_, numpy.bool_, numpy.float16, None, + numpy.complex64, numpy.complex128, + 'map', 'sequence'}: + raise ValueError( # pragma: no cover + "dtype has an unexpected value: '{}'.".format(self._dtype)) + _dtype_again() + + def _shape_again(): + if self._shape is not None: + for i, a in enumerate(self._shape): + if not isinstance(a, DimensionObject): + raise TypeError( # pragma: no cover + 'Dimension {} has a wrong type {}'.format( + i, type(a))) + if use_n1: + sh = self._shape[0] if self._shape else None + if isinstance(sh, DimensionObject) and sh._dim is None: + sh._dim = 'n' + if self._shape is not None: + for s in self._shape: + if isinstance(s, int): + raise TypeError( # pragma: no cover + "Unexpected type int in shape %r." % self) + _shape_again() def reshape(self, shape): """ diff --git a/mlprodict/tools/code_helper.py b/mlprodict/tools/code_helper.py index ec9dbe9f6..68860c2e1 100644 --- a/mlprodict/tools/code_helper.py +++ b/mlprodict/tools/code_helper.py @@ -204,3 +204,12 @@ def make_callable(fct, obj, code, gl, debug): "{}\n{}\n----\n{}".format( fct, res.__defaults__, defs, "\n".join(lines), code)) # pylint: disable=E1101 return res + + +def print_code(code, begin=1): + """ + Returns the code with line number. + """ + rows = code.split("\n") + return "\n".join("%03d %s" % (i + begin, s) + for i, s in enumerate(rows)) diff --git a/mlprodict/tools/graphs.py b/mlprodict/tools/graphs.py index bf94a7e69..ba2fc5dc2 100644 --- a/mlprodict/tools/graphs.py +++ b/mlprodict/tools/graphs.py @@ -4,6 +4,7 @@ .. versionadded:: 0.7 """ +import pprint import numpy @@ -167,6 +168,11 @@ def __init__(self, v0, v1, edges): continue if a in v1 and b in v0: continue + if b in v1: + # One operator is missing one input. + # We add one. + self.v0[a] = BiGraph.A('ERROR') + continue raise ValueError( "Edges (%r, %r) not found among the vertices." % (a, b)) @@ -219,12 +225,20 @@ def order_vertices(self): for v in self.v1: order[v] = 0 modif = 1 + n_iter = 0 while modif > 0: modif = 0 for a, b in self.edges: if order[b] <= order[a]: order[b] = order[a] + 1 modif += 1 + n_iter += 1 + if n_iter > len(order): + break + if modif > 0: + raise RuntimeError( + "The graph has a cycle.\n%s" % pprint.pformat( + self.edges)) return order def adjacency_matrix(self): @@ -250,19 +264,20 @@ def adjacency_matrix(self): matrix[row_id[b], col_id[a]] = 1 return matrix, row, col - def display_structure(self, grid=5): + def display_structure(self, grid=5, distance=5): """ Creates a display structure which contains all the necessary steps to display a graph. :param grid: align text to this grid + :param distance: distance to the text :return: instance of @see cl AdjacencyGraphDisplay """ def adjust(c, way): if way == 1: - d = grid * ((c + grid * 2 - (grid // 2 + 1)) // grid) + d = grid * ((c + distance * 2 - (grid // 2 + 1)) // grid) else: - d = -grid * ((-c + grid * 2 - (grid // 2 + 1)) // grid) + d = -grid * ((-c + distance * 2 - (grid // 2 + 1)) // grid) return d matrix, row, col = self.adjacency_matrix() @@ -366,14 +381,17 @@ def onnx2bigraph(model_onnx, recursive=False): for o in model_onnx.graph.initializer: v0[o.name] = BiGraph.A('Init') for n in model_onnx.graph.node: - v1[n.name] = BiGraph.A(n.op_type) + nname = n.name if len(n.name) > 0 else "id%d" % id(n) + v1[nname] = BiGraph.A(n.op_type) for i, o in enumerate(n.input): c = str(i) if i < 10 else "+" - edges[o, n.name] = BiGraph.A('I%s' % c) + nname = n.name if len(n.name) > 0 else "id%d" % id(n) + edges[o, nname] = BiGraph.A('I%s' % c) for i, o in enumerate(n.output): c = str(i) if i < 10 else "+" if o not in v0: v0[o] = BiGraph.A('inout') - edges[n.name, o] = BiGraph.A('O%s' % c) + nname = n.name if len(n.name) > 0 else "id%d" % id(n) + edges[nname, o] = BiGraph.A('O%s' % c) return BiGraph(v0, v1, edges) diff --git a/requirements.txt b/requirements.txt index b8b8b6adc..99ae0ff4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,6 +45,6 @@ wheel xgboost # onnx -onnx>=1.9.0 +onnx==1.9.0 onnxruntime>=1.8.0 skl2onnx>=1.9.0