Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion _doc/examples/plot_op_reducesumsquare.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def torch_sum2(x, y):
df.pivot("fct", "N", "average")



###################################
# Reduction on a particular case RKRK
# +++++++++++++++++++++++++++++++++++
Expand Down
10 changes: 10 additions & 0 deletions _unittests/ut_plotting/test_text_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ def test_onnx_simple_text_plot_if(self):
text2 = oinf.to_text(kind="seq")
self.assertEqual(text, text2)

def test_onnx_simple_text_plot_kmeans_links(self):
x = numpy.random.randn(10, 3)
model = KMeans(3)
model.fit(x)
onx = to_onnx(model, x.astype(numpy.float32),
target_opset=15)
text = onnx_simple_text_plot(onx, add_links=True)
self.assertIn("Sqrt(Ad_C0) -> scores <------", text)
self.assertIn("|-|", text)


if __name__ == "__main__":
unittest.main()
86 changes: 85 additions & 1 deletion mlprodict/plotting/text_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,14 +383,16 @@ def _get_shape(obj):
"Unable to guess type from %r." % obj0)


def onnx_simple_text_plot(model, verbose=False, att_display=None):
def onnx_simple_text_plot(model, verbose=False, att_display=None,
add_links=False):
"""
Displays an ONNX graph into text.

:param model: ONNX graph
:param verbose: display debugging information
:param att_display: list of attributes to display, if None,
a default list if used
:param add_links: displays links of the right side
:return: str

An ONNX graph is printed the following way:
Expand All @@ -413,6 +415,26 @@ def onnx_simple_text_plot(model, verbose=False, att_display=None):
text = onnx_simple_text_plot(onx, verbose=False)
print(text)

The same graphs with links.

.. runpython::
:showcode:
:warningout: DeprecationWarning

import numpy
from sklearn.cluster import KMeans
from mlprodict.plotting.plotting import onnx_simple_text_plot
from mlprodict.onnx_conv import to_onnx

x = numpy.random.randn(10, 3)
y = numpy.random.randn(10)
model = KMeans(3)
model.fit(x, y)
onx = to_onnx(model, x.astype(numpy.float32),
target_opset=15)
text = onnx_simple_text_plot(onx, verbose=Falsen add_links=True)
print(text)

Visually, it looks like the following:

.. gdot::
Expand Down Expand Up @@ -483,7 +505,10 @@ def str_node(indent, node):
model = model.graph

# inputs
line_name_new = {}
line_name_in = {}
for inp in model.input:
line_name_new[inp.name] = len(rows)
rows.append("input: name=%r type=%r shape=%r" % (
inp.name, _get_type(inp), _get_shape(inp)))
# initializer
Expand All @@ -492,6 +517,7 @@ def str_node(indent, node):
content = " -- %r" % to_array(init).ravel()
else:
content = ""
line_name_new[init.name] = len(rows)
rows.append("init: name=%r type=%r shape=%r%s" % (
init.name, _get_type(init), _get_shape(init), content))

Expand Down Expand Up @@ -560,6 +586,13 @@ def str_node(indent, node):

if add_break and verbose:
print("[onnx_simple_text_plot] add break")
for n in node.input:
if n in line_name_in:
line_name_in[n].append(len(rows))
else:
line_name_in[n] = [len(rows)]
for n in node.output:
line_name_new[n] = len(rows)
rows.append(str_node(indent, node))
indents[name] = indent

Expand All @@ -572,8 +605,59 @@ def str_node(indent, node):

# outputs
for out in model.output:
if out.name in line_name_in:
line_name_in[out.name].append(len(rows))
else:
line_name_in[out.name] = [len(rows)]
rows.append("output: name=%r type=%r shape=%r" % (
out.name, _get_type(out), _get_shape(out)))

if add_links:

def _mark_link(rows, lengths, r1, r2, d):
maxl = max(lengths[r1], lengths[r2]) + d * 2
maxl = max(maxl, max(len(rows[r]) for r in range(r1, r2 + 1))) + 2

if rows[r1][-1] == '|':
p1, p2 = rows[r1][:lengths[r1] + 2], rows[r1][lengths[r1] + 2:]
rows[r1] = p1 + p2.replace(' ', '-')
rows[r1] += ("-" * (maxl - len(rows[r1]) - 1)) + "+"

if rows[r2][-1] == " ":
rows[r2] += "<"
elif rows[r2][-1] == '|':
if "<" not in rows[r2]:
p = lengths[r2]
rows[r2] = rows[r2][:p] + '<' + rows[r2][p + 1:]
p1, p2 = rows[r2][:lengths[r2] + 2], rows[r2][lengths[r2] + 2:]
rows[r2] = p1 + p2.replace(' ', '-')
rows[r2] += ("-" * (maxl - len(rows[r2]) - 1)) + "+"

for r in range(r1 + 1, r2):
if len(rows[r]) < maxl:
rows[r] += " " * (maxl - len(rows[r]) - 1)
rows[r] += "|"

diffs = []
for n, r1 in line_name_new.items():
if n not in line_name_in:
continue
r2s = line_name_in[n]
for r2 in r2s:
if r1 >= r2:
continue
diffs.append((r2 - r1, (n, r1, r2)))
diffs.sort()
for i in range(len(rows)): # pylint: disable=C0200
rows[i] += " "
lengths = [len(r) for r in rows]

for d, (n, r1, r2) in diffs:
if d == 1 and len(line_name_in[n]) == 1:
# no line for link to the next node
continue
_mark_link(rows, lengths, r1, r2, d)

return "\n".join(rows)


Expand Down