[MRG+1] EHN add decimals parameter for export_graphviz #8698
Conversation
@@ -397,6 +405,14 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): | |||
return_string = True | |||
out_file = six.StringIO() | |||
|
|||
if isinstance(decimals, Integral): | |||
if decimals < 0: | |||
raise ValueError("'decimal' should be greater or equal to 0." |
MechCoder
Apr 21, 2017
Member
nitpick: This should be decimal(s)
nitpick: This should be decimal(s)
dot_data = StringIO() | ||
export_graphviz(clf, out_file=dot_data, decimals=decimals) | ||
# check value | ||
for finding in finditer("nvalue = \d+\.\d+", dot_data.getvalue()): |
MechCoder
Apr 21, 2017
•
Member
Why (n) value
?
Why (n) value
?
glemaitre
Apr 22, 2017
Author
Contributor
my mistake, I did not notice it was a \n
my mistake, I did not notice it was a \n
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0, | ||
max_depth=1) | ||
clf.fit(X, y_reg) | ||
for decimals, nb_len in zip((4, 3), (5, 4)): |
MechCoder
Apr 21, 2017
Member
Sorry for being dumb, but what is nb_len
? Or why are you checking that there should be 3 decimals when there can be a maximum of 4.
Sorry for being dumb, but what is nb_len
? Or why are you checking that there should be 3 decimals when there can be a maximum of 4.
glemaitre
Apr 22, 2017
Author
Contributor
this is the opposite. The length will be decimals + 1
since I will try to match a .
. But this is true that nb_len
is useless.
this is the opposite. The length will be decimals + 1
since I will try to match a .
. But this is true that nb_len
is useless.
export_graphviz(clf, out_file=dot_data, decimals=decimals) | ||
# check value | ||
for finding in finditer("nvalue = \d+\.\d+", dot_data.getvalue()): | ||
assert_equal(len(search("\.\d+", finding.group()).group()), nb_len) |
MechCoder
Apr 21, 2017
Member
I'm pretty sure this should be assert_less(...., decimals)
I'm pretty sure this should be assert_less(...., decimals)
glemaitre
Apr 22, 2017
Author
Contributor
It could be less_equal. But with this random seed, we should get exactly decimals + 1
. Making it less_equal can make the test too much permissive
It could be less_equal. But with this random seed, we should get exactly decimals + 1
. Making it less_equal can make the test too much permissive
y_reg = rng.random_sample((5, )) | ||
|
||
# regression case | ||
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0, |
MechCoder
Apr 21, 2017
Member
Can you add a test for a classifer as well?
Can you add a test for a classifer as well?
raghavrv
Apr 27, 2017
Member
+1
+1
That's it from me. |
clf.fit(X, y_reg) | ||
for decimals in (4, 3): | ||
dot_data = StringIO() | ||
export_graphviz(clf, out_file=dot_data, decimals=decimals) |
# check impurity | ||
for finding in finditer("friedman_mse = \d+\.\d+", | ||
dot_data.getvalue()): | ||
assert_equal(len(search("\.\d+", finding.group()).group()), |
MechCoder
Apr 22, 2017
•
Member
I agree that assert_less
makes this more permissive. In that case, this warrants a comment that this is particular for this random seed as you mention or you can use toy data where you know that the number of decimals in impurity
etc are greater than the provided value.
I agree that assert_less
makes this more permissive. In that case, this warrants a comment that this is particular for this random seed as you mention or you can use toy data where you know that the number of decimals in impurity
etc are greater than the provided value.
clf.fit(X, y_cla) | ||
for decimals in (4, 3): | ||
dot_data = StringIO() | ||
export_graphviz(clf, out_file=dot_data, decimals=decimals) |
MechCoder
Apr 22, 2017
Member
Setting proportion=True
over here might be a good idea.
Setting proportion=True
over here might be a good idea.
LGTM pending whatsnew |
@MechCoder Done |
LGTM |
@raghavrv ping Could you give a look? |
Thanks!! A few questions. Otherwise LGTM +1 for merge after refactoring the tests into a loop (if possible) |
y_reg = rng.random_sample((5, )) | ||
|
||
# regression case | ||
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0, |
raghavrv
Apr 27, 2017
Member
+1
+1
@@ -72,7 +74,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, | |||
feature_names=None, class_names=None, label='all', | |||
filled=False, leaves_parallel=False, impurity=True, | |||
node_ids=False, proportion=False, rotate=False, | |||
rounded=False, special_characters=False): | |||
rounded=False, special_characters=False, decimals=3): |
raghavrv
Apr 27, 2017
Member
Should it be precision
or should we leave it as decimals
to conform to the assert_*
syntax?
Should it be precision
or should we leave it as decimals
to conform to the assert_*
syntax?
@@ -142,6 +144,10 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, | |||
When set to ``False``, ignore special characters for PostScript | |||
compatibility. | |||
decimals : int, optional (default=3) | |||
The number of decimal reported of the ``impurity``, |
raghavrv
Apr 27, 2017
Member
Desired precision in the values of impurity
, threshold
and value
attributes of each node.
Desired precision in the values of impurity
, threshold
and value
attributes of each node.
assert_equal(len(search("\.\d+", finding.group()).group()), | ||
decimals + 1) | ||
|
||
# classification case |
raghavrv
Apr 27, 2017
Member
Is it possible to make a loop for reg/classification?
Is it possible to make a loop for reg/classification?
I can see a case being made for |
pandas used precision too
…On 28 Apr 2017 11:28 pm, "Guillaume Lemaitre" ***@***.***> wrote:
I first thought about decimals as in around, but the point of @jmschrei
<https://github.com/jmschrei> and @raghavrv <https://github.com/raghavrv>
is pertinent.
I made the changes
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#8698 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz6zfH6wJfIUtfo63vikBLU1-MOdJtks5r0elkgaJpZM4My8fh>
.
|
@jnothman ok so definitely |
00da9cc
into
scikit-learn:master
Thanks a lot... I'm merging this in!! |
Thanks! |
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
…8698) * EHN add decimals parameter for export_graphviz * FIX address comments * TST add test for classification * TST/FIX address comments * FIX comments raghav
Reference Issue
Fixes #8662
What does this implement/fix? Explain your changes.
Introduce a parameter
decimals
to control decimal precision when displayingthreshold
,impurity
, andvalue
parameters inexport_graphviz
.Any other comments?