Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

Addressing issue #1802 (Gradient Boosting Out-of-bag Estimates) #1806

Closed
wants to merge 1 commit into from

5 participants

yanirs Gilles Louppe Andreas Mueller Peter Prettenhofer Olivier Grisel
yanirs

Changed oob_score_ to be calculated based only on trees where the OOB instances weren't used for training.
Not entirely sure about the correctness of doing it this way, but it seems to make the OOB scores more reliable (based on local testing).

yanirs yanirs Addressing issue #1802 (Gradient Boosting Out-of-bag Estimates): chan…
…ged oob_score_ to be calculated based only on trees where the OOB instances weren't used for training
c28aed7
Gilles Louppe
Owner
yanirs

Do you have a better idea?

yanirs yanirs closed this
yanirs yanirs reopened this
Gilles Louppe
Owner

Don't take it the wrong way. Thank you for your contribution!

The thing is, I am not sure we can have non-biased oob estimates at all from boosting algorithms. We should dig in the literature to see whether it can be done.

yanirs

No worries, no offence taken :-)

I completely agree that if there is a better way to generate OOB estimates, it should be used. I couldn't find anything in the literature, but admittedly I mostly skimmed the relevant (major) publications.
In any case, with the current OOB score implementation, it seems like it improves indefinitely (like the in-bag scores), at least based on my limited testing. So it's not very useful as a way of tuning n_estimators.

Andreas Mueller
Owner

I think the current implementation should definitely be removed. There are no "out of bag" estimates here as there are no bags imho. But maybe there is a way to get some estimate via the subsampling.

Gilles Louppe
Owner

I think the current implementation should definitely be removed. There are no "out of bag" estimates here as there are no bags imho. But maybe there is a way to get some estimate via the subsampling.

Agree.

Peter Prettenhofer
Owner

@amueller I disagree, there are "bags" - still the oob estimates might be fundamentally flawed due to the fact that we already peeked at the out-of-bag samples (via our old iterations).

thinking aloud: I totally see a problem when using OOB estimates as a proxy for generalization error - but they might be interesting for model selection -- in particular, for selecting the number of boosting iteration or for early stopping.

as @yanirs already said, there is no account for out-of-bag estimates in the primary literature but its featured in R's gbm and ada vignettes.

Peter Prettenhofer
Owner

even Wikipedia discusses OOB for gradient boosting; it references the Stochastic Gradient Boosting paper by Friedman which doesn't discuss OOB as far as I know.

http://en.wikipedia.org/wiki/Gradient_boosting

Andreas Mueller
Owner

@pprett Maybe I should have said "gradient boosting is not bagging".

Bagging = bootstrap aggregation, i.e. averaging over bootstrap samples.

If you don't do bagging, I don't think is is appropriate to talk about bags ;)

Andreas Mueller
Owner

(apparently the author of gbm disagrees with me on the last point and calls the individual samples in each round bag).

Andreas Mueller
Owner

So how is it implemented in gbm? Just using the trees that don't use the sample?

Peter Prettenhofer
Owner

in gbm, the oob estimate of the i-th iteration is just the loss (=deviance) of the out-of-bag samples of the current (the i-th) tree. this is the same what we have in sklearn.

Andreas Mueller
Owner

That sounds like a horrible idea. Interesting.

Gilles Louppe
Owner

@pprett @amueller Are we keeping things as they are (and as they also are in gbm) or shall we do anything?

My opinion would be to keep things as they are now and close this issue.

Peter Prettenhofer
Owner
Gilles Louppe glouppe closed this
yanirs

Seriously? After agreeing that the current OOB implementation is broken you're just going to leave it as it is? At the very least, it's worth noting this issue in the documentation to avoid misleading users. I spent quite a bit of time digging through the code to figure out why the OOB scores never stop improving, and I would like other users to avoid going down the same path.
If I had time, I would run some more experiments to demonstrate that my solution works (I only tested it on one dataset: https://www.kaggle.com/c/bluebook-for-bulldozers/forums/t/4368/congratulations-to-the-preliminary-winners/23125#post23125), which I think is a more reasonable way forward than copying the bugs in R's gbm implementation.

Gilles Louppe glouppe reopened this
Gilles Louppe
Owner

This still sounds as a terrible idea to me to skip trees, in the sense that I really don't understand what your "oob estimates" actually represent. Even if it this biased, the current oob implementation at least makes sense.

But okay, let's reopen this and continue the discussion.

Peter Prettenhofer
Owner

I think the best way to move forward is to compile a test suite with benchmark datasets (e.g. hastie 10_2, spam, cal_housing) and compare the OOB estimator for n_estimators against a CV estimator.

As far as I understand @yanirs correctly, the issue is that OOB score always improves and thus over-estimates the optimal number of boosting iterations. G. Ridgeway actually claims the opposite: "OOB underestimates the optimal number of iterations".

Gilles Louppe
Owner

In boosting, the only reliable way I see to have a non-biased approximation of the generalization error is to put aside a validation set and evaluate the model on it after training.

For a given sample x_i, it makes no sense in my opinion to only consider those trees that have not been trained on x_i. The prediction is a sequential process, not a vote. I don't understand why you could skip steps in that process.

yanirs

Yes @pprett, from my testing oob_score_ behaves like train_score_, and having oob_score_ is misleading because it's not really oob.

I agree with @glouppe that the best approach is to use an external validation set, but this is not always practical. Without further testing, I'm not confident that the approach of skipping trees is always correct, but it makes more sense to me than the current oob_score_ implementation. So I think that the minimal solution is to note this issue in the documentation until further testing is performed (I may end up running more experiments, but I'm a flooded with work at the moment).

Gilles Louppe
Owner

So I think that the minimal solution is to note this issue in the documentation

I agree. We should indeed at least says that it is heavily biased, nearly as much as the training error.

Peter Prettenhofer
Owner

I had a closer look at the issue - in particular, the GBM implementation of OOB estimates to compute the "optimal" number of iterations.

GBM does not record the OOB score per se but the improvement in deviance on the OOB samples by adding the predictions of the current tree to the predictions we have so far.
In order to estimate the optimal number of iterations GBM smooths the improvement scores using LOESS, computes the cumulative sum and picks the max**

The relevant functions are src/distribution.h:BagImprovement and gbm.perf.R:gbm.perf#20 .

** GBM actually picks the min of the negative cumsum (see code referenced above)

Andreas Mueller
Owner

We should definitely warn. Maybe we should even rename? This really is no oob estimate (with either "method").

I have to read up on what Peter said to understand what is going on in gbm.

Peter Prettenhofer
Owner

@amueller I'll work on this in the next days

Peter Prettenhofer
Owner

Pull request submitted for OOB improvements - seems to work for selecting n_estimators #2188 .

I'll close this one as I think using the OOB improvements is the correct way to do it (re-open if disagree)

Peter Prettenhofer pprett closed this
yanirs

I finally got around to testing my tree-skipping approach on multiple datasets: http://yanirseroussi.com/2014/12/29/stochastic-gradient-boosting-choosing-the-best-number-of-iterations/

I agree with the decision not to merge the original pull request, but I was thinking that it may be useful to others if I contributed my experimental code to extend the example in http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_oob.html to real datasets. Thoughts?

Peter Prettenhofer
Owner
yanirs

Thanks Peter! Good point about adding a notebook to the post. Notebooks do tend to be more readable than plain Python files.

Olivier Grisel
Owner

Your blog engine might also have a plugin to directly turn notebooks into blog posts:

Or you can use nbconvert to script the conversion:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Commits on Mar 24, 2013
  1. yanirs

    Addressing issue #1802 (Gradient Boosting Out-of-bag Estimates): chan…

    yanirs authored
    …ged oob_score_ to be calculated based only on trees where the OOB instances weren't used for training
This page is out of date. Refresh to see the latest.
Showing with 15 additions and 10 deletions.
  1. +15 10 sklearn/ensemble/gradient_boosting.py
25 sklearn/ensemble/gradient_boosting.py
View
@@ -138,7 +138,7 @@ def negative_gradient(self, y, y_pred, **kargs):
The predictions.
"""
- def update_terminal_regions(self, tree, X, y, residual, y_pred,
+ def update_terminal_regions(self, tree, X, y, residual, y_pred, y_pred_oob,
sample_mask, learning_rate=1.0, k=0):
"""Update the terminal regions (=leaves) of the given tree and
updates the current predictions of the model. Traverses tree
@@ -171,8 +171,9 @@ def update_terminal_regions(self, tree, X, y, residual, y_pred,
y_pred[:, k])
# update predictions (both in-bag and out-of-bag)
- y_pred[:, k] += (learning_rate
- * tree.value[:, 0, 0].take(terminal_regions, axis=0))
+ tree_pred = tree.value[:, 0, 0].take(terminal_regions, axis=0)
+ y_pred[:, k] += learning_rate * tree_pred
+ y_pred_oob[~sample_mask, k] += learning_rate * tree_pred[~sample_mask]
@abstractmethod
def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
@@ -202,14 +203,16 @@ def __call__(self, y, pred):
def negative_gradient(self, y, pred, **kargs):
return y - pred.ravel()
- def update_terminal_regions(self, tree, X, y, residual, y_pred,
+ def update_terminal_regions(self, tree, X, y, residual, y_pred, y_pred_oob,
sample_mask, learning_rate=1.0, k=0):
"""Least squares does not need to update terminal regions.
But it has to update the predictions.
"""
# update predictions
- y_pred[:, k] += learning_rate * tree.predict(X).ravel()
+ tree_pred = tree.predict(X).ravel()
+ y_pred[:, k] += learning_rate * tree_pred
+ y_pred_oob[~sample_mask, k] += learning_rate * tree_pred[~sample_mask]
def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
residual, pred):
@@ -454,7 +457,7 @@ def __init__(self, loss, learning_rate, n_estimators, min_samples_split,
self.verbose = verbose
self.estimators_ = np.empty((0, 0), dtype=np.object)
- def _fit_stage(self, i, X, X_argsorted, y, y_pred, sample_mask,
+ def _fit_stage(self, i, X, X_argsorted, y, y_pred, y_pred_oob, sample_mask,
random_state):
"""Fit another stage of ``n_classes_`` trees to the boosting model. """
loss = self.loss_
@@ -480,7 +483,8 @@ def _fit_stage(self, i, X, X_argsorted, y, y_pred, sample_mask,
# update tree leaves
loss.update_terminal_regions(tree.tree_, X, y, residual, y_pred,
- sample_mask, self.learning_rate, k=k)
+ y_pred_oob, sample_mask,
+ self.learning_rate, k=k)
# add tree to ensemble
self.estimators_[i, k] = tree
@@ -566,6 +570,7 @@ def fit(self, X, y):
# init predictions
y_pred = self.init_.predict(X)
+ y_pred_oob = self.init_.predict(X)
self.estimators_ = np.empty((self.n_estimators, self.loss_.K),
dtype=np.object)
@@ -584,15 +589,15 @@ def fit(self, X, y):
sample_mask = _random_sample_mask(n_samples, n_inbag,
random_state)
# fit next stage of trees
- y_pred = self._fit_stage(i, X, X_argsorted, y, y_pred, sample_mask,
- random_state)
+ y_pred = self._fit_stage(i, X, X_argsorted, y, y_pred, y_pred_oob,
+ sample_mask, random_state)
# track deviance (= loss)
if self.subsample < 1.0:
self.train_score_[i] = self.loss_(y[sample_mask],
y_pred[sample_mask])
self.oob_score_[i] = self.loss_(y[~sample_mask],
- y_pred[~sample_mask])
+ y_pred_oob[~sample_mask])
if self.verbose > 1:
print("built tree %d of %d, train score = %.6e, "
"oob score = %.6e" % (i + 1, self.n_estimators,
Something went wrong with that request. Please try again.