Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions _unittests/ut_onnxrt/test_onnxrt_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ def test_onnxt_dot(self):
self.assertIn('Ad_Addcst -> Ad_Add;', dot)
self.assertIn('Ad_Add1 -> Y;', dot)

def test_onnxt_text(self):
idi = numpy.identity(2)
idi2 = numpy.identity(2) * 2
onx = OnnxAdd(
OnnxAdd('X', idi, op_version=get_opset_number_from_onnx()),
idi2, output_names=['Y'],
op_version=get_opset_number_from_onnx())
model_def = onx.to_onnx({'X': idi.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
oinf = OnnxInference(model_def)
text = oinf.to_text()
self.assertIn('Init', text)
self.assertIn('Input-0', text)
self.assertIn('Output-0', text)
self.assertIn('inout', text)
self.assertIn('O0 I0', text)
self.assertIn('Ad_Addcst', text)

def test_onnxt_dot_onnx(self):
idi = numpy.identity(2)
idi2 = numpy.identity(2) * 2
Expand Down
91 changes: 91 additions & 0 deletions _unittests/ut_tools/test_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
"""
@brief test log(time=3s)
"""
import unittest
import numpy
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from pyquickhelper.pycode import ExtTestCase
from skl2onnx.algebra.onnx_ops import OnnxAdd, OnnxSub # pylint: disable=E0611
from mlprodict.onnx_conv import to_onnx
from mlprodict.tools import get_opset_number_from_onnx
from mlprodict.tools.graphs import onnx2bigraph, BiGraph


class TestGraphs(ExtTestCase):

def fit(self, model):
data = load_iris()
X, y = data.data, data.target
model.fit(X, y)
return model

def test_exc(self):
self.assertRaise(lambda: BiGraph([], [], []), TypeError)
self.assertRaise(lambda: BiGraph({}, [], []), TypeError)
self.assertRaise(lambda: BiGraph({}, {}, []), TypeError)
self.assertRaise(
lambda: BiGraph({'a': None}, {'b': None}, {('a', 'a'): None}),
ValueError)
self.assertRaise(
lambda: BiGraph({'a': None}, {'a': None}, {('a', 'a'): None}),
ValueError)

def test_pipe_graph(self):
model = self.fit(
make_pipeline(StandardScaler(), LogisticRegression()))
onx = to_onnx(model, numpy.zeros((3, 4), dtype=numpy.float64))
bigraph = onnx2bigraph(onx)
text = str(bigraph)
self.assertEqual(text, "BiGraph(19 v., 12 v., 30 edges)")
obj = list(bigraph)
self.assertEqual(len(obj), 61)
for o in obj:
self.assertEqual(len(o), 3)
self.assertIn(o[0], {-1, 0, 1})
self.assertIsInstance(o[1], (str, tuple))
self.assertStartsWith("A(", str(o[-1]))

def test_pipe_graph_order(self):
model = self.fit(
make_pipeline(StandardScaler(), LogisticRegression()))
onx = to_onnx(model, numpy.zeros((3, 4), dtype=numpy.float64))
bigraph = onnx2bigraph(onx)
order = bigraph.order_vertices()
self.assertEqual(len(order), 31)
self.assertIsInstance(order, dict)
for k in order:
self.assertIsInstance(bigraph[k], BiGraph.A)
ed = list(bigraph.edges)[0]
self.assertIsInstance(bigraph[ed], BiGraph.A)

def test_pipe_graph_display(self):
model = self.fit(
make_pipeline(StandardScaler(), LogisticRegression()))
onx = to_onnx(model, numpy.zeros((3, 4), dtype=numpy.float64))
bigraph = onnx2bigraph(onx)
graph = bigraph.display_structure()
text = str(graph)
self.assertIn("AdjacencyGraphDisplay(", text)
self.assertIn("Action(", text)

def test_pipe_graph_display_text(self):
idi = numpy.identity(2)
opv = get_opset_number_from_onnx()
A = OnnxAdd('X', idi, op_version=opv)
B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv)
onx = B.to_onnx({'X': idi.astype(numpy.float32),
'W': idi.astype(numpy.float32)})
bigraph = onnx2bigraph(onx)
graph = bigraph.display_structure()
text = graph.to_text()
for c in ['Input-1', 'Input-0', 'Output-0', 'W', 'W', 'I0', 'I1',
'inout', 'O0 I0', 'A S']:
self.assertIn(c, text)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions mlprodict/onnxrt/onnx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _init(self):
self.to_json = self.exporters_.to_json
self.to_dot = self.exporters_.to_dot
self.to_python = self.exporters_.to_python
self.to_text = self.exporters_.to_text

if self.runtime in ('python_compiled', 'python_compiled_debug'):
# switch the inference method to the compiled one
Expand Down
13 changes: 13 additions & 0 deletions mlprodict/onnxrt/onnx_inference_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import textwrap
from onnx import numpy_helper
from ..onnx_tools.onnx2py_helper import _var_as_dict, _type_to_string
from ..tools.graphs import onnx2bigraph


class OnnxInferenceExport:
Expand Down Expand Up @@ -583,3 +584,15 @@ def clean_args(args):
raise NotImplementedError( # pragma: no cover
"Unknown extension for file '{}'.".format(k))
return file_data

def to_text(self, recursive=False):
"""
It calls function @see fn onnx2bigraph to return
the ONNX graph as text.

:param recursive: dig into subgraphs too
:return: text
"""
bigraph = onnx2bigraph(self.oinf.obj, recursive=recursive)
graph = bigraph.display_structure()
return graph.to_text()
Loading