New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] feature: add beta-threshold early stopping for decision tree growth #6954

Merged
merged 16 commits into from Jul 27, 2016

Conversation

Projects
None yet
7 participants
@nelson-liu
Contributor

nelson-liu commented Jul 3, 2016

Reference Issue

Proposed here: #6557 (comment)

What does this implement/fix? Explain your changes.

Implements a stopping criterion for decision tree growth by checking if the impurity of a node is less than a user defined threshold beta. if it is, that node is set as a leaf, and no further splits are made on it. Also adds a test.

Any other comments?

I'm not sure if my test is proper. Right now, I create a tree with min_samples_split = 2 and min_samples_leaf = 1 and beta undefined (so 0 by default) and fit it on data. I then assert whether all the leaves have an impurity of 0, as they should due to the values of min_samples_split and min_samples_leaf.

To test beta, I do the same thing (including using the above values of min_samples_split and min_samples_leaf), but add a value to the beta parameter at tree construction and instead check whether the impurity of all the leaves lies within [0,beta).

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 3, 2016

@lesteve can you help me clear the travis cache?

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 3, 2016

ping @glouppe @raghavrv @jmschrei / anyone else for reviews? :)

# verify leaf nodes without beta have impurity 0
est = TreeEstimator(max_leaf_nodes=max_leaf_nodes,
random_state=0)
est.fit(X, y)

This comment has been minimized.

@raghavrv

raghavrv Jul 3, 2016

Member

Maybe test for the expected value of beta (=0 right?)

This comment has been minimized.

@nelson-liu

nelson-liu Jul 3, 2016

Contributor

yup, i'll do that.

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

You don't seem to validate if the beta is within the allowed range...

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

And you should add the beta parameter to the ensemble methods too.

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 3, 2016

@raghavrv by "validate if the beta is within the allowed range" do you mean validate that the input is a non-negative float?

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

the input is a non-negative float?

yes. beta has a range [0, 1]

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

And thanks for the PR!

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 3, 2016

yes. beta has a range [0, 1]

Can't beta be greater than 1, since the possible impurity values can be greater than 1 (in the case of regression)?

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

Yes. Sorry for not being clear. You'll have to consider classification and regression separately and validate them separately I think...

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

Also wait entropy can be greater than 1. Current change is good. Leave it as such.

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

Argh sorry entropy is never greater than one. I was thinking that gini impurity can be greater than one, but since we use gini coefficient it will also be within range [0, 1]. So you should indeed special case classification and validate beta for it separately.

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 3, 2016

Also sorry for focussing on triviality. BTW don't we need a warm start with this early stopping method? One where you can reset the beta and continue splitting the nodes that were not split fully?

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 3, 2016

So you should indeed special case classification and validate beta for it separately.

Done, sorry I didn't understand what you meant the first time around. Also added some tests to verify they properly throw errors.

@jmschrei

This comment has been minimized.

Member

jmschrei commented Jul 4, 2016

Warm start would be a good addition in combination with this, but I think that should be a separate PR.

@jmschrei

This comment has been minimized.

Member

jmschrei commented Jul 4, 2016

I changed my mind--I think this should actually be min_impurity_split. There's no reason to name it a random greek character when we can explicitly name it what it is.

@@ -805,6 +805,10 @@ class RandomForestClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
beta : float, optional (default=0.)
Threshold for early stopping in tree growth. If the impurity
of a node is below the threshold, the node is a leaf.

This comment has been minimized.

@jmschrei

jmschrei Jul 4, 2016

Member

Might want to be more explicit here, saying that a node will split if its impurity is above the min_impurity_split, otherwise is a leaf.

@@ -150,6 +153,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
"""
random_state = check_random_state(self.random_state)
beta = self.beta

This comment has been minimized.

@jmschrei

jmschrei Jul 4, 2016

Member

I don't think there is a reason to unpack beta here, but that's a minor style thing

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 4, 2016

There's no reason to name it a random greek character when we can explicitly name it what it is.

+1 This would align well with our existing stopping criteria params...

raise ValueError("beta must be a float")
if is_classification:
if not 0. <= beta <= 1.:
raise ValueError("beta must be in range [0., 1.] "

This comment has been minimized.

@jmschrei

jmschrei Jul 4, 2016

Member

It is true that classification shouldn't be above 1.0, but entropy has a stricter bound depending on the number of classes. I can't remember if they are reweighted to scale to 1.0 though? It might be better to just take in a positive number and let users figure it out.

@@ -35,7 +35,7 @@ ctypedef np.npy_intp SIZE_t
# constant to mark tree leafs
cdef int LEAF = -1
# trivial comment to force recompilation

This comment has been minimized.

@raghavrv

raghavrv Jul 4, 2016

Member

(Just adding a comment so this doesn't get merged with this scaffold comment.)

This comment has been minimized.

@nelson-liu

nelson-liu Jul 4, 2016

Contributor

i removed it, thanks 👍

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 4, 2016

Also please add a whatsnew entry under new features.

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 4, 2016

@raghavrv done! appveyor seems to be failing in an odd way so i repushed to trigger another build...

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 5, 2016

hmm, seems like appveyor failures are related to #4016

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 5, 2016

hmm... i just noticed that the appveyor tests on github redirect to https://ci.appveyor.com/project/agramfort/scikit-learn/build/1.0.276, which is on @agramfort 's account. Is there any reason why we aren't using the sklearn-ci account (it passes tests there)? https://ci.appveyor.com/project/sklearn-ci/scikit-learn/build/1.0.6961

@agramfort

This comment has been minimized.

Member

agramfort commented Jul 5, 2016

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 16, 2016

I feel this is good to go. Thanks. @glouppe a second review and merge?

is_leaf = is_leaf or (impurity <= MIN_IMPURITY_SPLIT)
is_leaf = (is_leaf or
(impurity <= MIN_IMPURITY_SPLIT) or
(impurity < min_impurity_split))

This comment has been minimized.

@glouppe

glouppe Jul 16, 2016

Member

This is clearly confusing.

This comment has been minimized.

@nelson-liu

nelson-liu Jul 17, 2016

Contributor

i agree! I renamed it to LEAF_MIN_IMPURITY but i think that's also a little bit confusing. do you have any suggestions for suitable names?

@raghavrv

This comment has been minimized.

Member

raghavrv commented Jul 16, 2016

@glouppe @jmschrei BTW do you think warm start would be a good thing to work on next?

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 22, 2016

@glouppe is there anything else that needs to be done on this PR?

@glouppe

This comment has been minimized.

Member

glouppe commented Jul 27, 2016

LGTM once the defaults are properly changed to 10e-7.

@glouppe

This comment has been minimized.

Member

glouppe commented Jul 27, 2016

Thanks Nelson! I'll wait for our friend Travis to arrive, and I'll merge.

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 27, 2016

Sorry for that oversight! not quite sure what I was thinking, changing the docstrings and not the actual code 😝 thanks again for taking a look @glouppe

@jmschrei

This comment has been minimized.

Member

jmschrei commented Jul 27, 2016

👍

@glouppe glouppe merged commit 376aa50 into scikit-learn:master Jul 27, 2016

2 of 4 checks passed

ci/circleci CircleCI is running your tests
Details
continuous-integration/appveyor/pr Waiting for AppVeyor build to complete
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
coverage/coveralls Coverage increased (+0.005%) to 94.493%
Details
@glouppe

This comment has been minimized.

Member

glouppe commented Jul 27, 2016

Bim! Happy to see GSoC efforts to materialize :)

@@ -805,6 +805,10 @@ class RandomForestClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -1001,6 +1007,10 @@ class RandomForestRegressor(ForestRegressor):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -1160,6 +1172,10 @@ class ExtraTreesClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -1355,6 +1373,10 @@ class ExtraTreesRegressor(ForestRegressor):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -1500,6 +1524,10 @@ class RandomTreesEmbedding(BaseForest):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -1358,6 +1359,10 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -1711,6 +1718,10 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -63,7 +64,6 @@ TREE_UNDEFINED = -2
cdef SIZE_t _TREE_LEAF = TREE_LEAF
cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED
cdef SIZE_t INITIAL_STACK_SIZE = 10
cdef DTYPE_t MIN_IMPURITY_SPLIT = 1e-7

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

Was that not documented beforehand? I feel like that should have been in the docs.

This comment has been minimized.

@nelson-liu

nelson-liu Jul 27, 2016

Contributor

nope, it wasn't documented beforehand as far as I saw.

@@ -608,6 +615,10 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin):
If None, the random number generator is the RandomState instance used
by `np.random`.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@@ -848,6 +861,10 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin):
If None, the random number generator is the RandomState instance used
by `np.random`.
min_impurity_split : float, optional (default=1e-7)

This comment has been minimized.

@amueller

amueller Jul 27, 2016

Member

versionadded is missing.

@amueller

This comment has been minimized.

Member

amueller commented Jul 27, 2016

This is great, thanks! I would be awsome if there was some example, though, and please add the "versionadded" tag to all the docstrings

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 27, 2016

oops, didn't realize the need for the "versionadded" tags, thanks. What sort of example were you thinking? An inline one in the docs, or a full-fledged example in the examples/ directory? I'm thinking of adding one to show how changing the value of the parameter affects the number of nodes in the tree, was that what you had in mind?

I'll go ahead and add these in a new PR

@amueller

This comment has been minimized.

Member

amueller commented Jul 27, 2016

Maybe an example that discusses the many pre-pruning options and how they change the tree? I think a full-fledge example on pruning would be good, in particular if we get post-pruning at some point.

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Jul 27, 2016

@amueller what pre-pruning methods in particular were you thinking about? The ones i'm thinking of are min_impurity_split, max_leaf_nodes, max_depth?

@amueller

This comment has been minimized.

Member

amueller commented Jul 27, 2016

yeah. Maybe also min_samples_leaf?

@nelson-liu

This comment has been minimized.

Contributor

nelson-liu commented Aug 6, 2016

@amueller I wrote a preliminary version of what could become an example as a GSoC blog post, could you take a quick look at let me know what you think / what extra content you think should be added for an example? link is: http://blog.nelsonliu.me/2016/08/06/gsoc-week-10-pr-6954-prepruning-decision-trees/

olologin added a commit to olologin/scikit-learn that referenced this pull request Aug 24, 2016

[MRG+1] feature: add beta-threshold early stopping for decision tree …
…growth (scikit-learn#6954)

* feature: add beta-threshold early stopping for decision tree growth

* check if value of beta is greater than or equal to 0

* test if default value of beta is 0 and edit input validation error message

* feature: separately validate beta for reg. and clf., and add tests for it

* feature: add beta to forest-based ensemble methods

* feature: add separate condition to determine that beta is float

* feature: add beta to gradient boosting estimators

* rename parameter to min_impurity_split, edit input validation and associated tests

* chore: fix spacing in forest and force recompilation of grad boosting extension

* remove trivial comment in grad boost and add whats new

* edit wording in test comment / rebuild

* rename constant with the same name as our parameter

* edit line length for what's new

* remove constant and set min_impurity_split to 1e-7 by default

* fix docstrings for new default

* fix defaults in gradientboosting and forest classes
@arjoly

This comment has been minimized.

Member

arjoly commented Sep 15, 2016

Great, thanks @nelson-liu

TomDLT added a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016

[MRG+1] feature: add beta-threshold early stopping for decision tree …
…growth (scikit-learn#6954)

* feature: add beta-threshold early stopping for decision tree growth

* check if value of beta is greater than or equal to 0

* test if default value of beta is 0 and edit input validation error message

* feature: separately validate beta for reg. and clf., and add tests for it

* feature: add beta to forest-based ensemble methods

* feature: add separate condition to determine that beta is float

* feature: add beta to gradient boosting estimators

* rename parameter to min_impurity_split, edit input validation and associated tests

* chore: fix spacing in forest and force recompilation of grad boosting extension

* remove trivial comment in grad boost and add whats new

* edit wording in test comment / rebuild

* rename constant with the same name as our parameter

* edit line length for what's new

* remove constant and set min_impurity_split to 1e-7 by default

* fix docstrings for new default

* fix defaults in gradientboosting and forest classes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment