Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`71`: adds tools to compare two onnx graphs
* :pr:`61`: adds function to plot onnx model as graphs
* :pr:`60`: supports translation of local functions
* :pr:`59`: add methods to update nodes in GraphAPI
Expand Down
30 changes: 30 additions & 0 deletions _doc/api/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,33 @@ ExtendedReferenceEvaluator
++++++++++++++++++++++++++

.. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator
:members:

ResultType
++++++++++

.. autoclass:: onnx_array_api.reference.ResultType
:members:

ResultExecution
+++++++++++++++

.. autoclass:: onnx_array_api.reference.ResultExecution
:members:

YieldEvaluator
++++++++++++++

.. autoclass:: onnx_array_api.reference.YieldEvaluator
:members:

DistanceExecution
+++++++++++++++++

.. autoclass:: onnx_array_api.reference.DistanceExecution
:members:

compare_onnx_execution
++++++++++++++++++++++

.. autofunction:: onnx_array_api.reference.compare_onnx_execution
52 changes: 52 additions & 0 deletions _doc/command_lines.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
=============
command lines
=============

compare
=======

The function convers an onnx file into some code.

::

python -m compare -m1 model1.onnx -m2 model2.onnx -v 1

Output example::

[compare_onnx_execution] got 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 5 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 5 results
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 4 pairs
[compare_onnx_execution] done
= | INPUT float32 5x6 AAAA X | INPUT float32 5x6 AAAA X
= | INPUT float32 5x6 AAAA Y | INPUT float32 5x6 AAAA Y
= | RESULT float32 5x6 AABB Add res | RESULT float32 5x6 AABB Add res
= | RESULT float32 5x6 AAAA Cos Z | RESULT float32 5x6 AAAA Cos Z

.. runpython::

from onnx_array_api._command_lines_parser import get_parser_compare
get_parser_compare().print_help()

See function :func:`onnx_array_api.reference.compare_onnx_execution`.

translate
=========

The function convers an onnx file into some code.

::

python -m translate ...

Output example::

not yet ready

.. runpython::

from onnx_array_api._command_lines_parser import get_parser_translate
get_parser_translate().print_help()
68 changes: 68 additions & 0 deletions _doc/examples/plot_onnx_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""

.. _l-onnx-diff-example:

Compares the conversions of the same model with different options
=================================================================

The script compares two onnx models obtained with the same trained
scikit-learn models but converted with different options.

A model
+++++++
"""

from sklearn.mixture import GaussianMixture
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skl2onnx import to_onnx
from onnx_array_api.reference import compare_onnx_execution
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


data = load_iris()
X_train, X_test = train_test_split(data.data)
model = GaussianMixture()
model.fit(X_train)

#################################
# Conversion to onnx
# ++++++++++++++++++

onx = to_onnx(
model, X_train[:1], options={id(model): {"score_samples": True}}, target_opset=12
)

print(onnx_simple_text_plot(onx))

##################################
# Conversion to onnx without ReduceLogSumExp
# ++++++++++++++++++++++++++++++++++++++++++

onx2 = to_onnx(
model,
X_train[:1],
options={id(model): {"score_samples": True}},
black_op={"ReduceLogSumExp"},
target_opset=12,
)

print(onnx_simple_text_plot(onx2))


#############################################
# Differences
# +++++++++++
#
# Function :func:`onnx_array_api.reference.compare_onnx_execution`
# compares the intermediate results of two onnx models. Then it finds
# the best alignmet between the two models using an edit distance.

res1, res2, align, dc = compare_onnx_execution(onx, onx2, verbose=1)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)

###############################
# The display shows that ReduceSumSquare was replaced by Mul + ReduceSum,
# and ReduceLogSumExp by ReduceMax + Sub + Exp + Log + Add.
1 change: 1 addition & 0 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The objective is to speed up the implementation of converter libraries.
tutorial/index
api/index
tech/index
command_lines
auto_examples/index

.. toctree::
Expand Down
1 change: 1 addition & 0 deletions _doc/tutorial/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ Tutorial
graph_api
light_api
numpy_api
tools
benchmarks
20 changes: 20 additions & 0 deletions _doc/tutorial/tools.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
=====
Tools
=====

Some of useful tools.

Text representation
===================

Plotting a graph is great but difficult to read when
the graph is big and it is slow.
:func:`onnx_array_api.plotting.text_plot.onnx_simple_text_plot`
prints out a text representation.

Differences between two models
==============================

How to understand the differences between two models
assuming they are producing the same outputs?
Example :ref:`l-onnx-diff-example` shows how to do it.
26 changes: 25 additions & 1 deletion _unittests/ut_reference/test_array_tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import unittest
import numpy as np
from onnx import TensorProto
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
from onnx.helper import (
make_graph,
make_model,
make_node,
make_tensor_value_info,
make_opsetid,
)
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.reference import (
to_array_extended,
Expand Down Expand Up @@ -51,6 +57,24 @@ def make_model_f8(fr, to):
back = from_array_extended(got, "a")
self.assertEqual(to, back.data_type)

def test_fused_matmul(self):
model = make_model(
make_graph(
[make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")],
"name",
[
make_tensor_value_info("X", TensorProto.FLOAT, None),
make_tensor_value_info("Y", TensorProto.FLOAT, None),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
)
ref = ExtendedReferenceEvaluator(model)
a = np.arange(4).reshape(-1, 2)
got = ref.run(None, {"X": a, "Y": a})
self.assertEqualArray(a @ a, got[0])


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading