Skip to content

Conversation

@pprett
Copy link
Member

@pprett pprett commented Nov 24, 2011

This is a PR for Gradient Boosted Regression Trees [1](aka Gradient Boosting, MART, TreeNet).

GBRTs have been advertised as one of the best off-the-shelf data-mining procedures; they share many properties with random forests including little need for tuning and data preprocessing as well as high predictive accuracy. GBRT have been used very successfully in areas such as learning of ranking functions and ecology.

This should be an alternative to R's 'gbm' package [2]. Currently, it feature three loss functions (binary classification, least squared regression and robust regression), stochastic gradient boosting, and variable importance.

I've benchmarked the code against R's gbm package (via rpy2) using a variety of datasets (about 4 classification and 3 regression datasets) - the results are remarkably similar; gbm, however, is usually a bit faster for least-squares regression.

Some features are still on my TODO list:

* Multi-class classification (done thanks to @scottblanc)
* Partial dependency plots
* Quantile loss function for robust regression

I haven't benchmarked it against OpenCVs implementation.

[1] http://en.wikipedia.org/wiki/Gradient_boosting
[2] http://cran.r-project.org/web/packages/gbm/index.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've included bench_tree.py by accident - I'll kick it out in the next commit

@paolo-losi
Copy link
Member

Peter,

amazing work! thank you very much for the contribution.

Paolo

On Thu, Nov 24, 2011 at 9:37 AM, Peter Prettenhofer
reply@reply.github.com
wrote:

This is a PR for Gradient Boosted Regression Trees [1](aka Gradient Boosting, MART, TreeNet).

GBRTs have been advertised as one of the best off-the-shelf data-mining procedures; they share many properties with random forests including little need for tuning and data preprocessing as well as high predictive accuracy. GBRT have been used very successfully in areas such as learning of ranking functions and ecology.

This should be an alternative to R's 'gbm' package [2]. Currently, it feature three loss functions (binary classification, least squared regression and robust regression), stochastic gradient boosting, and variable importance.

The PR lacks documentation (there's only a stub in ensemble.rst), I'll update it ASAP.

I've benchmarked the code against R's gbm package (via rpy2) using a variety of datasets (about 4 classification and 3 regression datasets) - the results are remarkably similar; gbm, however, is usually a bit faster (I'll do some further profiling in the coming days).

Some features are still on my TODO list:

   * Multi-class classification
   * Partial dependency plots
   * Quantile loss function for robust regression

I haven't benchmarked it against OpenCVs implementation - that's also on my TODO list.

[1] http://en.wikipedia.org/wiki/Gradient_boosting
[2] http://cran.r-project.org/web/packages/gbm/index.html

You can merge this Pull Request by running:

 git pull https://github.com/pprett/scikit-learn gradient_boosting

Or you can view, comment on it, or merge it online at:

 #448

-- Commit Summary --

  • initial checkin of gradient boosting
  • GBRT benchmark from ELSII Example 10.2
  • added GBRT regressor + classifier classes; added shrinkage
  • use super in DecisionTree subclasses
  • first work on various loss functions for gradient boosting.
  • added store_sample_mask flag to build_tree
  • implemented lad and binomial deviance - still a bug in binomial deviance -> mapping to {-1,1} or {0,1} ?
  • updated benchmark script for gbrt.
  • some debug stmts
  • new benchmarks for gbrt classification
  • fix: MSE criterion was wrong (don't weight variance!)
  • more benchmarks
  • binomial deviance now works!!!!!
  • add gradient boosting to covtype benchmark
  • add documentation to GB
  • timeit stmts in boosting procedure.
  • add previously rm c code
  • updated tree
  • hopefully the last bugfix in MSE
  • new params in gbrt benchmark and comment out debug output
  • make Node an extension type + change class label indexing.
  • predict_proba now returns an array w/ as many cols as classes.
  • cosmit: tidyed up RegressionCriterion
  • added VariableImportance visitor and variable_importance property
  • minor changes to benchmark scripts
  • use np.take if possible, added monitor object to fit method for algorithm introspection.
  • cosmit
  • choose left branch if smaller or equal to threshold; add epsilon to find_larger_than.
  • compiled changes for last commit
  • cosmit
  • some tweaks and debug msg in tree to spot numerical difficulties.
  • added TimSC tree fix
  • changed from node.error to node.initial_error in graphviz exporter
  • recompiled cython code after rebase
  • fix: _tree.Node
  • comment out HuberLoss and comment in benchmarks
  • changed from y in {-1,1} to {0,1}
  • cosmit: beautified RegressionCriterion (sum and sq_sum instead of mean).
  • rename node.sample_mask to node.terminal_region
  • fix: Node.reduce
  • fix init predictor for binomial loss
  • performance enh: update predictions during update_terminal_regions
  • fix: samplemask
  • added timing info
  • use new tree repr; adapt gradient boosting for new tree repr.
  • Merge branch 'master' into gradient_boosting
  • cythonized tree (still broken)
  • clear tree.py
  • updated _tree.c
  • updated GradientBoosting with current master
  • fix: update variable importance
  • added gradient boosting regression example
  • added test deviance to GBRT example
  • updated TODO in module doc
  • Merge branch 'master' into gradient_boosting
  • fix: make GradientBoostingBase clonable.
  • added unit tests for gradient boosting (coverage ~95%)
  • better test coverage
  • store loss object in estimator
  • Merge branch 'master' into gradient_boosting
  • stub for gradient boosting documentation

-- File Changes --

A benchmarks/bench_gbrt.py (258)
M benchmarks/bench_tree.py (198)
M doc/modules/ensemble.rst (22)
A examples/ensemble/plot_gradient_boosting_regression.py (74)
M sklearn/ensemble/init.py (3)
A sklearn/ensemble/gradient_boosting.py (541)
A sklearn/ensemble/tests/test_gradient_boosting.py (183)
M sklearn/tree/_tree.c (519)
M sklearn/tree/tree.py (39)

-- Patch Links --

 https://github.com/scikit-learn/scikit-learn/pull/448.patch
 https://github.com/scikit-learn/scikit-learn/pull/448.diff


Reply to this email directly or view it on GitHub:
#448

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd appreciate your feedback on the monitor argument. The main intention is to have a way to introspect the model while it's being trained. The monitor is basically something that can be call'ed with the current state of the model (i.e. the current model at iteration i). Such a feature can be useful for: a) custom termination criteria (e.g. error/deviance on held-out set increases/stalls) and b) to introspect the model for model selection.

An alternative to the monitor would be to store the progress of the model in dedicated attributes (such as self.train_deviance) and do introspection / early stopping after the model has been trained. AFAIK that's the way R's gbm package does it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We indeed need such a pattern in scikit-learn (I needed something similar to debug convergence issues in minibatch kmeans and power iteration clustering too for instance). I will try to give it a deeper look this WE.

@bdholt1
Copy link
Member

bdholt1 commented Nov 24, 2011

This is great work @pprett! I look forward to reviewing it in greater detail.

@glouppe
Copy link
Contributor

glouppe commented Nov 24, 2011

I don't have much time now but I'll review your work as best as I can starting from next week.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really useful, and we might find it moving to the tree.py module. One of my colleagues has been working on autonomous driving and the trouble he found with using the standard mean predictor (bagging) is that it tends to smooth the output and underestimate extremal values. Using a median predictor went a long way to resolving this.

@ogrisel
Copy link
Member

ogrisel commented Nov 24, 2011

@pprett can you please either rebase or merge master into this branch so that github marks it as green? Better review something that does not conflict with the current master.

@pprett
Copy link
Member Author

pprett commented Nov 24, 2011

@ogrisel updated - two lines of cython code can be quite a mess in the generated c code...

@ALL Thanks for your feedback and review commitments!

@ogrisel concerning the monitor - we should discuss this very openly - the way I implemented it was quite adhoc... any thoughts/criticism much appreciated

@pprett
Copy link
Member Author

pprett commented Dec 1, 2011

I did some more benchmarks against R's gbm package. I measured error rate, training time and testing time on four classification datasets and MSE, training time and testing time on four regression datasets.

Classification - decision stumps - 200 iterations
GitHub Logo

when using decision stumps the implementation in this PR is competitive to R's gbm for binary classification. Gbm has a hard time on arcene - maybe some numerical issues (it spills out warnings too).

Classification - depth 3 trees - 200 iterations
Classification - depth 3 trees - 200 iterations

If we use more complex weak learners we see that the training time of this PR grows compared to gbm.

Regression - stumps - 100 iterations (normalized such that GBM's score is 1.0)
Regression - stumps - 100 iterations (normalized such that GBM's score is 1.0)

For regression gbm is much faster than this PR. The internal data structure used by gbm's tree impl. might be more efficient than our sample mask approach but this shouldn't affect decision stumps too much.

Regression - Depth 4 trees - 100 iterations (normalized such that GBM's score is 1.0)
Regression - Depth 4 trees - 100 iterations (normalized such that GBM's score is 1.0)

Interesting, the differences in effectiveness; for stumps results were pretty much equal but not for deeper trees - they might use some heuristics to control tree growing...

here's the album link:
https://picasaweb.google.com/112318904407774814180/Misc#5681158907337974802

@mblondel
Copy link
Member

mblondel commented Dec 6, 2011

Thanks for the awesome work Peter! I was not familiar with Gradient boosting so I had a quick look at the Wikipedia page. It reminds me a lot of the matching pursuit family of algorithms (but replacing the residuals by the pseudo residuals so as to support arbitrary loss functions) and of forward stepwise regression algorithms like LARS. Quick questions:

  • In the Wikipedia article, they mention that the step length can be optimized with a line search. If I'm not mistaken, you use a user-defined learning rate in this implementation. Is it correct? I've been looking into line search methods for setting the step length automatically lately but it seems fairly difficult to do so in a generic way (so as to support arbitrary loss functions). Moreover, most methods require smoothness of the function to optimize, thus ruling out the hinge loss.
  • As far as I see, the algorithm is fairly general (can be used with other learners than trees). Would it be difficult to make a base class which is learner agnostic? (__init__(estimator, ...))

@pprett
Copy link
Member Author

pprett commented Dec 6, 2011

@mblondel : thanks!

@step-length: gradient tree boosting (as introduced by Friedman [1]) involves two "learning rates"; The first for the "functional gradient descent" (i.e. fitting the pseudo-residuals) - this is where the line search is involved; the second is a regularization heuristic proposed by Friedman to tackle overfitting. The learning_rate parameter refers to the latter. (Friedman calls it "learning rate" parameter v in [1] - after Eq. (36)). In gbm it is called shrinkage - probably a better name.

In some cases the line search can be done in closed form. In the case of regression trees as base learners you actually add J separate basis function in each iteration instead of a singe one (where J is the number of leaves); so you actually have to do J line-searches in each iteration. This is implemented in the _update_terminal_region methods. For the loss functions 'least squares' and 'least absolute deviation' the line searches can be computed in closed form. For 'binomial deviance', the line search cannot be computed in closed form - [1] uses the same strategy as LogitBoost, it approximates the result by a single Newton-Raphson step.

[1] Friedman, J. H. "Greedy Function Approximation: A Gradient Boosting Machine." (February 1999)

@learner-agnostic: This is correct; on the other hand, trees are AFAIK the most popular base learners and Gradient Tree Boosting has some distinct characteristics (e.g. it fits J instead of just one basis functions). Furthermore, by assuming trees as base learners we can gain some efficiency because they can share certain data-structures (e.g. the sorted features array). R's gbm package is also hard-wired to trees as base learners and is quite efficient because of that. R's mboost package [2] supports both trees and generalized linear models. But I've to admit I don't know how mboost compares to gbm in terms of efficiency.

[2] Bühlmann, P and Horton, ."Boosting Algorithms: Regularization, Prediction and Model Fitting"

@mblondel
Copy link
Member

mblondel commented Dec 6, 2011

I hadn't realized that you handle the loss-specific line search code in the loss functions and I wanted to suggest to use the closed-form solution for the squared loss. For other smooth loss functions, the number of Newton-Raphson steps could be an option (defaulting to 1).

Regarding learning_rate vs shrinkage, I guess I prefer the former because we are actually talking of reducing the step size.

@pprett
Copy link
Member Author

pprett commented Dec 7, 2011

I just realized that sklearns max_depth parameter and gbm's interaction.depth mean different things - this is most likely the fact why this PR is much slower than gbm for max_depth larger than 1. It does not explain the poor performance for stumps and least squares regression (3rd plot) - I have to investigate further.

I'll push an update in the days to come which is able to build trees with exactly J terminal regions.

@glouppe
Copy link
Contributor

glouppe commented Dec 18, 2011

Hi Peter,

I have started reviewing your code. This looks great! To make things easier, I will directly pull my changes to your branch during the sprint.

@pprett
Copy link
Member Author

pprett commented Dec 19, 2011

Hi Gilles,

thanks!

BTW: I had a hard time figuring out how to build a tree with exactly J leaves (e.g. depth-first-search, breath-first search, some greedy heuristic based on improvement?) so I simply provide a max_depth parameter to control the number of leaves; this is different from the interaction.depth parameter in GBM - it seems that GBM creates one path of depth interaction.depth based on some improvement heuristic but I cannot find any description in the literature. Any comments, thoughts, literature pointers on this are much appreciated!

@glouppe
Copy link
Contributor

glouppe commented Dec 21, 2011

Peter, sorry for replying so late. As you may have seen, I have two pending pull requests concerning the tree and the forest module. I would like to have them merged first before helping you with your branch.

I have had to the opportunity to read your request though and I think we should aim at making a boosting algorithm that is more generic first, in the spirit of @jaberg recent PR, and then optimize for trees, but not the other way around. Among other things, I have for intance added in my PRs a similar trick as what you have used to avoid the recompution of X_argsorted and to reuse sample_mask, without having to use the private functon _build_tree. Namely, I put them as default parameters into the fit method. This for instance would help optimize the use of trees without having to touch at the inner code of the tree module that, in my opinion, should not be used by other modules than itself. What I want to tell you is that I think similar refactorings could be used to use trees as black boxes as much as possible, without losing much in terms of speed. As a second example, I think it would be better as well if boosting would inherit from the BaseEnsemble class and that way have a similar API than other ensemble techniques.

Anyway, don't take it wrong. Thsi is not my intention. I really want to help you with boosting. This is one of the most important algorithms in machine learning and I really want to have it the project.

I will address your PR in more details with more direct suggestions at my return in Belgium (I have to catch my plane in a few hours).

@pprett
Copy link
Member Author

pprett commented Dec 21, 2011

@glouppe no worries - this sounds convincing! I definelty agree that if we have both a generic gradient boosting and a specialized gradient tree boosting the latter should be a subclass of the former.

@glouppe @jaberg We should definelty combine our efforts - maybe the best starting point is to make a list of requirements - What functionality (e.g. loss functions, parameters) and quality attributes (performance, extensibility, ...) should the component have. I've to admit that I'm a bit biased towards GBM, it's a very well crafted implementation - both flexible (various loss functions for regression and classification) and performant - having something competitive in the scikit would be great!

@mblondel
Copy link
Member

mblondel commented Jan 9, 2012

I'm interested in seeing this PR merged soon. What remains to be done?

@pprett
Copy link
Member Author

pprett commented Jan 9, 2012

@mbondel I did some performance enhancements yesterday - basically, reducing the difference to R's gbm (now its competitive for classification but still a bit slower for (least squares) regression). I started working on the narrative documentation. There are some open issues::

  1. derive GradientBoosting from base ensemble class
    just a minor issue
  2. I've merged my gradient_boosting branch with the closed properties PR and I've to undo these changes
    also just a minor issue
  3. Maybe another example that shows how to do model selection with out-of-bag estimates vs. cross-validation.

Still I haven't merged it with James' initial draft for generic functional gradient boosting.
The major reason is that I've no experience with generic functional boosting and James' draft currently only supports regression. We should definitely compare it to mboost [1] before merging it to master. Furthermore, the gradient tree boosting code has been benchmarked extensively on 5 classification and 4 regression datasets so its relatively mature.

[2] Bühlmann, P and Horton, ."Boosting Algorithms: Regularization, Prediction and Model Fitting"

@ogrisel
Copy link
Member

ogrisel commented Jan 9, 2012

This looks very promising, thanks for the update. IMO we should target the merge of this PR before the 0.10 release and open a new PR for @jaberg generic implementation that reuses / refactor @pprett's code and target it for the 0.11.

@mblondel
Copy link
Member

mblondel commented Jan 9, 2012

@ogrisel's plan sounds good: let's keep this PR's TODO list reasonable. Thanks @pprett for the update :)

@pprett
Copy link
Member Author

pprett commented Jan 9, 2012

OK - now here's the updated version. I've added two benchmark scripts benchmarks/bench_gbrt.py and benchmarks/bench_gbm.py - they are not meant to be merged but allow us to compare this version to R's gbm package. You just need R (with gbm) and rpy2.

Here's a summary of the changes I needed to to in other modules::

  • tree:
    • I moved the feature importances computation in a function - it seems that Breiman and Hastie use slightly different formulas;
    • I've added additional payload to tree.Tree: terminal_region which has the same shape as y and stores the node_ids of the terminal node of each sample (if there is none its -1). This is only stored if store_terminal_region == True.
    • Furthermore, recursive_partition keeps track of the sample_indices (=locations in the original y array). This allows setting Tree.terminal_region in the presence of fancy indexing.
  • _tree:
    • Criterion classes now give access to the init_values via criterion.init_value() this spares use the recomputation of mean (regression) or bincount (classification).

I'll continue with the narrative documentation...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any reference for that method? It seems counter-intuitive to me to not weight by the number of samples in the node. This indeed means that a feature used near the leaves might be as important as the feature used at the root, even if the former is used to separate significantly less samples (recall that best_error and init_error are already normalized by the number of samples in the node).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glouppe I coded up the feature importance procedure from ESLII (2nd Edition). Your argument sounds convincing though - I'll look into the formulas more carefully - maybe I missed a normalization step. Could you pass me the reference for the 'gini' method (is it Breiman's random forest paper?)

@mblondel
Copy link
Member

When you get time, could you merge master @pprett? I wanted to to play with this branch but it currently has merge conflicts. (It's low priority, I just wanted to play around :))

Regarding the generic and tree-specialized implementations, instead of using inheritance, could we just enable the tree-optimized routines if isinstance(base_estimator, TreeClassifier)? (Does the specialized implementation need constructor parameters which won't exist in the generic implementation?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node_id seems to be discarded, so maybe just leave it out?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch - thx

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this a bit more? I would have thought the class prior is np.mean(y) (if y is 0 or 1) and if I understand the code correctly,
self.prior is what I would call the log-odds or log probability ratio.
Do I misunderstand something or are we just using different words?
In the multi-class case it seems to do something different, though.
Also in the multi-class case, wouldn't it be easier to use LabelBinarizer and np.mean ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are correct - I used to use np.mean(y) but I changed to the same init as gbm - I forgot to update the docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into the multi-class init tomorrow and I have to double check with Scott.
btw: Thanks for your throught review - I really appreciate that!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're very welcome. Well it got more superficial in the end as I got a bit tired ;)
Btw I'm leaving Vienna in ~2 Weeks. Wanna go for a drink before?

Oh and if you like I'd appreciate if you looked at the 5 lines of code I deleted here: #707 ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely - I'll send you an email immediately.

regarding the PR - sure thing - I had it on my radar but I was a bit busy lately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason to make this a function instead of a static member variable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope - should be a member (inherited from MulticlassLossFunction?) - I'll look into that tomorrow.

@amueller
Copy link
Member

Thanks for addressing my comments. Pretty impressive work :)

I have not fully gone through the algorithm but I'm pretty sure it's correct by now ;)
As I said above, it would be cool if you could explain the two ClassPriorPredictors to me.

About the label binarization: I think you can not avoid a unique and if the classes are non-consecutive numbers you need to get the inverse indices or something. I think it would be great if the scikit would be as tolerant as possible when it comes to the class labels. Therefore I think it would make sense to see if the classes can be binarized maybe in the fit method.

Thanks for all the good work!

@amueller
Copy link
Member

Oh, I saw you just edited the docstrings for the ClassPriorPredictors :)
Good night ;)

@pprett
Copy link
Member Author

pprett commented Mar 22, 2012

@amueller regarding ClassPriorPredictor - I renamed it to PriorProbabilityEstimator - since we re-label the classes to be consecutive ints at the beginning of GradientBoostingClassifier I now uses np.bincount.

BTW: I renamed all *Predictor classes to *Estimator.

@amueller
Copy link
Member

Oh I missed the relabeling of classes last night. np.bincount is probably the best choice then. 👍 for merge from my side!

@pprett
Copy link
Member Author

pprett commented Mar 22, 2012

Here are the new benchmark results; Sklearn vs. GBM using decision stumps and 250 base learners.

Classification

GitHub Logo

Regression

Values are normalized such that GBM is 1.0

GitHub Logo

Lower is better - Sklearn is competitive to GBM for classification - but for least-squares regression GBM is significantly faster (don't know why to be honest...) - AFAIK GBM does not support multi-class classification - test times for Sklearn are usually better (except for the Boston dataset)

@agramfort
Copy link
Member

pardon my stupidity but how do you explain the perf difference? is
there any implementation diff or just the python overhead?

@mblondel
Copy link
Member

pardon my stupidity but how do you explain the perf difference? is

Don't be too harsh on yourself Alex ;-)

@agramfort
Copy link
Member

pardon my stupidity but how do you explain the perf difference? is

Don't be too harsh on yourself Alex ;-)

I am fighting against my ego :)

btw @pprett congrats for the amazing job !

@pprett
Copy link
Member Author

pprett commented Mar 22, 2012

@agramfort I don't think it's due to any overhead w.r.t. python itself - I think its due to a number of subtle differences that "add up"; One interesting aspect though is the effect of the number of features on the performance. The datasets differ considerably w.r.t. the number of features::

Dataset     n_fx    n_train   n_test
Example 10.2 10  2000  10000
Spam 57  1536  3065
Madelon 500  2000  600
Arcene 10000  100  100
Boston 13 455 50
Friedman #1 10  200  1000
Friedman #2 4  200  1000
Friedman #3 4  200  1000

It seems that the more features, the better we are off - e.g. GBM completely fails on Arcene which has quite a lot of features.

@pprett pprett merged commit 923a759 into scikit-learn:master Mar 22, 2012
@amueller
Copy link
Member

WOOOT!!! :)

@ndawe
Copy link
Member

ndawe commented Mar 24, 2012

Great work! Just a minor suggestion: shouldn't the classifier and regressor be named GradientBoostedClassifier and GradientBoostedRegressor instead of GradientBoostingClassifier and GradientBoostingRegressor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.