diff --git a/_unittests/ut_cli/test_cli_onnx_stats.py b/_unittests/ut_cli/test_cli_onnx_stats.py index d847e2f87..f93dfdd76 100644 --- a/_unittests/ut_cli/test_cli_onnx_stats.py +++ b/_unittests/ut_cli/test_cli_onnx_stats.py @@ -8,8 +8,10 @@ from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression +from sklearn.exceptions import ConvergenceWarning from pyquickhelper.loghelper import BufferedPrint -from pyquickhelper.pycode import ExtTestCase, get_temp_folder +from pyquickhelper.pycode import ( + ExtTestCase, get_temp_folder, ignore_warnings) from mlprodict.__main__ import main from mlprodict.cli import convert_validate @@ -22,6 +24,7 @@ def test_cli_onnx_stats(self): res = str(st) self.assertIn("optim", res) + @ignore_warnings(ConvergenceWarning) def test_onnx_stats(self): iris = load_iris() X, y = iris.data, iris.target @@ -47,6 +50,24 @@ def test_onnx_stats(self): res = str(st) self.assertIn("ninits: 0", res) + st = BufferedPrint() + main(args=["onnx_stats", "--name", outonnx, '--kind', 'io'], + fLOG=st.fprint) + res = str(st) + self.assertIn("input: name", res) + + st = BufferedPrint() + main(args=["onnx_stats", "--name", outonnx, '--kind', 'text'], + fLOG=st.fprint) + res = str(st) + self.assertIn("input: name", res) + + st = BufferedPrint() + main(args=["onnx_stats", "--name", outonnx, '--kind', 'node'], + fLOG=st.fprint) + res = str(st) + self.assertIn("op_", res) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_plotting/test_text_plotting.py b/_unittests/ut_plotting/test_text_plotting.py index 57096680a..76fa3241d 100644 --- a/_unittests/ut_plotting/test_text_plotting.py +++ b/_unittests/ut_plotting/test_text_plotting.py @@ -16,7 +16,8 @@ from mlprodict.onnx_conv import to_onnx from mlprodict.tools.asv_options_helper import get_opset_number_from_onnx from mlprodict.plotting.plotting import ( - onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot) + onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot, + onnx_text_plot_io) class TestPlotTextPlotting(ExtTestCase): @@ -133,6 +134,18 @@ def test_onnx_simple_text_plot_leaky(self): """).strip(" \n") self.assertIn(expected, text) + def test_onnx_text_plot_io(self): + x = OnnxLeakyRelu('X', alpha=0.5, op_version=15, + output_names=['Y']) + onx = x.to_onnx({'X': FloatTensorType()}, + outputs={'Y': FloatTensorType()}, + target_opset=15) + text = onnx_text_plot_io(onx) + expected = textwrap.dedent(""" + input: + """).strip(" \n") + self.assertIn(expected, text) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_tools/test_optim_onnx_identity.py b/_unittests/ut_tools/test_optim_onnx_identity.py index 4c12572c6..76eba3457 100644 --- a/_unittests/ut_tools/test_optim_onnx_identity.py +++ b/_unittests/ut_tools/test_optim_onnx_identity.py @@ -41,6 +41,10 @@ def test_onnx_remove_identities(self): self.assertIn('subgraphs', stats) self.assertGreater(stats['subgraphs'], 1) self.assertGreater(stats['op_Identity'], 2) + stats = onnx_statistics(model_def, optim=False, node_type=True) + self.assertIn('subgraphs', stats) + self.assertGreater(stats['subgraphs'], 1) + self.assertGreater(stats['op_Identity'], 2) new_model = onnx_remove_node_identity(model_def) stats2 = onnx_statistics(new_model, optim=False) diff --git a/mlprodict/cli/optimize.py b/mlprodict/cli/optimize.py index 6cae7206a..74c750070 100644 --- a/mlprodict/cli/optimize.py +++ b/mlprodict/cli/optimize.py @@ -6,12 +6,17 @@ import onnx -def onnx_stats(name, optim=False): +def onnx_stats(name, optim=False, kind=None): """ Computes statistics on an ONNX model. :param name: filename :param optim: computes statistics before an after optimisation was done + :param kind: kind of statistics, if left unknown, + prints out the metadata, possible values: + * `io`: prints input and output name, type, shapes + * `node`: prints the distribution of node types + * `text`: printts a text summary .. cmdref:: :title: Computes statistics on an ONNX graph @@ -20,13 +25,24 @@ def onnx_stats(name, optim=False): The command computes statistics on an ONNX model. """ - from ..onnx_tools.optim import onnx_statistics if not os.path.exists(name): raise FileNotFoundError( # pragma: no cover "Unable to find file '{}'.".format(name)) with open(name, 'rb') as f: model = onnx.load(f) - return onnx_statistics(model, optim=optim) + if kind in (None, ""): + from ..onnx_tools.optim import onnx_statistics + return onnx_statistics(model, optim=optim) + if kind == 'text': + from ..plotting.plotting import onnx_simple_text_plot + return onnx_simple_text_plot(model) + if kind == 'io': + from ..plotting.plotting import onnx_text_plot_io + return onnx_text_plot_io(model) + if kind == 'node': + from ..onnx_tools.optim import onnx_statistics + return onnx_statistics(model, optim=optim, node_type=True) + raise ValueError("Unexpected kind=%r." % kind) def onnx_optim(name, outfile=None, recursive=True, options=None, verbose=0, fLOG=None): diff --git a/mlprodict/onnx_tools/optim/onnx_helper.py b/mlprodict/onnx_tools/optim/onnx_helper.py index 7cd0cd816..706899c98 100644 --- a/mlprodict/onnx_tools/optim/onnx_helper.py +++ b/mlprodict/onnx_tools/optim/onnx_helper.py @@ -10,16 +10,17 @@ from .onnx_optimisation import onnx_remove_node -def onnx_statistics(onnx_model, recursive=True, optim=True): +def onnx_statistics(onnx_model, recursive=True, optim=True, node_type=False): """ Computes statistics on :epkg:`ONNX` models, extracts informations about the model such as the number of nodes. - @param onnx_model onnx model - @param recursive looks into subgraphs - @param optim adds statistics because of optimisation - @return dictionary + :param onnx_model: onnx model + :param recursive: looks into subgraphs + :param optim: adds statistics because of optimisation + :param node_type: add distribution of node types + :return: dictionary .. runpython:: :showcode: @@ -98,9 +99,13 @@ def update(sts, st): # Number of identities counts = Counter(map(lambda obj: obj.op_type, graph.node)) - for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']: - if op in counts: - stats['op_' + op] = counts[op] + if node_type: + for op, v in counts.items(): + stats['op_' + op] = v + else: + for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']: + if op in counts: + stats['op_' + op] = counts[op] # Recursive if recursive: @@ -110,7 +115,8 @@ def update(sts, st): for att in node.attribute: if att.name != 'body': continue - substats = onnx_statistics(att.g, recursive=True, optim=False) + substats = onnx_statistics( + att.g, recursive=True, optim=False, node_type=node_type) update(stats, {'subgraphs': 1}) update(stats, substats) diff --git a/mlprodict/plotting/plotting.py b/mlprodict/plotting/plotting.py index e9f087819..be26854cf 100644 --- a/mlprodict/plotting/plotting.py +++ b/mlprodict/plotting/plotting.py @@ -4,7 +4,9 @@ @brief Shorcuts to plotting functions. """ -from .text_plot import onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot +from .text_plot import ( + onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot, + onnx_text_plot_io) from .plotting_benchmark import plot_benchmark_metrics from .plotting_validate_graph import plot_validate_benchmark from .plotting_onnx import plot_onnx diff --git a/mlprodict/plotting/text_plot.py b/mlprodict/plotting/text_plot.py index 8e1d88cf8..52c8acb48 100644 --- a/mlprodict/plotting/text_plot.py +++ b/mlprodict/plotting/text_plot.py @@ -320,6 +320,61 @@ def _find_sequence(node_name, known, done): return new_nodes +def _get_type(obj0): + obj = obj0 + if hasattr(obj, 'data_type'): + if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 + hasattr(obj, 'float_data')): + return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101 + if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 + hasattr(obj, 'double_data')): + return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101 + if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 + hasattr(obj, 'int64_data')): + return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101 + raise RuntimeError( + "Unable to guess type from %r." % obj0) + if hasattr(obj, 'type'): + obj = obj.type + if hasattr(obj, 'tensor_type'): + obj = obj.tensor_type + if hasattr(obj, 'elem_type'): + return TENSOR_TYPE_TO_NP_TYPE.get(obj.elem_type, '?') + raise RuntimeError( + "Unable to guess type from %r." % obj0) + + +def _get_shape(obj): + obj0 = obj + if hasattr(obj, 'data_type'): + if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 + hasattr(obj, 'float_data')): + return (len(obj.float_data), ) + if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 + hasattr(obj, 'double_data')): + return (len(obj.double_data), ) + if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 + hasattr(obj, 'int64_data')): + return (len(obj.int64_data), ) + raise RuntimeError( + "Unable to guess type from %r." % obj0) + if hasattr(obj, 'type'): + obj = obj.type + if hasattr(obj, 'tensor_type'): + obj = obj.tensor_type + if hasattr(obj, 'shape'): + obj = obj.shape + dims = [] + for d in obj.dim: + if hasattr(d, 'dim_value'): + dims.append(d.dim_value) + else: + dims.append(None) + return tuple(dims) + raise RuntimeError( + "Unable to guess type from %r." % obj0) + + def onnx_simple_text_plot(model, verbose=False, att_display=None): """ Displays an ONNX graph into text. @@ -392,59 +447,6 @@ def onnx_simple_text_plot(model, verbose=False, att_display=None): 'transB', ] - def get_type(obj0): - obj = obj0 - if hasattr(obj, 'data_type'): - if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 - hasattr(obj, 'float_data')): - return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101 - if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 - hasattr(obj, 'double_data')): - return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101 - if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 - hasattr(obj, 'int64_data')): - return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101 - raise RuntimeError( - "Unable to guess type from %r." % obj0) - if hasattr(obj, 'type'): - obj = obj.type - if hasattr(obj, 'tensor_type'): - obj = obj.tensor_type - if hasattr(obj, 'elem_type'): - return TENSOR_TYPE_TO_NP_TYPE[obj.elem_type] - raise RuntimeError( - "Unable to guess type from %r." % obj0) - - def get_shape(obj): - obj0 = obj - if hasattr(obj, 'data_type'): - if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 - hasattr(obj, 'float_data')): - return (len(obj.float_data), ) - if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 - hasattr(obj, 'double_data')): - return (len(obj.double_data), ) - if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 - hasattr(obj, 'int64_data')): - return (len(obj.int64_data), ) - raise RuntimeError( - "Unable to guess type from %r." % obj0) - if hasattr(obj, 'type'): - obj = obj.type - if hasattr(obj, 'tensor_type'): - obj = obj.tensor_type - if hasattr(obj, 'shape'): - obj = obj.shape - dims = [] - for d in obj.dim: - if hasattr(d, 'dim_value'): - dims.append(d.dim_value) - else: - dims.append(None) - return tuple(dims) - raise RuntimeError( - "Unable to guess type from %r." % obj0) - def str_node(indent, node): atts = [] if hasattr(node, 'attribute'): @@ -475,11 +477,11 @@ def str_node(indent, node): # inputs for inp in model.input: rows.append("input: name=%r type=%r shape=%r" % ( - inp.name, get_type(inp), get_shape(inp))) + inp.name, _get_type(inp), _get_shape(inp))) # initializer for init in model.initializer: rows.append("init: name=%r type=%r shape=%r" % ( - init.name, get_type(init), get_shape(init))) + init.name, _get_type(init), _get_shape(init))) # successors, predecessors successors = {} @@ -558,5 +560,56 @@ def str_node(indent, node): # outputs for out in model.output: rows.append("output: name=%r type=%r shape=%r" % ( - out.name, get_type(out), get_shape(out))) + out.name, _get_type(out), _get_shape(out))) + return "\n".join(rows) + + +def onnx_text_plot_io(model, verbose=False, att_display=None): + """ + Displays information about input and output types. + + :param model: ONNX graph + :param verbose: display debugging information + :return: str + + An ONNX graph is printed the following way: + + .. runpython:: + :showcode: + :warningout: DeprecationWarning + + import numpy + from sklearn.cluster import KMeans + from mlprodict.plotting.plotting import onnx_text_plot_io + from mlprodict.onnx_conv import to_onnx + + x = numpy.random.randn(10, 3) + y = numpy.random.randn(10) + model = KMeans(3) + model.fit(x, y) + onx = to_onnx(model, x.astype(numpy.float32), + target_opset=15) + text = onnx_text_plot_io(onx, verbose=False) + print(text) + """ + rows = [] + if hasattr(model, 'opset_import'): + for opset in model.opset_import: + rows.append("opset: domain=%r version=%r" % ( + opset.domain, opset.version)) + if hasattr(model, 'graph'): + model = model.graph + + # inputs + for inp in model.input: + rows.append("input: name=%r type=%r shape=%r" % ( + inp.name, _get_type(inp), _get_shape(inp))) + # initializer + for init in model.initializer: + rows.append("init: name=%r type=%r shape=%r" % ( + init.name, _get_type(init), _get_shape(init))) + # outputs + for out in model.output: + rows.append("output: name=%r type=%r shape=%r" % ( + out.name, _get_type(out), _get_shape(out))) return "\n".join(rows)