From 8382877dae4e0014fce0808cba9a83156ec50e53 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 31 Mar 2023 19:41:50 +0200 Subject: [PATCH 1/2] improve tree plotting --- _unittests/ut_plotting/test_text_plot.py | 12 ++++++++++++ onnx_array_api/plotting/text_plot.py | 6 +++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_plotting/test_text_plot.py b/_unittests/ut_plotting/test_text_plot.py index 4ebdcc8..8f677aa 100644 --- a/_unittests/ut_plotting/test_text_plot.py +++ b/_unittests/ut_plotting/test_text_plot.py @@ -65,6 +65,18 @@ def test_onnx_text_plot_tree_cls(self): self.assertIn(" T y=", res) self.assertIn("n_classes=3", res) + def test_onnx_text_plot_tree_cls_2(self): + iris = load_iris() + X_train, y_train = iris.data.astype(numpy.float32), iris.target + clr = DecisionTreeClassifier() + clr.fit(X_train, y_train) + model_def = to_onnx( + clr, X_train.astype(numpy.float32), options={"zipmap": False} + ) + res = onnx_text_plot_tree(model_def.graph.node[0]) + self.assertIn("n_classes=3", res) + print(res) + @ignore_warnings((UserWarning, FutureWarning)) def test_onnx_simple_text_plot_kmeans(self): x = numpy.random.randn(10, 3) diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index 669b1ab..5dd84ca 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -114,9 +114,9 @@ def process_tree(atts, treeid): for i in range(len(short[f"{prefix}_treeids"])): idn = short[f"{prefix}_nodeids"][i] node = nodes[idn] - node.target_nodeids = idn - node.target_ids = short[f"{prefix}_ids"][i] - node.target_weights = short[f"{prefix}_weights"][i] + node.append_target( + id=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i] + ) def iterate(nodes, node, depth=0, true_false=""): node.depth = depth From c28aeb234bbaa7c2f80c815934cddb8fcaba862d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 1 Apr 2023 14:03:29 +0200 Subject: [PATCH 2/2] fix onnx_text_plot_tree --- .../data/onnx_text_plot_tree_cls_2.onnx | Bin 0 -> 1410 bytes _unittests/ut_plotting/test_text_plot.py | 41 +++++++++++++---- onnx_array_api/plotting/text_plot.py | 42 +++++++++++++----- 3 files changed, 64 insertions(+), 19 deletions(-) create mode 100644 _unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx diff --git a/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx b/_unittests/ut_plotting/data/onnx_text_plot_tree_cls_2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6efc78ecd49992e3840da68810251f19a0b54f4d GIT binary patch literal 1410 zcmb_cOKaOe5Z<*NTOL1R>riM&DEO56V4B$U5is63aR?zewDb^qF_N>2m&j7k%59JN z8>LWu>d)x0^pIQcrSz9{XC+%UsUd-C*wJb@asH4qn)5k^_+p?pen^g<{AT4u{oA`GmM&*9RC`dd{fIb{3xc9YE-<#(s;IA2)1+5EZKL5}l1{&bvQb7ghUo6XNmS$tcR5%pmCA4B=t zmm${Mla3#ZYnlF8zc;1nJ%|TgW9*unXZ`;9_8}h%DSQv7S~AU^zN_K{ECLwR%mRUw jk+G`~PecQ