Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] feature: add beta-threshold early stopping for decision tree growth #6954

Merged
merged 16 commits into from
Jul 27, 2016

Conversation

nelson-liu
Copy link
Contributor

Reference Issue

Proposed here: #6557 (comment)

What does this implement/fix? Explain your changes.

Implements a stopping criterion for decision tree growth by checking if the impurity of a node is less than a user defined threshold beta. if it is, that node is set as a leaf, and no further splits are made on it. Also adds a test.

Any other comments?

I'm not sure if my test is proper. Right now, I create a tree with min_samples_split = 2 and min_samples_leaf = 1 and beta undefined (so 0 by default) and fit it on data. I then assert whether all the leaves have an impurity of 0, as they should due to the values of min_samples_split and min_samples_leaf.

To test beta, I do the same thing (including using the above values of min_samples_split and min_samples_leaf), but add a value to the beta parameter at tree construction and instead check whether the impurity of all the leaves lies within [0,beta).

@nelson-liu
Copy link
Contributor Author

@lesteve can you help me clear the travis cache?

@nelson-liu
Copy link
Contributor Author

ping @glouppe @raghavrv @jmschrei / anyone else for reviews? :)

# verify leaf nodes without beta have impurity 0
est = TreeEstimator(max_leaf_nodes=max_leaf_nodes,
random_state=0)
est.fit(X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe test for the expected value of beta (=0 right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, i'll do that.

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

You don't seem to validate if the beta is within the allowed range...

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

And you should add the beta parameter to the ensemble methods too.

@nelson-liu
Copy link
Contributor Author

@raghavrv by "validate if the beta is within the allowed range" do you mean validate that the input is a non-negative float?

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

the input is a non-negative float?

yes. beta has a range [0, 1]

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

And thanks for the PR!

@nelson-liu
Copy link
Contributor Author

nelson-liu commented Jul 3, 2016

yes. beta has a range [0, 1]

Can't beta be greater than 1, since the possible impurity values can be greater than 1 (in the case of regression)?

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

Yes. Sorry for not being clear. You'll have to consider classification and regression separately and validate them separately I think...

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

Also wait entropy can be greater than 1. Current change is good. Leave it as such.

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

Argh sorry entropy is never greater than one. I was thinking that gini impurity can be greater than one, but since we use gini coefficient it will also be within range [0, 1]. So you should indeed special case classification and validate beta for it separately.

@raghavrv
Copy link
Member

raghavrv commented Jul 3, 2016

Also sorry for focussing on triviality. BTW don't we need a warm start with this early stopping method? One where you can reset the beta and continue splitting the nodes that were not split fully?

@nelson-liu
Copy link
Contributor Author

So you should indeed special case classification and validate beta for it separately.

Done, sorry I didn't understand what you meant the first time around. Also added some tests to verify they properly throw errors.

@jmschrei
Copy link
Member

jmschrei commented Jul 4, 2016

Warm start would be a good addition in combination with this, but I think that should be a separate PR.

@jmschrei
Copy link
Member

jmschrei commented Jul 4, 2016

I changed my mind--I think this should actually be min_impurity_split. There's no reason to name it a random greek character when we can explicitly name it what it is.

@@ -805,6 +805,10 @@ class RandomForestClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

beta : float, optional (default=0.)
Threshold for early stopping in tree growth. If the impurity
of a node is below the threshold, the node is a leaf.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to be more explicit here, saying that a node will split if its impurity is above the min_impurity_split, otherwise is a leaf.

@raghavrv
Copy link
Member

raghavrv commented Jul 4, 2016

There's no reason to name it a random greek character when we can explicitly name it what it is.

+1 This would align well with our existing stopping criteria params...

raise ValueError("beta must be a float")
if is_classification:
if not 0. <= beta <= 1.:
raise ValueError("beta must be in range [0., 1.] "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true that classification shouldn't be above 1.0, but entropy has a stricter bound depending on the number of classes. I can't remember if they are reweighted to scale to 1.0 though? It might be better to just take in a positive number and let users figure it out.

@raghavrv
Copy link
Member

raghavrv commented Jul 4, 2016

Also please add a whatsnew entry under new features.

@nelson-liu
Copy link
Contributor Author

@raghavrv done! appveyor seems to be failing in an odd way so i repushed to trigger another build...

@nelson-liu
Copy link
Contributor Author

hmm, seems like appveyor failures are related to #4016

@nelson-liu
Copy link
Contributor Author

hmm... i just noticed that the appveyor tests on github redirect to https://ci.appveyor.com/project/agramfort/scikit-learn/build/1.0.276, which is on @agramfort 's account. Is there any reason why we aren't using the sklearn-ci account (it passes tests there)? https://ci.appveyor.com/project/sklearn-ci/scikit-learn/build/1.0.6961

@agramfort
Copy link
Member

agramfort commented Jul 5, 2016 via email

@raghavrv
Copy link
Member

I feel this is good to go. Thanks. @glouppe a second review and merge?

is_leaf = is_leaf or (impurity <= MIN_IMPURITY_SPLIT)
is_leaf = (is_leaf or
(impurity <= MIN_IMPURITY_SPLIT) or
(impurity < min_impurity_split))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is clearly confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree! I renamed it to LEAF_MIN_IMPURITY but i think that's also a little bit confusing. do you have any suggestions for suitable names?

@raghavrv
Copy link
Member

@glouppe @jmschrei BTW do you think warm start would be a good thing to work on next?

@nelson-liu
Copy link
Contributor Author

@glouppe is there anything else that needs to be done on this PR?

@@ -1437,7 +1442,7 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, init=None, random_state=None,
max_depth=3, min_impurity_split=0.,init=None, random_state=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Copy link
Contributor

@glouppe glouppe Jul 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, missing space after the comma

@glouppe
Copy link
Contributor

glouppe commented Jul 27, 2016

LGTM once the defaults are properly changed to 10e-7.

@glouppe
Copy link
Contributor

glouppe commented Jul 27, 2016

Thanks Nelson! I'll wait for our friend Travis to arrive, and I'll merge.

@nelson-liu
Copy link
Contributor Author

Sorry for that oversight! not quite sure what I was thinking, changing the docstrings and not the actual code 😝 thanks again for taking a look @glouppe

@jmschrei
Copy link
Member

👍

@glouppe glouppe merged commit 376aa50 into scikit-learn:master Jul 27, 2016
@glouppe
Copy link
Contributor

glouppe commented Jul 27, 2016

Bim! Happy to see GSoC efforts to materialize :)

@@ -805,6 +805,10 @@ class RandomForestClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.

@amueller
Copy link
Member

This is great, thanks! I would be awsome if there was some example, though, and please add the "versionadded" tag to all the docstrings

@nelson-liu
Copy link
Contributor Author

oops, didn't realize the need for the "versionadded" tags, thanks. What sort of example were you thinking? An inline one in the docs, or a full-fledged example in the examples/ directory? I'm thinking of adding one to show how changing the value of the parameter affects the number of nodes in the tree, was that what you had in mind?

I'll go ahead and add these in a new PR

@amueller
Copy link
Member

Maybe an example that discusses the many pre-pruning options and how they change the tree? I think a full-fledge example on pruning would be good, in particular if we get post-pruning at some point.

@nelson-liu
Copy link
Contributor Author

nelson-liu commented Jul 27, 2016

@amueller what pre-pruning methods in particular were you thinking about? The ones i'm thinking of are min_impurity_split, max_leaf_nodes, max_depth?

@amueller
Copy link
Member

yeah. Maybe also min_samples_leaf?

@nelson-liu
Copy link
Contributor Author

@amueller I wrote a preliminary version of what could become an example as a GSoC blog post, could you take a quick look at let me know what you think / what extra content you think should be added for an example? link is: http://blog.nelsonliu.me/2016/08/06/gsoc-week-10-pr-6954-prepruning-decision-trees/

olologin pushed a commit to olologin/scikit-learn that referenced this pull request Aug 24, 2016
…growth (scikit-learn#6954)

* feature: add beta-threshold early stopping for decision tree growth

* check if value of beta is greater than or equal to 0

* test if default value of beta is 0 and edit input validation error message

* feature: separately validate beta for reg. and clf., and add tests for it

* feature: add beta to forest-based ensemble methods

* feature: add separate condition to determine that beta is float

* feature: add beta to gradient boosting estimators

* rename parameter to min_impurity_split, edit input validation and associated tests

* chore: fix spacing in forest and force recompilation of grad boosting extension

* remove trivial comment in grad boost and add whats new

* edit wording in test comment / rebuild

* rename constant with the same name as our parameter

* edit line length for what's new

* remove constant and set min_impurity_split to 1e-7 by default

* fix docstrings for new default

* fix defaults in gradientboosting and forest classes
@arjoly
Copy link
Member

arjoly commented Sep 15, 2016

Great, thanks @nelson-liu

TomDLT pushed a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016
…growth (scikit-learn#6954)

* feature: add beta-threshold early stopping for decision tree growth

* check if value of beta is greater than or equal to 0

* test if default value of beta is 0 and edit input validation error message

* feature: separately validate beta for reg. and clf., and add tests for it

* feature: add beta to forest-based ensemble methods

* feature: add separate condition to determine that beta is float

* feature: add beta to gradient boosting estimators

* rename parameter to min_impurity_split, edit input validation and associated tests

* chore: fix spacing in forest and force recompilation of grad boosting extension

* remove trivial comment in grad boost and add whats new

* edit wording in test comment / rebuild

* rename constant with the same name as our parameter

* edit line length for what's new

* remove constant and set min_impurity_split to 1e-7 by default

* fix docstrings for new default

* fix defaults in gradientboosting and forest classes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants