Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
39 changes: 37 additions & 2 deletions _unittests/ut_plotting/test_text_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_onnx_text_plot_tree_reg(self):
onx = to_onnx(clr, X)
res = onnx_text_plot_tree(onx.graph.node[0])
self.assertIn("treeid=0", res)
self.assertIn(" T y=", res)
self.assertIn(" +f", res)

def test_onnx_text_plot_tree_cls(self):
iris = load_iris()
Expand All @@ -62,9 +62,44 @@ def test_onnx_text_plot_tree_cls(self):
onx = to_onnx(clr, X)
res = onnx_text_plot_tree(onx.graph.node[0])
self.assertIn("treeid=0", res)
self.assertIn(" T y=", res)
self.assertIn(" +f 0:", res)
self.assertIn("n_classes=3", res)

def test_onnx_text_plot_tree_cls_2(self):
this = os.path.join(
os.path.dirname(__file__), "data", "onnx_text_plot_tree_cls_2.onnx"
)
with open(this, "rb") as f:
model_def = load(f)
res = onnx_text_plot_tree(model_def.graph.node[0])
self.assertIn("n_classes=3", res)
expected = textwrap.dedent(
"""
n_classes=3
n_trees=1
----
treeid=0
n X2 <= 2.4499998
-n X3 <= 1.75
-n X2 <= 4.85
-f 0:0 1:0 2:1
+n X0 <= 5.95
-f 0:0 1:0 2:1
+f 0:0 1:1 2:0
+n X2 <= 4.95
-n X3 <= 1.55
-n X0 <= 6.95
-f 0:0 1:0 2:1
+f 0:0 1:1 2:0
+f 0:0 1:0 2:1
+n X3 <= 1.65
-f 0:0 1:0 2:1
+f 0:0 1:1 2:0
+f 0:1 1:0 2:0
"""
).strip(" \n\r")
self.assertEqual(expected, res.strip(" \n\r"))

@ignore_warnings((UserWarning, FutureWarning))
def test_onnx_simple_text_plot_kmeans(self):
x = numpy.random.randn(10, 3)
Expand Down
46 changes: 34 additions & 12 deletions onnx_array_api/plotting/text_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def _rule(r):
raise ValueError(f"Unexpected rule {r!r}.")


def _number2str(i):
if isinstance(i, int):
return str(i)
if int(i) == i:
return str(int(i))
return f"{i:1.2f}"


def onnx_text_plot_tree(node):
"""
Gives a textual representation of a tree ensemble.
Expand Down Expand Up @@ -61,18 +69,32 @@ def __init__(self, i, atts):
setattr(self, k, v[i])
self.depth = 0
self.true_false = ""
self.targets = []

def append_target(self, tid, weight):
self.targets.append(dict(target_id=tid, weight=weight))

def process_node(self):
"node to string"
if self.nodes_modes == "LEAF":
text = "%s y=%r f=%r i=%r" % (
self.true_false,
self.target_weights,
self.target_ids,
self.target_nodeids,
)
if len(self.targets) == 0:
text = f"{self.true_false}f"
elif len(self.targets) == 1:
t = self.targets[0]
text = (
f"{self.true_false}f "
f"{t['target_id']}:{_number2str(t['weight'])}"
)
else:
ts = " ".join(
map(
lambda t: f"{t['target_id']}:{_number2str(t['weight'])}",
self.targets,
)
)
text = f"{self.true_false}f {ts}"
else:
text = "%s X%d %s %r" % (
text = "%sn X%d %s %r" % (
self.true_false,
self.nodes_featureids,
_rule(self.nodes_modes),
Expand Down Expand Up @@ -114,9 +136,9 @@ def process_tree(atts, treeid):
for i in range(len(short[f"{prefix}_treeids"])):
idn = short[f"{prefix}_nodeids"][i]
node = nodes[idn]
node.target_nodeids = idn
node.target_ids = short[f"{prefix}_ids"][i]
node.target_weights = short[f"{prefix}_weights"][i]
node.append_target(
tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i]
)

def iterate(nodes, node, depth=0, true_false=""):
node.depth = depth
Expand All @@ -127,14 +149,14 @@ def iterate(nodes, node, depth=0, true_false=""):
nodes,
nodes[node.nodes_falsenodeids],
depth=depth + 1,
true_false="F",
true_false="-",
):
yield n
for n in iterate(
nodes,
nodes[node.nodes_truenodeids],
depth=depth + 1,
true_false="T",
true_false="+",
):
yield n

Expand Down