Skip to content

Commit

Permalink
FIX plot_tree now displays correct criterion (#13947)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhoang7 authored and jnothman committed May 27, 2019
1 parent db48ebc commit fa383a4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
11 changes: 9 additions & 2 deletions doc/whats_new/v0.21.rst
Expand Up @@ -44,6 +44,13 @@ Changelog
`drop` parameter was not reflected in `get_feature_names`. :pr:`13894`
by :user:`James Myatt <jamesmyatt>`.

:mod:`sklearn.tree`
................................

- |Fix| Fixed an issue with :func:`plot_tree` where it display
entropy calculations even for `gini` criterion in DecisionTreeClassifiers.
:pr:`13947` by :user:`Frank Hoang <fhoang7>`.

:mod:`sklearn.utils.sparsefuncs`
................................

Expand Down Expand Up @@ -965,8 +972,8 @@ Baibak, daten-kieker, Denis Kataev, Didi Bar-Zev, Dillon Gardner, Dmitry Mottl,
Dmitry Vukolov, Dougal J. Sutherland, Dowon, drewmjohnston, Dror Atariah,
Edward J Brown, Ekaterina Krivich, Elizabeth Sander, Emmanuel Arias, Eric
Chang, Eric Larson, Erich Schubert, esvhd, Falak, Feda Curic, Federico Caselli,
Fibinse Xavier`, Finn O'Shea, Gabriel Marzinotto, Gabriel Vacaliuc, Gabriele
Calvo, Gael Varoquaux, GauravAhlawat, Giuseppe Vettigli, Greg Gandenberger,
Frank Hoang, Fibinse Xavier`, Finn O'Shea, Gabriel Marzinotto, Gabriel Vacaliuc,
Gabriele Calvo, Gael Varoquaux, GauravAhlawat, Giuseppe Vettigli, Greg Gandenberger,
Guillaume Fournier, Guillaume Lemaitre, Gustavo De Mari Pereira, Hanmin Qin,
haroldfox, hhu-luqi, Hunter McGushion, Ian Sanders, JackLangerman, Jacopo
Notarstefano, jakirkham, James Bourbeau, Jan Koch, Jan S, janvanrijn, Jarrod
Expand Down
11 changes: 6 additions & 5 deletions sklearn/tree/export.py
Expand Up @@ -547,16 +547,16 @@ def __init__(self, max_depth=None, feature_names=None,

self.arrow_args = dict(arrowstyle="<-")

def _make_tree(self, node_id, et, depth=0):
def _make_tree(self, node_id, et, criterion, depth=0):
# traverses _tree.Tree recursively, builds intermediate
# "_reingold_tilford.Tree" object
name = self.node_to_str(et, node_id, criterion='entropy')
name = self.node_to_str(et, node_id, criterion=criterion)
if (et.children_left[node_id] != _tree.TREE_LEAF
and (self.max_depth is None or depth <= self.max_depth)):
children = [self._make_tree(et.children_left[node_id], et,
depth=depth + 1),
criterion, depth=depth + 1),
self._make_tree(et.children_right[node_id], et,
depth=depth + 1)]
criterion, depth=depth + 1)]
else:
return Tree(name, node_id)
return Tree(name, node_id, *children)
Expand All @@ -568,7 +568,8 @@ def export(self, decision_tree, ax=None):
ax = plt.gca()
ax.clear()
ax.set_axis_off()
my_tree = self._make_tree(0, decision_tree.tree_)
my_tree = self._make_tree(0, decision_tree.tree_,
decision_tree.criterion)
draw_tree = buchheim(my_tree)

# important to make sure we're still
Expand Down
27 changes: 23 additions & 4 deletions sklearn/tree/tests/test_export.py
Expand Up @@ -399,20 +399,39 @@ def test_export_text():
assert export_text(reg, decimals=1, show_weights=True) == expected_report


def test_plot_tree(pyplot):
def test_plot_tree_entropy(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz
# Check correctness of export_graphviz for criterion = entropy
clf = DecisionTreeClassifier(max_depth=3,
min_samples_split=2,
criterion="gini",
criterion="entropy",
random_state=2)
clf.fit(X, y)

# Test export code
feature_names = ['first feat', 'sepal_width']
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 0.5\n"
assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 1.0\n"
"samples = 6\nvalue = [3, 3]")
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"


def test_plot_tree_gini(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = gini
clf = DecisionTreeClassifier(max_depth=3,
min_samples_split=2,
criterion="gini",
random_state=2)
clf.fit(X, y)

# Test export code
feature_names = ['first feat', 'sepal_width']
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert nodes[0].get_text() == ("first feat <= 0.0\ngini = 0.5\n"
"samples = 6\nvalue = [3, 3]")
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"

0 comments on commit fa383a4

Please sign in to comment.