diff --git a/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx b/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx new file mode 100644 index 0000000..6efc78e Binary files /dev/null and b/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx differ diff --git a/_unittests/ut_plotting/test_text_plot.py b/_unittests/ut_plotting/test_text_plot.py index 4ebdcc8..c505e7d 100644 --- a/_unittests/ut_plotting/test_text_plot.py +++ b/_unittests/ut_plotting/test_text_plot.py @@ -52,7 +52,7 @@ def test_onnx_text_plot_tree_reg(self): onx = to_onnx(clr, X) res = onnx_text_plot_tree(onx.graph.node[0]) self.assertIn("treeid=0", res) - self.assertIn(" T y=", res) + self.assertIn(" +f", res) def test_onnx_text_plot_tree_cls(self): iris = load_iris() @@ -62,9 +62,44 @@ def test_onnx_text_plot_tree_cls(self): onx = to_onnx(clr, X) res = onnx_text_plot_tree(onx.graph.node[0]) self.assertIn("treeid=0", res) - self.assertIn(" T y=", res) + self.assertIn(" +f 0:", res) self.assertIn("n_classes=3", res) + def test_onnx_text_plot_tree_cls_2(self): + this = os.path.join( + os.path.dirname(__file__), "data", "onnx_text_plot_tree_cls_2.onnx" + ) + with open(this, "rb") as f: + model_def = load(f) + res = onnx_text_plot_tree(model_def.graph.node[0]) + self.assertIn("n_classes=3", res) + expected = textwrap.dedent( + """ + n_classes=3 + n_trees=1 + ---- + treeid=0 + n X2 <= 2.4499998 + -n X3 <= 1.75 + -n X2 <= 4.85 + -f 0:0 1:0 2:1 + +n X0 <= 5.95 + -f 0:0 1:0 2:1 + +f 0:0 1:1 2:0 + +n X2 <= 4.95 + -n X3 <= 1.55 + -n X0 <= 6.95 + -f 0:0 1:0 2:1 + +f 0:0 1:1 2:0 + +f 0:0 1:0 2:1 + +n X3 <= 1.65 + -f 0:0 1:0 2:1 + +f 0:0 1:1 2:0 + +f 0:1 1:0 2:0 + """ + ).strip(" \n\r") + self.assertEqual(expected, res.strip(" \n\r")) + @ignore_warnings((UserWarning, FutureWarning)) def test_onnx_simple_text_plot_kmeans(self): x = numpy.random.randn(10, 3) diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index 669b1ab..dfb9be0 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -24,6 +24,14 @@ def _rule(r): raise ValueError(f"Unexpected rule {r!r}.") +def _number2str(i): + if isinstance(i, int): + return str(i) + if int(i) == i: + return str(int(i)) + return f"{i:1.2f}" + + def onnx_text_plot_tree(node): """ Gives a textual representation of a tree ensemble. @@ -61,18 +69,32 @@ def __init__(self, i, atts): setattr(self, k, v[i]) self.depth = 0 self.true_false = "" + self.targets = [] + + def append_target(self, tid, weight): + self.targets.append(dict(target_id=tid, weight=weight)) def process_node(self): "node to string" if self.nodes_modes == "LEAF": - text = "%s y=%r f=%r i=%r" % ( - self.true_false, - self.target_weights, - self.target_ids, - self.target_nodeids, - ) + if len(self.targets) == 0: + text = f"{self.true_false}f" + elif len(self.targets) == 1: + t = self.targets[0] + text = ( + f"{self.true_false}f " + f"{t['target_id']}:{_number2str(t['weight'])}" + ) + else: + ts = " ".join( + map( + lambda t: f"{t['target_id']}:{_number2str(t['weight'])}", + self.targets, + ) + ) + text = f"{self.true_false}f {ts}" else: - text = "%s X%d %s %r" % ( + text = "%sn X%d %s %r" % ( self.true_false, self.nodes_featureids, _rule(self.nodes_modes), @@ -114,9 +136,9 @@ def process_tree(atts, treeid): for i in range(len(short[f"{prefix}_treeids"])): idn = short[f"{prefix}_nodeids"][i] node = nodes[idn] - node.target_nodeids = idn - node.target_ids = short[f"{prefix}_ids"][i] - node.target_weights = short[f"{prefix}_weights"][i] + node.append_target( + tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i] + ) def iterate(nodes, node, depth=0, true_false=""): node.depth = depth @@ -127,14 +149,14 @@ def iterate(nodes, node, depth=0, true_false=""): nodes, nodes[node.nodes_falsenodeids], depth=depth + 1, - true_false="F", + true_false="-", ): yield n for n in iterate( nodes, nodes[node.nodes_truenodeids], depth=depth + 1, - true_false="T", + true_false="+", ): yield n