[MRG+2] Sample weights for gradient boosting #3224

Merged
merged 19 commits into from Sep 22, 2014

Conversation

Projects
None yet
8 participants
Owner

pprett commented May 30, 2014

and piggy-packed 'exponential' loss for binary classification (= AdaBoost).

I didn't implement sample weights for 'lad', 'huber', and 'quantile' since it would require weighted median / percentiles. I will either raise a warning or exception in this case: what do you prefer?

cc @glouppe

TODO: benchmark benchmark and some regression tests

Owner

ndawe commented May 30, 2014

Well done! You made my day.

Owner

glouppe commented May 31, 2014

Great job Peter! I am quite busy at the moment, but I'll try to find time in the next days for reviewing this.

I will either raise a warning or exception in this case: what do you prefer?

I would go for an exception. Since sample weights are not supported for these losses, it is better in my opinion to raise an exception than computing something as if everything was ok.

Owner

larsmans commented Jun 1, 2014

Exception for the reason given by Gilles and because adding the weighting later would change results.

Owner

pprett commented Jun 1, 2014

@glouppe @larsmans I now raise a ValueError if sample_weight is given and we use robust regression ('lad', 'huber', 'quantile')

Owner

pprett commented Jun 2, 2014

Here are some benchmarks to test for performance regressions:

Sklearn is a version from Jan/Feb after Lars and Gilles worked on cache optimization and sorting (was too lazy to rerun them). Sklearn-sw is this branch. XGboost and gbm are given as reference. XGboost uses 4 threads, both sklearn and xgboost build complete trees (xgboost has a slightly different stopping criteria)

Errors (lower is better)

benchmark_sw_err

Sklearn versions show little difference

Training time

benchmark_sw_train

slight increase in training time (< 2%). Training times for XGboost are astonishing -- look at bioresponse (large n_features).

Testing time

benchmark_sw_test

slight increase in testing time (<5%)

Owner

glouppe commented Jun 2, 2014

Thanks for the benchmarks! The slight increase is fine with me.

Owner

pprett commented Jun 2, 2014

Here is an accuracy benchmark using random sample weights on some synthetic data -- gbm vs. this pr:

benchmark_sw_gbm

results look good (both bernoulli and least-squares)

Owner

arjoly commented Jun 2, 2014

Training times for XGboost are astonishing -- look at bioresponse (large n_features).

Any ideas where the difference comes from?
The default parameters values are defined at https://github.com/tqchen/xgboost/blob/master/booster/tree/xgboost_tree_model.h#L417.

Owner

pprett commented Jun 2, 2014

@arjoly some differences:

  • multi-threaded training (over features in each split point; i used 4 threads)
  • tree-growing stopping criterion based on min improvement (trees don't look too different though)
  • slightly different algorithm / data structures to look for best split (sample-mask?)

@arjoly arjoly and 1 other commented on an outdated diff Jun 2, 2014

sklearn/ensemble/gradient_boosting.py
class PriorProbabilityEstimator(BaseEstimator):
"""An estimator predicting the probability of each
class in the training data.
"""
- def fit(self, X, y):
- class_counts = np.bincount(y)
- self.priors = class_counts / float(y.shape[0])
+ def fit(self, X, y, sample_weight=None):
+ if sample_weight is None:
+ class_counts = np.bincount(y)
+ priors = class_counts / float(y.shape[0])
+ else:
+ classes = np.unique(y)
+ n_classes = classes.shape[0]
+ priors = np.zeros(n_classes, dtype=np.float64)
+ for c in classes:
+ mask = y == c
+ priors[c] = np.sum(sample_weight[mask])
+ priors /= priors.sum()
@arjoly

arjoly Jun 2, 2014

Owner

np.bincount supports a weight argument.

@pprett

pprett Jun 2, 2014

Owner

great - didn't knew that - thx @arjoly

@arjoly arjoly commented on the diff Jun 2, 2014

sklearn/ensemble/gradient_boosting.py
@@ -219,14 +248,19 @@ class LeastSquaresError(RegressionLossFunction):
def init_estimator(self):
return MeanEstimator()
- def __call__(self, y, pred):
- return np.mean((y - pred.ravel()) ** 2.0)
+ def __call__(self, y, pred, sample_weight=None):
+ if sample_weight is None:
+ return np.mean((y - pred.ravel()) ** 2.0)
+ else:
+ return (1.0 / sample_weight.sum()) * \
+ np.sum(sample_weight * ((y - pred.ravel()) ** 2.0))
@arjoly

arjoly Jun 2, 2014

Owner

Here, you can use np.dot.

@arjoly

arjoly Jun 2, 2014

Owner

Or even np.average which would work for both version.

np.average(((y_pred - y_true) ** 2), weights=sample_weight)
@larsmans

larsmans Jun 2, 2014

Owner

average is slow:

>>> y = np.random.randn(10000)
>>> pred = np.random.randn(10000)
>>> sample_weight = np.random.randn(10000)
>>> %timeit np.average((y - pred) ** 2, weights=sample_weight)
1000 loops, best of 3: 306 us per loop
>>> %timeit np.sum(sample_weight * ((y - pred) ** 2))
10000 loops, best of 3: 174 us per loop

np.einsum is also an option:

>>> %timeit d = y - pred; np.einsum('i,i,i', sample_weight, d, d)
10000 loops, best of 3: 91.3 us per loop
@pprett

pprett Sep 7, 2014

Owner

thanks but code here is not performance critical and I'd rather go for readability

@ogrisel

ogrisel Sep 18, 2014

Owner

I find einsum readable enough in this case. But fair enough.

@arjoly arjoly commented on the diff Jun 2, 2014

sklearn/ensemble/gradient_boosting.py
@@ -244,8 +278,12 @@ class LeastAbsoluteError(RegressionLossFunction):
def init_estimator(self):
return QuantileEstimator(alpha=0.5)
- def __call__(self, y, pred):
- return np.abs(y - pred.ravel()).mean()
+ def __call__(self, y, pred, sample_weight=None):
+ if sample_weight is None:
+ return np.abs(y - pred.ravel()).mean()
+ else:
+ return (1.0 / sample_weight.sum()) * \
+ np.sum(sample_weight * np.abs(y - pred.ravel()))
@arjoly

arjoly Jun 2, 2014

Owner

Here you could use,

np.average(np.abs(y_pred - y_true), weights=sample_weight)

@arjoly arjoly commented on the diff Jun 2, 2014

sklearn/ensemble/gradient_boosting.py
@@ -68,8 +68,11 @@ def predict(self, X):
class MeanEstimator(BaseEstimator):
"""An estimator predicting the mean of the training targets."""
- def fit(self, X, y):
- self.mean = np.mean(y)
+ def fit(self, X, y, sample_weight=None):
+ if sample_weight is None:
+ self.mean = np.mean(y)
+ else:
+ self.mean = np.average(y, weights=sample_weight)
@arjoly

arjoly Jun 2, 2014

Owner

self.mean = np.average(y, weights=sample_weight) would work for both.

Owner

arjoly commented Jun 2, 2014

Why not implementing a weighted percentile / median?

This corresponds to a few lines of numpy (see http://stackoverflow.com/questions/20601872/numpy-or-scipy-to-calculate-weighted-median and https://github.com/nudomarinero/wquantiles).

Owner

pprett commented Jun 2, 2014

@arjoly agreed - wquantiles looks good - will do that!

Owner

arjoly commented Jun 2, 2014

Should the negative_gradient take into account sample_weights?

Owner

pprett commented Jun 2, 2014

@arjoly none of the existing loss functions need to do so since we weight the negative gradients when we fit the regression trees

Coverage Status

Changes Unknown when pulling 4a5437a on pprett:gbrt-sample-weight into * on scikit-learn:master*.

Owner

arjoly commented Jun 2, 2014

@arjoly none of the existing loss functions need to do so since we weight the negative gradients when we fit the regression trees

thanks !

@arjoly arjoly and 1 other commented on an outdated diff Jun 2, 2014

sklearn/ensemble/gradient_boosting.py
class PriorProbabilityEstimator(BaseEstimator):
"""An estimator predicting the probability of each
class in the training data.
"""
- def fit(self, X, y):
- class_counts = np.bincount(y)
- self.priors = class_counts / float(y.shape[0])
+ def fit(self, X, y, sample_weight=None):
+ if sample_weight is None:
+ class_counts = np.bincount(y)
+ priors = class_counts / float(y.shape[0])
+ else:
+ weighted_class_counts = np.bincount(y, weights=sample_weight)
+ priors = weighted_class_counts / sample_weight.sum()
+ self.priors = priors
@arjoly

arjoly Jun 2, 2014

Owner

why not

class_counts = np.bincount(y, weight=sample_weight)
self.priors = class_counts / class_counts.sum()

?

Coverage Status

Changes Unknown when pulling f3a3b0f on pprett:gbrt-sample-weight into * on scikit-learn:master*.

Owner

ogrisel commented Jun 9, 2014

@pprett is this WIP or MRG? If WIP, what remains to do before reaching MRG state?

pprett changed the title from Sample weights for gradient boosting to [MRG] Sample weights for gradient boosting Jun 17, 2014

Owner

pprett commented Jun 17, 2014

@ogrisel sorry, forgot about that -- yeah - is MRG state

Owner

arjoly commented Jun 25, 2014

There is an implementation of the weighted median in weight_boosting

    def _get_median_predict(self, X, limit):
        # Evaluate predictions of all estimators
        predictions = np.array([
            est.predict(X) for est in self.estimators_[:limit]]).T

        # Sort the predictions
        sorted_idx = np.argsort(predictions, axis=1)

        # Find index of median prediction for each sample
        weight_cdf = self.estimator_weights_[sorted_idx].cumsum(axis=1)
        median_or_above = weight_cdf >= 0.5 * weight_cdf[:, -1][:, np.newaxis]
        median_idx = median_or_above.argmax(axis=1)

        median_estimators = sorted_idx[np.arange(X.shape[0]), median_idx]

        # Return median predictions
        return predictions[np.arange(X.shape[0]), median_estimators]
Owner

pprett commented Sep 7, 2014

@arjoly I rebased the PR. What do you think is still missing to get this into master? I'd rather postpone sample_weights for quantile and robust regression for now.

Owner

pprett commented Sep 7, 2014

@arjoly but if you feel strongly about it then I highly appreciate a patch ;)

Coverage Status

Coverage increased (+0.0%) when pulling 1813e1e on pprett:gbrt-sample-weight into 57f67d0 on scikit-learn:master.

Owner

arjoly commented Sep 14, 2014

Ok, without quantile and robust regression.

This will make move the Schmilblick

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 15, 2014

sklearn/ensemble/gradient_boosting.py
@@ -753,6 +899,21 @@ def fit(self, X, y, monitor=None):
# Check input
X, y = check_X_y(X, y, dtype=DTYPE)
n_samples, n_features = X.shape
+ if sample_weight is None:
+ sample_weight = np.ones(n_samples, dtype=np.float32)
+ else:
+ sample_weight = column_or_1d(sample_weight, warn=True)
+ if self.loss in ('lad', 'huber', 'quantile'):
+ raise ValueError('sample_weight not supported for loss=%r' %
+ self.loss)
@ogrisel

ogrisel Sep 15, 2014

Owner

If you plan to add support for sample_weight for those losses at some point in the future, I would rather raise NotImplementedError to state those are not mathematically invalid constructor parameters but could be implemented in a future version of scikit-learn.

@ogrisel ogrisel commented on the diff Sep 15, 2014

sklearn/ensemble/gradient_boosting.py
@@ -753,6 +899,21 @@ def fit(self, X, y, monitor=None):
# Check input
X, y = check_X_y(X, y, dtype=DTYPE)
n_samples, n_features = X.shape
+ if sample_weight is None:
+ sample_weight = np.ones(n_samples, dtype=np.float32)
+ else:
+ sample_weight = column_or_1d(sample_weight, warn=True)
+ if self.loss in ('lad', 'huber', 'quantile'):
+ raise ValueError('sample_weight not supported for loss=%r' %
+ self.loss)
+
+ if y.shape[0] != n_samples:
+ raise ValueError('Shape mismatch of X and y: %d != %d' %
+ (n_samples, y.shape[0]))
+ if n_samples != sample_weight.shape[0]:
+ raise ValueError('Shape mismatch of sample_weight: %d != %d' %
+ (sample_weight.shape[0], n_samples))
@ogrisel

ogrisel Sep 15, 2014

Owner

Note: at some point we could use sklearn.utils.validation.check_consistent_length but I rather use your manual check in the mean time as the error messages are more informative.

@pprett

pprett Sep 15, 2014

Owner

make sense -- check should allow for custom messages for check_consistent_length then

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 15, 2014

doc/modules/ensemble.rst
@@ -649,6 +649,10 @@ the parameter ``loss``:
prior probability of each class. At each iteration ``n_classes``
regression trees have to be constructed which makes GBRT rather
inefficient for data sets with a large number of classes.
+ * Exponential loss (``'exponential'``): The same loss function
+ as :class:`AdaBoostClassifier`. Less robust to mislabeled
+ examples than ``'deviance'``; cannot provide probability
+ estimates and can only be used for binary classification.
@ogrisel

ogrisel Sep 15, 2014

Owner

Is there any reason to use GradientBoostingClassifier(loss='exponential') over AdaBoostClassifier? or is the goal of introducing this loss is to make it easier to do a sanity / correctness check by comparing the two implementations?

If one implementation is always better than the other (e.g. from a CPU or RAM usage perspective I think this should be made more explicit here).

@pprett

pprett Sep 15, 2014

Owner

Loss was added for completeness -- from an algorithm point of view I'd suggest that users rather use AdaBoostClassifier, however, I havent benchmarked the two implementations against each other. I should do that...

@ogrisel

ogrisel Sep 15, 2014

Owner

Looking forward to the see the results ;)

@ogrisel ogrisel commented on an outdated diff Sep 15, 2014

sklearn/ensemble/gradient_boosting.py
@@ -261,7 +291,15 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
class HuberLossFunction(RegressionLossFunction):
- """Loss function for least absolute deviation (LAD) regression. """
+ """Huber loss function forrobust regression.
@ogrisel

ogrisel Sep 15, 2014

Owner

typo: missing space.

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 15, 2014

sklearn/ensemble/gradient_boosting.py
@@ -348,7 +384,25 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
tree.value[leaf, 0] = val
-class BinomialDeviance(LossFunction):
+class ClassificationLossFunction(six.with_metaclass(ABCMeta, LossFunction)):
+ """Base class for classification loss functions. """
+
+ def _score_to_proba(self, score):
+ """Template method to convert scores to probabilities.
+
+ If the loss does not support probabilites raises AttributeError.
+ """
+ raise AttributeError('Loss does not support predict_proba')
@ogrisel

ogrisel Sep 15, 2014

Owner

Maybe this should be a TypeError and the error message should reflect the name of the concrete class of self.

raise TypeError('%s does not support predict_proba' % type(self).__name__)
@pprett

pprett Sep 15, 2014

Owner

good point

check

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 15, 2014

sklearn/ensemble/gradient_boosting.py
"""Compute the deviance (= 2 * negative log-likelihood). """
# logaddexp(0, v) == log(1.0 + exp(v))
pred = pred.ravel()
- return -2.0 * np.mean((y * pred) - np.logaddexp(0.0, pred))
+ if sample_weight is None:
+ return -2.0 * np.mean((y * pred) - np.logaddexp(0.0, pred))
+ else:
+ return (-2.0 / sample_weight.sum()) * \
+ np.sum(sample_weight * ((y * pred) - np.logaddexp(0.0, pred)))
@ogrisel

ogrisel Sep 15, 2014

Owner

This could be rewritten without the \ by wrapping with the ) as * and / are commutative:

        return (-2.0 / sample_weight.sum() *
            np.sum(sample_weight * ((y * pred) - np.logaddexp(0.0, pred))))

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 15, 2014

sklearn/ensemble/tests/test_gradient_boosting.py
+ gbrt = GradientBoostingClassifier(n_estimators=100, min_samples_split=1,
+ max_depth=1, loss=loss,
+ learning_rate=1.0, random_state=0)
+ gbrt.fit(X_train, y_train)
+ error_rate = (1.0 - gbrt.score(X_test, y_test))
+ assert error_rate < 0.09, \
+ "GB(loss={}) failed with error {}".format(loss, error_rate)
+
+ gbrt = GradientBoostingClassifier(n_estimators=200, min_samples_split=1,
+ max_depth=1,
+ learning_rate=1.0, subsample=0.5,
+ random_state=0)
+ gbrt.fit(X_train, y_train)
+ error_rate = (1.0 - gbrt.score(X_test, y_test))
+ assert error_rate < 0.08, \
+ "Stochastic GB(loss={} failed with error {}".format(loss,
@ogrisel

ogrisel Sep 15, 2014

Owner

unbalanced parens in error message.

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 15, 2014

...semble/tests/test_gradient_boosting_loss_functions.py
+ pred = rng.rand(100)
+
+ # least squares
+ loss = LeastSquaresError(1)
+ loss_wo_sw = loss(y, pred)
+ loss_w_sw = loss(y, pred, np.ones(pred.shape[0], dtype=np.float32))
+ assert_almost_equal(loss_wo_sw, loss_w_sw)
+
+
+def test_sample_weight_init_estimators():
+ """Smoke test for init estimators with sample weights. """
+ rng = check_random_state(13)
+ X = rng.rand(100, 2)
+ sample_weight = np.ones(100)
+ reg_y = rng.rand(100)
+ #reg_pred = rng.rand(100)
@ogrisel

ogrisel Sep 15, 2014

Owner

please remove this line

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 15, 2014

...semble/tests/test_gradient_boosting_loss_functions.py
+ loss = LeastSquaresError(1)
+ loss_wo_sw = loss(y, pred)
+ loss_w_sw = loss(y, pred, np.ones(pred.shape[0], dtype=np.float32))
+ assert_almost_equal(loss_wo_sw, loss_w_sw)
+
+
+def test_sample_weight_init_estimators():
+ """Smoke test for init estimators with sample weights. """
+ rng = check_random_state(13)
+ X = rng.rand(100, 2)
+ sample_weight = np.ones(100)
+ reg_y = rng.rand(100)
+ #reg_pred = rng.rand(100)
+
+ clf_y = rng.randint(0, 2, size=100)
+ #clf_pred = rng.randint(0, 1, size=100)
@ogrisel

ogrisel Sep 15, 2014

Owner

and this one too.

Owner

ogrisel commented Sep 15, 2014

Other than previous comments, +1 on my side.

Coverage Status

Coverage decreased (-0.0%) when pulling 2439ea8 on pprett:gbrt-sample-weight into 38f5b6c on scikit-learn:master.

Coverage Status

Coverage decreased (-0.01%) when pulling 8c1a95f on pprett:gbrt-sample-weight into 38f5b6c on scikit-learn:master.

Owner

pprett commented Sep 17, 2014

@arjoly @ogrisel added sample weights for robust regression and quantile regression

Coverage Status

Coverage decreased (-0.01%) when pulling 2d50328 on pprett:gbrt-sample-weight into 12f63da on scikit-learn:master.

Owner

pprett commented Sep 17, 2014

@ogrisel here is a quick benchmark with GradientBoostingClassifier vs. AdaBoostClassifier on covertype:

Classifier     train-time     test-time error-rate
-----------------------------------------------------------
AdaBoost     919.2482s   1.3197s     0.1530
GBRT           423.9343s   0.2883s     0.1561
'GBRT': GradientBoostingClassifier(n_estimators=100, loss='exponential', max_leaf_nodes=6,
                                       learning_rate=1.0,
                                       min_samples_leaf=5, verbose=1),
 'AdaBoost': AdaBoostClassifier(n_estimators=100,
                                   base_estimator=DecisionTreeClassifier(min_samples_leaf=5,
                                                                         max_leaf_nodes=6)),
Owner

ogrisel commented Sep 17, 2014

Interesting, someone in the Higgs Boson kaggle challenge reported that he/she used AdaBoost because sklearn GBRT were comparatively slower (with another loss though obviously).

How do you explain the large difference at test time? The generated trees should have the same size no?

Owner

ogrisel commented Sep 17, 2014

Here is the post I referenced in my previous comment: https://www.kaggle.com/c/higgs-boson/forums/t/10344/winning-methodology-sharing

Owner

pprett commented Sep 17, 2014

GradientBoosting doesnt use the dt.predict method but a c routine that
operates on the dt.tree_ extension type directly -- I optimized that once -
I think it used to be even faster. It shouldnt be a big difference for
large batches though... it matters a lot if you make predictions on single
data points.

2014-09-17 18:50 GMT+02:00 Olivier Grisel notifications@github.com:

Interesting, someone the Higgs Boson kaggle challenge reported that he/she
used Adaboost because sklearn GBRT were comparatively slower (with another
loss though obviously).

How do you explain the large difference at test time? The generated trees
should have the same size no?


Reply to this email directly or view it on GitHub
#3224 (comment)
.

Peter Prettenhofer

Owner

ogrisel commented Sep 17, 2014

Thanks, good to know.

Owner

pprett commented Sep 17, 2014

@ogrisel @arjoly what do you think -- good to go?

@ogrisel ogrisel commented on an outdated diff Sep 18, 2014

sklearn/ensemble/gradient_boosting.py
@@ -50,6 +50,18 @@
from ._gradient_boosting import _random_sample_mask
+def _weighted_percentile(arr, sample_weight, percentile=50):
+ """Compute the weighted ``percentile`` of ``arr`` with ``sample_weight``. """
@ogrisel

ogrisel Sep 18, 2014

Owner

I don't like variable names like arr much. I prefer real words like data or array.

@ogrisel ogrisel commented on the diff Sep 18, 2014

sklearn/ensemble/gradient_boosting.py
@@ -79,12 +97,20 @@ def predict(self, X):
class LogOddsEstimator(BaseEstimator):
"""An estimator predicting the log odds ratio."""
- def fit(self, X, y):
- n_pos = np.sum(y)
- n_neg = y.shape[0] - n_pos
- if n_neg == 0 or n_pos == 0:
+ scale = 1.0
+
+ def fit(self, X, y, sample_weight=None):
+ # pre-cond: pos, neg are encoded as 1, 0
+ if sample_weight is None:
+ pos = np.sum(y)
+ neg = y.shape[0] - pos
+ else:
+ pos = np.sum(sample_weight * y)
+ neg = np.sum(sample_weight * (1 - y))
@ogrisel

ogrisel Sep 18, 2014

Owner

np.average might spare a memory allocation. Not a big deal though,

@ogrisel

ogrisel Sep 18, 2014

Owner

Actually scratch that, we want the sum, not the mean. That would require multiplying by the sum of the weights... Keep the current code that is simpler.

@arjoly

arjoly Sep 18, 2014

Owner

Why not using np.dot?

@pprett

pprett Sep 18, 2014

Owner

because a) performance difference is negligible compared to other computations and b) readability

@ogrisel ogrisel commented on the diff Sep 18, 2014

sklearn/ensemble/tests/test_gradient_boosting.py
- assert_raises(ValueError, clf.predict, boston.data)
- clf.fit(boston.data, boston.target)
- y_pred = clf.predict(boston.data)
- mse = mean_squared_error(boston.target, y_pred)
- assert mse < 6.0, "Failed with loss %s and " \
- "mse = %.4f" % (loss, mse)
+ assert_raises(ValueError, clf.predict, boston.data)
+ clf.fit(boston.data, boston.target,
+ sample_weight=sample_weight)
+ y_pred = clf.predict(boston.data)
+ mse = mean_squared_error(boston.target, y_pred)
+ assert mse < 6.0, "Failed with loss %s and " \
+ "mse = %.4f" % (loss, mse)
@ogrisel

ogrisel Sep 18, 2014

Owner

Actually it would even be better to check that samples_weight = {None, np.ones(n_samples), 2 * np.ones(n_samples)} lead the exact same predictions (for a fixed random_state).

@ogrisel

ogrisel Sep 18, 2014

Owner

Also I don't see any test that check for none uniform sample weights. Maybe a toy edge case like:

X = [
    [1, 0],
    [1, 0],
    [1, 0],
    [0, 1],
])
y = [0, 0, 1, 0]
# ignore the first 2 training samples by setting their weight to 0
sample_weight = [0, 0, 1, 1]
gb = GradientBoostingClassifier(n_trees=5)
gb.fit(X, y, sample_weight=sample_weight)
assert_array_equal(gb.predict([[1, 0]]), [1])
@glouppe

glouppe Sep 18, 2014

Owner

Actually it would even be better to check that samples_weight = {None, np.ones(n_samples), 2 * np.ones(n_samples)} lead the exact same predictions (for a fixed random_state).

+1

@pprett

pprett Sep 18, 2014

Owner

added 2 * ones and compared all three to make sure we create the same predictions; huber loss failed so I assume I have an error in the weight updates. I adapted that formula without calculating the update myself - will do.

Owner

ogrisel commented Sep 18, 2014

It would be great to add class_weight=None and to add support for class_weight="auto" as constructor parameter of GradientBoostedClassifier and use the same utils as SGDClassifier to generate sample weights that related to the inverted class frequencies. If this is implemented it could be tested by checking that class_weight='auto' yields improved cross validated f1_score on a seriously imbalanced dataset vs class_weight=None.

This could be done in a separate PR though.

Owner

ogrisel commented Sep 18, 2014

Other than my comment on testing non uniform sample_weight, +1 for merge on my side.

@glouppe glouppe and 1 other commented on an outdated diff Sep 18, 2014

sklearn/ensemble/gradient_boosting.py
gamma_mask = np.abs(diff) <= gamma
- sq_loss = np.sum(0.5 * diff[gamma_mask] ** 2.0)
- lin_loss = np.sum(gamma * (np.abs(diff[~gamma_mask]) - gamma / 2.0))
- return (sq_loss + lin_loss) / y.shape[0]
+ if sample_weight is None:
+ sq_loss = np.sum(0.5 * diff[gamma_mask] ** 2.0)
+ lin_loss = np.sum(gamma * (np.abs(diff[~gamma_mask]) - gamma / 2.0))
+ else:
+ sq_loss = np.sum(0.5 * sample_weight[gamma_mask] * diff[gamma_mask] ** 2.0)
+ lin_loss = np.sum(gamma * sample_weight[~gamma_mask] *
+ (np.abs(diff[~gamma_mask]) - gamma / 2.0))
+ return (sq_loss + lin_loss) / sample_weight.sum()
@glouppe

glouppe Sep 18, 2014

Owner

Unless I am missing something, how can this work when sample_weight=None? The return statement will fail, wont it?

@pprett

pprett Sep 18, 2014

Owner

fail... thanks @glouppe -- added a test for all loss functions to make sure that deviance supports both sample weights and None

Owner

glouppe commented Sep 18, 2014

I have been using this branch for the past few weeks and it has been working well so far.

I am +1 for merge once Olivier's comments are addressed.

pprett added some commits Sep 18, 2014

@pprett pprett fix: sample_weight is None in Huber.deviance
fix: sample_weight multiplication in Huber leaf updates
more tests (non-uniform weights, weight consistency and invariance to scaling, deviance consistence)
05fbc64
@pprett pprett cosmit arr -> array 4966b12

Coverage Status

Coverage increased (+0.01%) when pulling 4966b12 on pprett:gbrt-sample-weight into 12f63da on scikit-learn:master.

Owner

pprett commented Sep 18, 2014

@ogrisel @glouppe thanks for the heads up -- I fixed two important issues in the Huber loss

added more tests

Owner

ogrisel commented Sep 18, 2014

I fixed two important issues in the Huber loss added more tests

Nice to hear that code reviews help kill real bugs :)

ogrisel changed the title from [MRG] Sample weights for gradient boosting to [MRG+2] Sample weights for gradient boosting Sep 18, 2014

@ogrisel ogrisel and 1 other commented on an outdated diff Sep 18, 2014

sklearn/ensemble/tests/test_gradient_boosting.py
+ clf.fit(X, y)
+ assert_array_equal(clf.predict(T), true_result)
+ assert_raises(TypeError, clf.predict_proba, T)
+ assert_raises(TypeError, lambda : next(clf.staged_predict_proba(T)))
+
+
+def test_non_uniform_weights_toy_edge_case():
+ X = [[1, 0],
+ [1, 0],
+ [1, 0],
+ [0, 1],
+ ]
+ y = [0, 0, 1, 0]
+ # ignore the first 2 training samples by setting their weight to 0
+ sample_weight = [0, 0, 1, 1]
+ gb = GradientBoostingClassifier(n_estimators=5)
@ogrisel

ogrisel Sep 18, 2014

Owner

Maybe you can do a loop over all the classification losses. It might be good to forge a similar edge case test for regression losses.

Owner

glouppe commented Sep 19, 2014

Anything left to merge this?

Owner

GaelVaroquaux commented Sep 19, 2014

Anything left to merge this?

I don't know, I haven't looked at the code. But it seems that @ogrisel
has given a +1 and you too. So that should tell us to merge the code. Two
+1s from core contributors are clearly a merge to me. Especially since
this is not controversial.

Owner

ogrisel commented Sep 19, 2014

I am still +1 to merge as it is but I would like to give @pprett a chance to tackle my last comment on testing with more losses: #3224 (comment) if he wishes to prior to merging.

Coverage Status

Coverage increased (+0.02%) when pulling e6fa800 on pprett:gbrt-sample-weight into 12f63da on scikit-learn:master.

Owner

pprett commented Sep 21, 2014

@ogrisel added tests and also added probabilistic outputs for exponential loss.

Coverage Status

Coverage increased (+0.01%) when pulling de11662 on pprett:gbrt-sample-weight into 12f63da on scikit-learn:master.

@ogrisel ogrisel added a commit that referenced this pull request Sep 22, 2014

@ogrisel ogrisel Merge pull request #3224 from pprett/gbrt-sample-weight
[MRG+2] Sample weights for gradient boosting
7dec87c

@ogrisel ogrisel merged commit 7dec87c into scikit-learn:master Sep 22, 2014

1 check passed

continuous-integration/travis-ci The Travis CI build passed
Details
Owner

ogrisel commented Sep 22, 2014

Thanks @pprett, merged!

Owner

arjoly commented Sep 22, 2014

Thanks @pprett !!!

Owner

GaelVaroquaux commented Sep 22, 2014

Thanks @pprett !!!

As always, good job team! On the coding and on the review side!

Owner

pprett commented Sep 22, 2014

JUHU -- thanks guys! @ogrisel @arjoly @glouppe

Owner

glouppe commented Sep 22, 2014

🍻

@IssamLaradji IssamLaradji added a commit to IssamLaradji/scikit-learn that referenced this pull request Oct 13, 2014

@ogrisel @IssamLaradji ogrisel + IssamLaradji Merge pull request #3224 from pprett/gbrt-sample-weight
[MRG+2] Sample weights for gradient boosting
f55c8ca
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment