Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] EHN add decimals parameter for export_graphviz #8698

Merged
merged 7 commits into from Apr 29, 2017

Conversation

@glemaitre
Copy link
Contributor

@glemaitre glemaitre commented Apr 4, 2017

Reference Issue

Fixes #8662

What does this implement/fix? Explain your changes.

Introduce a parameter decimals to control decimal precision when displaying threshold, impurity, and value parameters in export_graphviz.

Any other comments?

@@ -397,6 +405,14 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
return_string = True
out_file = six.StringIO()

if isinstance(decimals, Integral):
if decimals < 0:
raise ValueError("'decimal' should be greater or equal to 0."

This comment has been minimized.

@MechCoder

MechCoder Apr 21, 2017
Member

nitpick: This should be decimal(s)

dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, decimals=decimals)
# check value
for finding in finditer("nvalue = \d+\.\d+", dot_data.getvalue()):

This comment has been minimized.

@MechCoder

MechCoder Apr 21, 2017
Member

Why (n) value?

This comment has been minimized.

@glemaitre

glemaitre Apr 22, 2017
Author Contributor

my mistake, I did not notice it was a \n

clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
max_depth=1)
clf.fit(X, y_reg)
for decimals, nb_len in zip((4, 3), (5, 4)):

This comment has been minimized.

@MechCoder

MechCoder Apr 21, 2017
Member

Sorry for being dumb, but what is nb_len? Or why are you checking that there should be 3 decimals when there can be a maximum of 4.

This comment has been minimized.

@glemaitre

glemaitre Apr 22, 2017
Author Contributor

this is the opposite. The length will be decimals + 1 since I will try to match a .. But this is true that nb_len is useless.

export_graphviz(clf, out_file=dot_data, decimals=decimals)
# check value
for finding in finditer("nvalue = \d+\.\d+", dot_data.getvalue()):
assert_equal(len(search("\.\d+", finding.group()).group()), nb_len)

This comment has been minimized.

@MechCoder

MechCoder Apr 21, 2017
Member

I'm pretty sure this should be assert_less(...., decimals)

This comment has been minimized.

@glemaitre

glemaitre Apr 22, 2017
Author Contributor

It could be less_equal. But with this random seed, we should get exactly decimals + 1. Making it less_equal can make the test too much permissive

y_reg = rng.random_sample((5, ))

# regression case
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0,

This comment has been minimized.

@MechCoder

MechCoder Apr 21, 2017
Member

Can you add a test for a classifer as well?

This comment has been minimized.

glemaitre added 2 commits Apr 22, 2017
Copy link
Member

@MechCoder MechCoder left a comment

That's it from me.

clf.fit(X, y_reg)
for decimals in (4, 3):
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, decimals=decimals)

This comment has been minimized.

@MechCoder

MechCoder Apr 22, 2017
Member

You can use out_file=None directly since #7390

# check impurity
for finding in finditer("friedman_mse = \d+\.\d+",
dot_data.getvalue()):
assert_equal(len(search("\.\d+", finding.group()).group()),

This comment has been minimized.

@MechCoder

MechCoder Apr 22, 2017
Member

I agree that assert_less makes this more permissive. In that case, this warrants a comment that this is particular for this random seed as you mention or you can use toy data where you know that the number of decimals in impurity etc are greater than the provided value.

clf.fit(X, y_cla)
for decimals in (4, 3):
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, decimals=decimals)

This comment has been minimized.

@MechCoder

MechCoder Apr 22, 2017
Member

Setting proportion=True over here might be a good idea.

@MechCoder
Copy link
Member

@MechCoder MechCoder commented Apr 22, 2017

LGTM pending whatsnew

@glemaitre
Copy link
Contributor Author

@glemaitre glemaitre commented Apr 23, 2017

@MechCoder Done

Copy link
Member

@MechCoder MechCoder left a comment

LGTM

@MechCoder MechCoder changed the title [MRG] EHN add decimals parameter for export_graphviz [MRG+1] EHN add decimals parameter for export_graphviz Apr 23, 2017
@glemaitre
Copy link
Contributor Author

@glemaitre glemaitre commented Apr 25, 2017

@raghavrv ping Could you give a look?

Copy link
Member

@raghavrv raghavrv left a comment

Thanks!! A few questions. Otherwise LGTM

+1 for merge after refactoring the tests into a loop (if possible)

y_reg = rng.random_sample((5, ))

# regression case
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0,

This comment has been minimized.

@@ -72,7 +74,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
feature_names=None, class_names=None, label='all',
filled=False, leaves_parallel=False, impurity=True,
node_ids=False, proportion=False, rotate=False,
rounded=False, special_characters=False):
rounded=False, special_characters=False, decimals=3):

This comment has been minimized.

@raghavrv

raghavrv Apr 27, 2017
Member

Should it be precision or should we leave it as decimals to conform to the assert_* syntax?

@@ -142,6 +144,10 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
When set to ``False``, ignore special characters for PostScript
compatibility.
decimals : int, optional (default=3)
The number of decimal reported of the ``impurity``,

This comment has been minimized.

@raghavrv

raghavrv Apr 27, 2017
Member

Desired precision in the values of impurity, threshold and value attributes of each node.

assert_equal(len(search("\.\d+", finding.group()).group()),
decimals + 1)

# classification case

This comment has been minimized.

@raghavrv

raghavrv Apr 27, 2017
Member

Is it possible to make a loop for reg/classification?

@jmschrei
Copy link
Member

@jmschrei jmschrei commented Apr 28, 2017

I can see a case being made for precision since that's what numpy uses in set_printoptions. I am neutral on it, as this isn't a big enough feature to warrant nit-picking. Once the checks are fixed, I am +1 as well.

@glemaitre
Copy link
Contributor Author

@glemaitre glemaitre commented Apr 28, 2017

I first thought about decimals as in around, but the point of @jmschrei and @raghavrv is pertinent.

I made the changes

@jnothman
Copy link
Member

@jnothman jnothman commented Apr 29, 2017

@glemaitre
Copy link
Contributor Author

@glemaitre glemaitre commented Apr 29, 2017

@jnothman ok so definitely precision makes sense. Then, it should be OK to be merged.

@raghavrv raghavrv merged commit 00da9cc into scikit-learn:master Apr 29, 2017
5 checks passed
5 checks passed
ci/circleci Your tests passed on CircleCI!
Details
codecov/patch 100% of diff hit (target 95.52%)
Details
codecov/project 95.52% (+<.01%) compared to ee82c3f
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@raghavrv
Copy link
Member

@raghavrv raghavrv commented Apr 29, 2017

Thanks a lot... I'm merging this in!!

@ryanrozich
Copy link

@ryanrozich ryanrozich commented Apr 29, 2017

Thanks!

Sundrique added a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
NelleV added a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
AishwaryaRK added a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
6 participants