diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index bf650629e43fc..e6e0e6cf620dc 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -44,6 +44,13 @@ Changelog `drop` parameter was not reflected in `get_feature_names`. :pr:`13894` by :user:`James Myatt `. +:mod:`sklearn.tree` +................................ + +- |Fix| Fixed an issue with :func:`plot_tree` where it display + entropy calculations even for `gini` criterion in DecisionTreeClassifiers. + :pr:`13947` by :user:`Frank Hoang `. + :mod:`sklearn.utils.sparsefuncs` ................................ @@ -965,8 +972,8 @@ Baibak, daten-kieker, Denis Kataev, Didi Bar-Zev, Dillon Gardner, Dmitry Mottl, Dmitry Vukolov, Dougal J. Sutherland, Dowon, drewmjohnston, Dror Atariah, Edward J Brown, Ekaterina Krivich, Elizabeth Sander, Emmanuel Arias, Eric Chang, Eric Larson, Erich Schubert, esvhd, Falak, Feda Curic, Federico Caselli, -Fibinse Xavier`, Finn O'Shea, Gabriel Marzinotto, Gabriel Vacaliuc, Gabriele -Calvo, Gael Varoquaux, GauravAhlawat, Giuseppe Vettigli, Greg Gandenberger, +Frank Hoang, Fibinse Xavier`, Finn O'Shea, Gabriel Marzinotto, Gabriel Vacaliuc, +Gabriele Calvo, Gael Varoquaux, GauravAhlawat, Giuseppe Vettigli, Greg Gandenberger, Guillaume Fournier, Guillaume Lemaitre, Gustavo De Mari Pereira, Hanmin Qin, haroldfox, hhu-luqi, Hunter McGushion, Ian Sanders, JackLangerman, Jacopo Notarstefano, jakirkham, James Bourbeau, Jan Koch, Jan S, janvanrijn, Jarrod diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 636ef03689a79..c012632d455a5 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -547,16 +547,16 @@ def __init__(self, max_depth=None, feature_names=None, self.arrow_args = dict(arrowstyle="<-") - def _make_tree(self, node_id, et, depth=0): + def _make_tree(self, node_id, et, criterion, depth=0): # traverses _tree.Tree recursively, builds intermediate # "_reingold_tilford.Tree" object - name = self.node_to_str(et, node_id, criterion='entropy') + name = self.node_to_str(et, node_id, criterion=criterion) if (et.children_left[node_id] != _tree.TREE_LEAF and (self.max_depth is None or depth <= self.max_depth)): children = [self._make_tree(et.children_left[node_id], et, - depth=depth + 1), + criterion, depth=depth + 1), self._make_tree(et.children_right[node_id], et, - depth=depth + 1)] + criterion, depth=depth + 1)] else: return Tree(name, node_id) return Tree(name, node_id, *children) @@ -568,7 +568,8 @@ def export(self, decision_tree, ax=None): ax = plt.gca() ax.clear() ax.set_axis_off() - my_tree = self._make_tree(0, decision_tree.tree_) + my_tree = self._make_tree(0, decision_tree.tree_, + decision_tree.criterion) draw_tree = buchheim(my_tree) # important to make sure we're still diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index eed9be7bcb5d9..06ca9e446fdcc 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -399,12 +399,12 @@ def test_export_text(): assert export_text(reg, decimals=1, show_weights=True) == expected_report -def test_plot_tree(pyplot): +def test_plot_tree_entropy(pyplot): # mostly smoke tests - # Check correctness of export_graphviz + # Check correctness of export_graphviz for criterion = entropy clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2, - criterion="gini", + criterion="entropy", random_state=2) clf.fit(X, y) @@ -412,7 +412,26 @@ def test_plot_tree(pyplot): feature_names = ['first feat', 'sepal_width'] nodes = plot_tree(clf, feature_names=feature_names) assert len(nodes) == 3 - assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 0.5\n" + assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 1.0\n" "samples = 6\nvalue = [3, 3]") assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]" assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]" + + +def test_plot_tree_gini(pyplot): + # mostly smoke tests + # Check correctness of export_graphviz for criterion = gini + clf = DecisionTreeClassifier(max_depth=3, + min_samples_split=2, + criterion="gini", + random_state=2) + clf.fit(X, y) + + # Test export code + feature_names = ['first feat', 'sepal_width'] + nodes = plot_tree(clf, feature_names=feature_names) + assert len(nodes) == 3 + assert nodes[0].get_text() == ("first feat <= 0.0\ngini = 0.5\n" + "samples = 6\nvalue = [3, 3]") + assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]" + assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"