Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion _unittests/ut_cli/test_cli_onnx_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from pyquickhelper.loghelper import BufferedPrint
from pyquickhelper.pycode import ExtTestCase, get_temp_folder
from pyquickhelper.pycode import (
ExtTestCase, get_temp_folder, ignore_warnings)
from mlprodict.__main__ import main
from mlprodict.cli import convert_validate

Expand All @@ -22,6 +24,7 @@ def test_cli_onnx_stats(self):
res = str(st)
self.assertIn("optim", res)

@ignore_warnings(ConvergenceWarning)
def test_onnx_stats(self):
iris = load_iris()
X, y = iris.data, iris.target
Expand All @@ -47,6 +50,24 @@ def test_onnx_stats(self):
res = str(st)
self.assertIn("ninits: 0", res)

st = BufferedPrint()
main(args=["onnx_stats", "--name", outonnx, '--kind', 'io'],
fLOG=st.fprint)
res = str(st)
self.assertIn("input: name", res)

st = BufferedPrint()
main(args=["onnx_stats", "--name", outonnx, '--kind', 'text'],
fLOG=st.fprint)
res = str(st)
self.assertIn("input: name", res)

st = BufferedPrint()
main(args=["onnx_stats", "--name", outonnx, '--kind', 'node'],
fLOG=st.fprint)
res = str(st)
self.assertIn("op_", res)


if __name__ == "__main__":
unittest.main()
15 changes: 14 additions & 1 deletion _unittests/ut_plotting/test_text_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from mlprodict.onnx_conv import to_onnx
from mlprodict.tools.asv_options_helper import get_opset_number_from_onnx
from mlprodict.plotting.plotting import (
onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot)
onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot,
onnx_text_plot_io)


class TestPlotTextPlotting(ExtTestCase):
Expand Down Expand Up @@ -133,6 +134,18 @@ def test_onnx_simple_text_plot_leaky(self):
""").strip(" \n")
self.assertIn(expected, text)

def test_onnx_text_plot_io(self):
x = OnnxLeakyRelu('X', alpha=0.5, op_version=15,
output_names=['Y'])
onx = x.to_onnx({'X': FloatTensorType()},
outputs={'Y': FloatTensorType()},
target_opset=15)
text = onnx_text_plot_io(onx)
expected = textwrap.dedent("""
input:
""").strip(" \n")
self.assertIn(expected, text)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions _unittests/ut_tools/test_optim_onnx_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def test_onnx_remove_identities(self):
self.assertIn('subgraphs', stats)
self.assertGreater(stats['subgraphs'], 1)
self.assertGreater(stats['op_Identity'], 2)
stats = onnx_statistics(model_def, optim=False, node_type=True)
self.assertIn('subgraphs', stats)
self.assertGreater(stats['subgraphs'], 1)
self.assertGreater(stats['op_Identity'], 2)

new_model = onnx_remove_node_identity(model_def)
stats2 = onnx_statistics(new_model, optim=False)
Expand Down
22 changes: 19 additions & 3 deletions mlprodict/cli/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
import onnx


def onnx_stats(name, optim=False):
def onnx_stats(name, optim=False, kind=None):
"""
Computes statistics on an ONNX model.

:param name: filename
:param optim: computes statistics before an after optimisation was done
:param kind: kind of statistics, if left unknown,
prints out the metadata, possible values:
* `io`: prints input and output name, type, shapes
* `node`: prints the distribution of node types
* `text`: printts a text summary

.. cmdref::
:title: Computes statistics on an ONNX graph
Expand All @@ -20,13 +25,24 @@ def onnx_stats(name, optim=False):

The command computes statistics on an ONNX model.
"""
from ..onnx_tools.optim import onnx_statistics
if not os.path.exists(name):
raise FileNotFoundError( # pragma: no cover
"Unable to find file '{}'.".format(name))
with open(name, 'rb') as f:
model = onnx.load(f)
return onnx_statistics(model, optim=optim)
if kind in (None, ""):
from ..onnx_tools.optim import onnx_statistics
return onnx_statistics(model, optim=optim)
if kind == 'text':
from ..plotting.plotting import onnx_simple_text_plot
return onnx_simple_text_plot(model)
if kind == 'io':
from ..plotting.plotting import onnx_text_plot_io
return onnx_text_plot_io(model)
if kind == 'node':
from ..onnx_tools.optim import onnx_statistics
return onnx_statistics(model, optim=optim, node_type=True)
raise ValueError("Unexpected kind=%r." % kind)


def onnx_optim(name, outfile=None, recursive=True, options=None, verbose=0, fLOG=None):
Expand Down
24 changes: 15 additions & 9 deletions mlprodict/onnx_tools/optim/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@
from .onnx_optimisation import onnx_remove_node


def onnx_statistics(onnx_model, recursive=True, optim=True):
def onnx_statistics(onnx_model, recursive=True, optim=True, node_type=False):
"""
Computes statistics on :epkg:`ONNX` models,
extracts informations about the model such as
the number of nodes.

@param onnx_model onnx model
@param recursive looks into subgraphs
@param optim adds statistics because of optimisation
@return dictionary
:param onnx_model: onnx model
:param recursive: looks into subgraphs
:param optim: adds statistics because of optimisation
:param node_type: add distribution of node types
:return: dictionary

.. runpython::
:showcode:
Expand Down Expand Up @@ -98,9 +99,13 @@ def update(sts, st):

# Number of identities
counts = Counter(map(lambda obj: obj.op_type, graph.node))
for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']:
if op in counts:
stats['op_' + op] = counts[op]
if node_type:
for op, v in counts.items():
stats['op_' + op] = v
else:
for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']:
if op in counts:
stats['op_' + op] = counts[op]

# Recursive
if recursive:
Expand All @@ -110,7 +115,8 @@ def update(sts, st):
for att in node.attribute:
if att.name != 'body':
continue
substats = onnx_statistics(att.g, recursive=True, optim=False)
substats = onnx_statistics(
att.g, recursive=True, optim=False, node_type=node_type)
update(stats, {'subgraphs': 1})
update(stats, substats)

Expand Down
4 changes: 3 additions & 1 deletion mlprodict/plotting/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
@brief Shorcuts to plotting functions.
"""

from .text_plot import onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot
from .text_plot import (
onnx_text_plot, onnx_text_plot_tree, onnx_simple_text_plot,
onnx_text_plot_io)
from .plotting_benchmark import plot_benchmark_metrics
from .plotting_validate_graph import plot_validate_benchmark
from .plotting_onnx import plot_onnx
165 changes: 109 additions & 56 deletions mlprodict/plotting/text_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,61 @@ def _find_sequence(node_name, known, done):
return new_nodes


def _get_type(obj0):
obj = obj0
if hasattr(obj, 'data_type'):
if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
hasattr(obj, 'float_data')):
return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101
if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
hasattr(obj, 'double_data')):
return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101
if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
hasattr(obj, 'int64_data')):
return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101
raise RuntimeError(
"Unable to guess type from %r." % obj0)
if hasattr(obj, 'type'):
obj = obj.type
if hasattr(obj, 'tensor_type'):
obj = obj.tensor_type
if hasattr(obj, 'elem_type'):
return TENSOR_TYPE_TO_NP_TYPE.get(obj.elem_type, '?')
raise RuntimeError(
"Unable to guess type from %r." % obj0)


def _get_shape(obj):
obj0 = obj
if hasattr(obj, 'data_type'):
if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
hasattr(obj, 'float_data')):
return (len(obj.float_data), )
if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
hasattr(obj, 'double_data')):
return (len(obj.double_data), )
if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
hasattr(obj, 'int64_data')):
return (len(obj.int64_data), )
raise RuntimeError(
"Unable to guess type from %r." % obj0)
if hasattr(obj, 'type'):
obj = obj.type
if hasattr(obj, 'tensor_type'):
obj = obj.tensor_type
if hasattr(obj, 'shape'):
obj = obj.shape
dims = []
for d in obj.dim:
if hasattr(d, 'dim_value'):
dims.append(d.dim_value)
else:
dims.append(None)
return tuple(dims)
raise RuntimeError(
"Unable to guess type from %r." % obj0)


def onnx_simple_text_plot(model, verbose=False, att_display=None):
"""
Displays an ONNX graph into text.
Expand Down Expand Up @@ -392,59 +447,6 @@ def onnx_simple_text_plot(model, verbose=False, att_display=None):
'transB',
]

def get_type(obj0):
obj = obj0
if hasattr(obj, 'data_type'):
if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
hasattr(obj, 'float_data')):
return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101
if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
hasattr(obj, 'double_data')):
return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101
if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
hasattr(obj, 'int64_data')):
return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101
raise RuntimeError(
"Unable to guess type from %r." % obj0)
if hasattr(obj, 'type'):
obj = obj.type
if hasattr(obj, 'tensor_type'):
obj = obj.tensor_type
if hasattr(obj, 'elem_type'):
return TENSOR_TYPE_TO_NP_TYPE[obj.elem_type]
raise RuntimeError(
"Unable to guess type from %r." % obj0)

def get_shape(obj):
obj0 = obj
if hasattr(obj, 'data_type'):
if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
hasattr(obj, 'float_data')):
return (len(obj.float_data), )
if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
hasattr(obj, 'double_data')):
return (len(obj.double_data), )
if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
hasattr(obj, 'int64_data')):
return (len(obj.int64_data), )
raise RuntimeError(
"Unable to guess type from %r." % obj0)
if hasattr(obj, 'type'):
obj = obj.type
if hasattr(obj, 'tensor_type'):
obj = obj.tensor_type
if hasattr(obj, 'shape'):
obj = obj.shape
dims = []
for d in obj.dim:
if hasattr(d, 'dim_value'):
dims.append(d.dim_value)
else:
dims.append(None)
return tuple(dims)
raise RuntimeError(
"Unable to guess type from %r." % obj0)

def str_node(indent, node):
atts = []
if hasattr(node, 'attribute'):
Expand Down Expand Up @@ -475,11 +477,11 @@ def str_node(indent, node):
# inputs
for inp in model.input:
rows.append("input: name=%r type=%r shape=%r" % (
inp.name, get_type(inp), get_shape(inp)))
inp.name, _get_type(inp), _get_shape(inp)))
# initializer
for init in model.initializer:
rows.append("init: name=%r type=%r shape=%r" % (
init.name, get_type(init), get_shape(init)))
init.name, _get_type(init), _get_shape(init)))

# successors, predecessors
successors = {}
Expand Down Expand Up @@ -558,5 +560,56 @@ def str_node(indent, node):
# outputs
for out in model.output:
rows.append("output: name=%r type=%r shape=%r" % (
out.name, get_type(out), get_shape(out)))
out.name, _get_type(out), _get_shape(out)))
return "\n".join(rows)


def onnx_text_plot_io(model, verbose=False, att_display=None):
"""
Displays information about input and output types.

:param model: ONNX graph
:param verbose: display debugging information
:return: str

An ONNX graph is printed the following way:

.. runpython::
:showcode:
:warningout: DeprecationWarning

import numpy
from sklearn.cluster import KMeans
from mlprodict.plotting.plotting import onnx_text_plot_io
from mlprodict.onnx_conv import to_onnx

x = numpy.random.randn(10, 3)
y = numpy.random.randn(10)
model = KMeans(3)
model.fit(x, y)
onx = to_onnx(model, x.astype(numpy.float32),
target_opset=15)
text = onnx_text_plot_io(onx, verbose=False)
print(text)
"""
rows = []
if hasattr(model, 'opset_import'):
for opset in model.opset_import:
rows.append("opset: domain=%r version=%r" % (
opset.domain, opset.version))
if hasattr(model, 'graph'):
model = model.graph

# inputs
for inp in model.input:
rows.append("input: name=%r type=%r shape=%r" % (
inp.name, _get_type(inp), _get_shape(inp)))
# initializer
for init in model.initializer:
rows.append("init: name=%r type=%r shape=%r" % (
init.name, _get_type(init), _get_shape(init)))
# outputs
for out in model.output:
rows.append("output: name=%r type=%r shape=%r" % (
out.name, _get_type(out), _get_shape(out)))
return "\n".join(rows)