Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split #8449

Merged
merged 18 commits into from Apr 3, 2017

Conversation

@raghavrv
Copy link
Member

@raghavrv raghavrv commented Feb 24, 2017

Fixes #8400

Also ref Gilles' comment

This PR tries to stop splitting if the weighted impurity gain after a potential split is not above a user-given threshold...

@amueller Can you try this on your use cases and see if it gives better control than min_impurity_split?

@jnothman @glouppe @nelson-liu @glemaitre @jmschrei

imp_right = est.tree_.impurity[right]
weighted_n_right = est.tree_.weighted_n_node_samples[right]

actual_decrease = (est.tree_.impurity[node] -

This comment has been minimized.

@raghavrv

raghavrv Feb 24, 2017
Author Member

#TODO this is incorrect comparison. The actual decrease should again by multiplied by fractional weight of the parent node...

raghavrv added 2 commits Feb 24, 2017
@@ -446,7 +454,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
is_leaf = is_leaf or (split.pos >= end)
is_leaf = (is_leaf or split.pos >= end or
split.improvement + EPSILON < min_impurity_decrease)

This comment has been minimized.

@jmschrei

jmschrei Feb 25, 2017
Member

What's the need for epsilon here?

This comment has been minimized.

@raghavrv

raghavrv Feb 27, 2017
Author Member

I did this to avoid floating precision inconsistencies affecting the split... I'll explain clearly in a subsequent comment...

This comment has been minimized.

@raghavrv

raghavrv Mar 7, 2017
Author Member

So I did this to avoid not splitting if split.improvement is almost equal to min_impurity_decrease within the precision of the machine. For instance if you give min_impurity_decrease as 1e-7, it does not build the tree completely as sometimes the improvement is almost equal to 1e-7...

And I added it to the left and not right as it would give splitting the benefit of doubt (as opposed to not splitting)...

This comment has been minimized.

@raghavrv

raghavrv Mar 24, 2017
Author Member

To clarify further. Setting it to 1e-7 as done for other stopping params to denote eps will not let the tree grow fully and produce trees dissimilar to master...

This comment has been minimized.

@MechCoder

MechCoder Mar 31, 2017
Member

Add this as an inline comment, then.

@@ -272,10 +275,23 @@ def fit(self, X, y, sample_weight=None, check_input=True,
min_weight_leaf = (self.min_weight_fraction_leaf *
np.sum(sample_weight))

if self.min_impurity_split < 0.:
if self.min_impurity_split is not None:

This comment has been minimized.

@jmschrei

jmschrei Feb 25, 2017
Member

Is there a deprication decorator which can be used? I know there is one for depricated functions, but I'm not sure about parameters.

This comment has been minimized.

@raghavrv

raghavrv Feb 27, 2017
Author Member

I think we typically use our deprecated decorator for attributes not parameters... But I'm unsure... @amueller thoughts?

@jmschrei
Copy link
Member

@jmschrei jmschrei commented Feb 25, 2017

In general this looks good. I didn't check your test though to make sure it was correct.

@raghavrv
Copy link
Member Author

@raghavrv raghavrv commented Feb 27, 2017

Thanks a lot @jmschrei for the review!

@raghavrv
Copy link
Member Author

@raghavrv raghavrv commented Feb 27, 2017

Others @glouppe @amueller Reviews please :)

@nelson-liu
Copy link
Contributor

@nelson-liu nelson-liu commented Feb 27, 2017

Functionality wise this looks good to me, pending that comment about the deprecation decorator. Good work @raghavrv

@raghavrv
Copy link
Member Author

@raghavrv raghavrv commented Mar 7, 2017

Thanks @nelson-liu and @jmschrei. Andy or Gilles??

@raghavrv
Copy link
Member Author

@raghavrv raghavrv commented Mar 14, 2017

Or maybe @glemaitre / @ogrisel have some time for reviews?

@glemaitre
Copy link
Contributor

@glemaitre glemaitre commented Mar 14, 2017

Should you mention in the docstring that min_impurity_split will be deprecated?

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

I would change with:

A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.
.. versionadded:: 0.18
The impurity decrease due to a potential split is the difference in the

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

I would remove "due to a potential split"

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

Same changes as in RandomForestClassifier

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

Same changes as in RandomForestClassifier

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

Same changes as in RandomForestClassifier

@@ -1406,7 +1417,8 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_split=1e-7, init=None,
max_depth=3, min_impurity_decrease=0.,

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

min_impurity_decrease is define at 1e-7 in the above docstring.

This comment has been minimized.

@raghavrv

raghavrv Mar 24, 2017
Author Member

Thanks for the catch. I changed the doc to 0... I'm using 0 because of the EPSILON added as described here...

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=1e-7)

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

Check the default value

@@ -1790,7 +1811,8 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_split=1e-7, init=None, random_state=None,
max_depth=3, min_impurity_decrease=0.,

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

check the default value

This comment has been minimized.

@raghavrv

raghavrv Mar 24, 2017
Author Member

(Same as above)

Threshold for early stopping in tree growth. If the impurity
of a node is below the threshold, the node is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

Same changes as in RandomForestClassifier

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split

This comment has been minimized.

@glemaitre

glemaitre Mar 14, 2017
Contributor

Same changes as in RandomForestClassifier

@raghavrv
Copy link
Member Author

@raghavrv raghavrv commented Mar 24, 2017

Should you mention in the docstring that min_impurity_split will be deprecated?

Generally we don't mention that in docstring. We deprecate it and remove the doc for that param...

Thanks for the review. Have addressed it :) Another round?

@jnothman Could you take a look this too?

@raghavrv raghavrv force-pushed the raghavrv:min_impurity_decrease branch from 4775b93 to 0ca3a4e Mar 24, 2017
Copy link
Member

@MechCoder MechCoder left a comment

Some minor comments, looks fine otherwise.

.. versionadded:: 0.18
The impurity decrease is the difference in the parent node's impurity

This comment has been minimized.

@MechCoder

MechCoder Mar 31, 2017
Member

I would prefer the easier-to-follow definition over here (https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_criterion.pyx#L177).

Also, there seems to be an extra term outside the bracket (N_parent / N_total) from your tests here. (https://github.com/scikit-learn/scikit-learn/pull/8449/files#diff-c3874016cfa1f9bc378d573240ff0502R890)


fractional_node_weight = (
est.tree_.weighted_n_node_samples[node] /
est.tree_.weighted_n_node_samples[0])

This comment has been minimized.

@MechCoder

MechCoder Mar 31, 2017
Member

nitpick: Can you replace the denominator by just X.shape[0]?

est.tree_.impurity[node] -
(weighted_n_left * imp_left +
weighted_n_right * imp_right) /
(weighted_n_left + weighted_n_right)))

This comment has been minimized.

@MechCoder

MechCoder Mar 31, 2017
Member

It might be simpler to write (N_parent * Imp_parent - N_left * imp_left - N_right * imp_right) / N

def test_min_impurity_decrease():
# test if min_impurity_decrease ensure that a split is made only if
# if the impurity decrease is atleast that value
X, y = datasets.make_classification(n_samples=10000, random_state=42)

This comment has been minimized.

@MechCoder

MechCoder Mar 31, 2017
Member

You should test regressors also no?

This comment has been minimized.

@raghavrv

raghavrv Mar 31, 2017
Author Member

Yes! The ALL_TREES[...] contains regressors too... Just that I use the same classification data to test the regressors too...

@@ -446,7 +454,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
is_leaf = is_leaf or (split.pos >= end)
is_leaf = (is_leaf or split.pos >= end or
split.improvement + EPSILON < min_impurity_decrease)

This comment has been minimized.

@MechCoder

MechCoder Mar 31, 2017
Member

Add this as an inline comment, then.

# Test if min_impurity_split of base estimators is set
# Regression test for #8006
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
all_estimators = [GradientBoostingRegressor,

This comment has been minimized.

@MechCoder

MechCoder Mar 31, 2017
Member

You need to test for random forests also?

This comment has been minimized.

@raghavrv

raghavrv Mar 31, 2017
Author Member

Thanks! done in the latest commit..

@MechCoder
Copy link
Member

@MechCoder MechCoder commented Mar 31, 2017

I agree that the behaviour of min_impurity_decrease is much more intuitive than min_impurity_split.

raghavrv added 2 commits Mar 31, 2017
@MechCoder
Copy link
Member

@MechCoder MechCoder commented Mar 31, 2017

It's the same expression your one with the "fractional_weight" and the one documented in the criterion file. It is just that I find the latter easier to read, but it's fine. (I meant having the extra term is right and it wasn't reflected in the documentation)

@MechCoder
Copy link
Member

@MechCoder MechCoder commented Mar 31, 2017

LGTM!

@MechCoder MechCoder changed the title [MRG] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split [MRG+1] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split Mar 31, 2017
.. versionadded:: 0.18
The weighted impurity decrease equation is the following:

This comment has been minimized.

@glemaitre

glemaitre Apr 1, 2017
Contributor

Are we using the ::math environment in the docstring?

This comment has been minimized.

@jmschrei

jmschrei Apr 3, 2017
Member

@raghavrv will the math display correctly from lines 815-816? The `` tag will work properly, but does indenting alone work as intended?

Copy link
Member

@jmschrei jmschrei left a comment

LGTM. If you can address the one typesetting comment I'll go ahead and merge it.

.. versionadded:: 0.18
The weighted impurity decrease equation is the following:

This comment has been minimized.

@jmschrei

jmschrei Apr 3, 2017
Member

@raghavrv will the math display correctly from lines 815-816? The `` tag will work properly, but does indenting alone work as intended?

@raghavrv
Copy link
Member Author

@raghavrv raghavrv commented Apr 3, 2017

@jmschrei @glemaitre Thanks for pointing that out! It was not displaying correctly before but after the latest commit it should look like this

image

@jmschrei jmschrei changed the title [MRG+1] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split [MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split Apr 3, 2017
@jmschrei jmschrei merged commit fc2f249 into scikit-learn:master Apr 3, 2017
5 checks passed
5 checks passed
ci/circleci Your tests passed on CircleCI!
Details
codecov/patch 100% of diff hit (target 95.49%)
Details
codecov/project 95.5% (+0.01%) compared to 38adb27
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@raghavrv
Copy link
Member Author

@raghavrv raghavrv commented Apr 3, 2017

Yohoo!! Thanks for the reviews and merge @jmschrei @MechCoder and @glemaitre :)

@raghavrv raghavrv deleted the raghavrv:min_impurity_decrease branch Apr 3, 2017
@glouppe
Copy link
Member

@glouppe glouppe commented Apr 4, 2017

Nice :)

@amueller
Copy link
Member

@amueller amueller commented Apr 5, 2017

Sweet, thanks!
Can I haz example?

massich added a commit to massich/scikit-learn that referenced this pull request Apr 26, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
Sundrique added a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
NelleV added a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
sebp added a commit to sebp/scikit-survival that referenced this pull request Oct 16, 2017
Requires scikit-learn >= 0.19

See scikit-learn/scikit-learn#8449

Fixes #11
sebp added a commit to sebp/scikit-survival that referenced this pull request Oct 16, 2017
Requires scikit-learn >= 0.19

See scikit-learn/scikit-learn#8449

Fixes #11
sebp added a commit to sebp/scikit-survival that referenced this pull request Oct 30, 2017
maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
sebp added a commit to sebp/scikit-survival that referenced this pull request Nov 18, 2017
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

7 participants