[MRG+2] Add class_weight support to the forests & trees #3961

Merged
merged 10 commits into from Jan 13, 2015

Projects

None yet

5 participants

@trevorstephens
Contributor

This PR adds support for specifying class_weight in the forest classifier constructors.

  • Right now it only supports single-output classification problems. It would be possible however to accept a list of dicts for each target (or just the preset string 'auto' overall), and multiply the expanded weight vectors such that a sample with two minority classes becomes even more important, etc. I couldn't find any precedent in other classifiers to guide this.
  • Should an exception, warning or tantrum be thrown for an 'auto' class_weight used with the warm_start option?

Any other comments ?

@coveralls

Coverage Status

Coverage increased (+0.01%) when pulling de35bea on trevorstephens:rf-class_weight into 6dab7c5 on scikit-learn:master.

@trevorstephens
Contributor

Referencing #2129 and #64

@trevorstephens trevorstephens commented on an outdated diff Dec 13, 2014
sklearn/ensemble/forest.py
@@ -388,7 +399,19 @@ def _validate_y(self, y):
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])
- return y
+ if self.class_weight is not None:
+ if self.n_outputs_ == 1:
+ cw = compute_class_weight(self.class_weight,
+ self.classes_[0],
+ y_org[:, 0])
+ cw = cw[np.searchsorted(self.classes_[0], y_org[:, 0])]
+ else:
+ raise NotImplementedError('class_weights are not supported '
+ 'for multi-output. You may use '
+ 'sample_weights in the fit method '
+ 'to weight by sample.')
+
@trevorstephens
trevorstephens Dec 13, 2014 Contributor

Further to my comment about multiplying the output's weights for multi-output, this is what I was thinking for replacing #L403-412:

cw = []
for k in range(self.n_outputs_):
    cw_part = compute_class_weight(self.class_weight,
                                   self.classes_[k],
                                   y_org[:, k])
    cw_part = cw_part[np.searchsorted(self.classes_[k],
                                      y_org[:, k])]
    cw.append(cw_part)
cw = np.prod(cw, axis=0, dtype=np.float64)

Another option would be to perform this action at the bootstrap level for 'auto' class_weights in the _parallel_build_trees method. However, this would make 'auto' act differently to user-defined class_weights, as the user-defined dict would be applied the same way regardless of the bootstrap sample.

@trevorstephens
Contributor

Added support for multi-output through the presets or use of a list of dicts for user-defined weights, ie [{-1: 0.5, 1: 1.}, {0: 1., 1: 1., 2: 2.}.

In decided whether to implement the "auto" weights at the full dataset level, or in the bootstrapping, I added another option for the user to enter class_weight="bootstrap" which will do the weighting based on the make-up of the bootstrap sample, while the "auto" option bases weights on the full dataset (and thus saves a bit of time by only doing it once).

These presets, as with the user-defined dicts, are multiplied together for multi-output in this implementation.

@trevorstephens
Contributor

Anyone have any thoughts on this so far?

@amueller amueller commented on an outdated diff Dec 17, 2014
sklearn/ensemble/forest.py
@@ -89,6 +87,27 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts
+ if class_weight == 'bootstrap':
@amueller
amueller Dec 17, 2014 Member

Shouldn't this be implemented in the tree? If that gets the auto class weight it can compute it itself, right?

@amueller
Member

Thanks for your contribution.
I think as much as possible should be implemented in the DecisionTreeClassifier/Regressor. I think the only case where you need to do something in the forest is 'auto' where you want to use the whole training set, right?

@trevorstephens
Contributor

I will check it out, but I think that the tree will not be able to see the bootstrap sample directly, as the over/under-sampling is done by adjusting the counts of each sample by simply changing the weight to +2, +1, 0, etc, and this is multiplied by sample_weight which could mask some of the counts if computed at the tree level. The entire, un-bootstrapped y array (transformed away from the original labels) is passed to the tree still.

The "auto" and user-defined weights could be done in the tree perhaps, but that would require re-computing the same weights for every tree, which seems like redundant work.

I'll undertake a bit of an investigation all the same.

@glouppe
Member
glouppe commented Dec 19, 2014

Hi Trevor! First of all, sorry for the lack of feedback on our side. It seems like all tree-huggers are quite busy with other things.

Regarding the implementation of class weights, there are basically two different approaches, as I explained some time ago in #2129. The approach that you chose, exploiting sample weights only, or the one using class weights as priors in the computation of the impurity criteria. It should be checked but I am not sure both are exactly equivalent. Anyhow, the sample-weight based implementation is certainly the simplest to implement. So for the record, I am +1 for implementing this algorithm rather than modifying all criteria.

@glouppe
Member
glouppe commented Dec 19, 2014

Thanks for your contribution.
I think as much as possible should be implemented in the DecisionTreeClassifier/Regressor. I think the only case where you need to do something in the forest is 'auto' where you want to use the whole training set, right?

I agree. In general, whatever option we have in forests, we also have in single trees. This should be the case as well for this feature I believe.

Note that from an implementation point of view, redundant work can be avoided easily. E.g. by forcing from the forest class_weight=None in the trees that are built but passing the correctly balanced sample weights.

@trevorstephens
Contributor

Hey Gilles, thanks for the input, and no worries on delays, understand that everyone here is pretty heavily loaded.

I'm going to take a look through the tree code this weekend and see how class_weight could be implemented there as well, and how/if it can interact nicely with the forests.

I think we're on the same page, but the redundant work I mentioned was just that the class_weight computed across all samples (auto and user-defined modes) will be the same for every tree, and that throwing the class_weight at the tree will result in the same calculation being performed for every estimator in the ensemble, which seems like unnecessary overhead, though could be a 'nicer' paradigm perhaps. I certainly agree that class_weight should be implemented for DecisionTreeClassifier though, whether or not it is used by the forests.

Anyhow, I'll probably check back in after my lovely 20-hour commute back home to Australia this evening :)

@trevorstephens trevorstephens changed the title from [WIP] Add support for class_weight to the forests to [WIP] Add class_weight support to the forests & trees Dec 23, 2014
@trevorstephens
Contributor

Added class_weight, using much the same code, to DecisionTreeClassifier except without the 'bootstrap' option. In terms of farming out the calculation of the class_weight to the trees though, I see a few issues:

  • User-defined weights are represented as a dict of labels and weights, these labels change as they are passed into the tree estimators (to 0, 1, 2, ...) so the dict would have to be transformed to match the new labels if it is to be understood by the individual tree estimators. Additionally, the calculation is the same for every estimator, so performing once at the ensemble level saves a touch of overhead in building each tree.
  • 'auto' weights could be passed to the trees, but as with the user-defined weights, these would be the same for each tree, and so doing it at the ensemble level saves doing the same calculation multiple times.
  • 'bootstrap' weights are done for each tree, but the bootstrap sample is represented as weights on the original y(s), which is multiplies by any sample_weight passed to the fit method. Thus, it could be hard to unravel what the bootstrap indices are once in the tree.

So I think keeping the class_weight calcs at the ensemble level and passing it to the individual trees as a re-balanced sample_weight makes more sense for all these cases. Interested to know what others think about the implementation though.

@glouppe
Member
glouppe commented Dec 27, 2014

So I think keeping the class_weight calcs at the ensemble level and
passing it to the individual trees as a re-balanced sample_weight
makes more sense for all these cases. Interested to know what others think
about the implementation though.

This is what I had in mind. +1 for implementing this this way.

@trevorstephens
Contributor

Great, thanks @glouppe

I'm working on validation of the class_weight parameter to raise more informative errors when appropriate. What are your thoughts on the warm_start option? I've noticed it's possible to pass a new X for additional estimators when using this option. BaseSGDClassifier raises a ValueError here for its partial_fit method. Should an error or warning be issued for warm_start when using the 'auto' or 'bootstrap' presets in the forests?

@glouppe
Member
glouppe commented Dec 28, 2014

Should an error or warning be issued for warm_start when using the 'auto' or 'bootstrap' presets in the forests?

I have no strong opinion on this, at the very least we should mention what happens in the docstring. Consistency with SGDClassifier is a good thing though.

@glouppe
Member
glouppe commented Jan 3, 2015

Is this still a work in progress? I can start reviewing the PR if you feel this is more or less ready.

@trevorstephens trevorstephens changed the title from [WIP] Add class_weight support to the forests & trees to [MRG] Add class_weight support to the forests & trees Jan 3, 2015
@trevorstephens
Contributor

That would be great @glouppe , thanks.

Added tests for the class_weight parameter that raise errors when warranted. This is mostly for multi-output and valid strings as the compute_class_weight module does a number of checks that I am currently deferring to. I can add explicit tests for these cases if you think it is necessary.

I warn for the string presets and warm_start as it appears this has slightly different usage to SGDClassifier from reading through the relevant discussions in the original warm_start issues and PR.

@glouppe glouppe commented on the diff Jan 5, 2015
sklearn/utils/estimator_checks.py
@@ -737,6 +737,8 @@ def check_class_weight_classifiers(name, Classifier):
classifier = Classifier(class_weight=class_weight)
if hasattr(classifier, "n_iter"):
classifier.set_params(n_iter=100)
+ if hasattr(classifier, "min_weight_fraction_leaf"):
+ classifier.set_params(min_weight_fraction_leaf=0.01)
@glouppe
glouppe Jan 5, 2015 Member

Why this?

@trevorstephens
trevorstephens Jan 5, 2015 Contributor

download

It's an extremely noisy dataset that the trees seem to have a hard time fitting as the test, I believe, is written for linear models. That was to force it to learn a very small model. I could also get it to pass by forcing a decision stump here with max_depth=1 for instance.

@amueller
amueller Jan 5, 2015 Member

maybe the test could also be reworked. The test passed for rbf svms, too, I think.
I wrote some of the class_weight tests, but I feel these are very hard to do. There are some examples with carefully crafted datasets, maybe these should be used instead?

@amueller
amueller Jan 5, 2015 Member

Actually, this test seems to be fine to me. As long as there is any regularization, it should work. I guess the bootstrapping alone doesn't help enough as there are very few trees in the ensemble by default.

@trevorstephens
trevorstephens Jan 6, 2015 Contributor

Yup, there's also only 2 features in this dataset too, so the randomized feature selection doesn't have much of a chance to generalize. Happy to make any mods that are deemed necessary though. Admit this is a bit of a hack to get a passing grade from Travis, though the module(s) do look a bit harder through my new Iris tests in tree and forest.

@amueller
amueller Jan 6, 2015 Member

I think your fix is ok.

@trevorstephens
Contributor

Any other comments on the PR @amueller & @glouppe ?

@glouppe glouppe commented on an outdated diff Jan 8, 2015
sklearn/ensemble/forest.py
@@ -377,8 +406,9 @@ def _set_oob_score(self, X, y):
self.oob_score_ = oob_score / self.n_outputs_
- def _validate_y(self, y):
- y = np.copy(y)
+ def _validate_y_cw(self, y_org):
@glouppe
glouppe Jan 8, 2015 Member

Can you be explicit and call the variable y_original?

@glouppe
Member
glouppe commented Jan 8, 2015

Besides the small cosmit, I am +1 for merge. Tests are thorough. Could you also add an entry in the whatsnew file?

Thanks for your work @trevorstephens! This is a very helpful addition :)

(and sorry for the lack of responsiveness these days...)

@trevorstephens
Contributor

Thanks @glouppe ! Requested changes are made. Second reviewer anyone (while I figure out how to get around a lovely 11th hour merge conflict) ... ?

@trevorstephens trevorstephens changed the title from [MRG] Add class_weight support to the forests & trees to [MRG+1] Add class_weight support to the forests & trees Jan 8, 2015
@trevorstephens
Contributor

Travis seems to have failed for Py2.7 on an unrelated doctest for model_selection.rst, and failed to complete too. Can it be restarted?

@amueller
Member
amueller commented Jan 8, 2015

This is very unusual..... There really shouldn't be any travis failures any more :-/ ping @ogrisel.

@amueller amueller commented on an outdated diff Jan 8, 2015
sklearn/ensemble/forest.py
+ '"auto" weights, use compute_class_weight("auto", '
+ 'classes, y). In place of y you can use a large '
+ 'enough sample of the full training set target to '
+ 'properly estimate the class frequency '
+ 'distributions. Pass the resulting weights as the '
+ 'class_weight parameter.')
+ elif self.n_outputs_ > 1:
+ if not hasattr(self.class_weight, "__iter__"):
+ raise ValueError("For multi-output, class_weight should "
+ "be a list of dicts, or a valid string.")
+ elif len(self.class_weight) != self.n_outputs_:
+ raise ValueError("For multi-output, number of elements "
+ "in class_weight should match number of "
+ "outputs.")
+
+ if self.class_weight != 'bootstrap' or not self.bootstrap:
@amueller
amueller Jan 8, 2015 Member

could / should this be refactored for trees and forests?

@amueller
Member
amueller commented Jan 8, 2015

Looks good to me apart from maybe a refactoring of the class weight computation.

@trevorstephens
Contributor

@amueller is there something you are specifically concerned about? The explicit for loop is simply over the number of outputs, which will usually be a small number, most of the time just 1. Or are you referring to class_weight_k being different for different options?

@amueller
Member
amueller commented Jan 9, 2015

I was not concerned about performance but code duplication.

@trevorstephens
Contributor

Ah. I made a few comments about this a couple of weeks back: #3961 (comment)

TL;DR: It may be possible, but would result in identical calculations for every tree where this way we only do it once.

@amueller
Member
amueller commented Jan 9, 2015

never mind then, 👍 from me.

@trevorstephens trevorstephens changed the title from [MRG+1] Add class_weight support to the forests & trees to [MRG+2] Add class_weight support to the forests & trees Jan 9, 2015
@trevorstephens
Contributor

Great, thanks!

Any way to kick off Travis again?

@GaelVaroquaux
Member

I restarted the failed job.

@trevorstephens
Contributor

@GaelVaroquaux thanks!

@trevorstephens
Contributor

Thanks for the reviews @amueller and @glouppe , looks like Travis is happy now, we good to merge?

@GaelVaroquaux GaelVaroquaux commented on an outdated diff Jan 10, 2015
sklearn/ensemble/forest.py
@@ -211,11 +232,17 @@ def fit(self, X, y, sample_weight=None):
self.n_outputs_ = y.shape[1]
- y = self._validate_y(y)
+ y, cw = self._validate_y_cw(y)
@GaelVaroquaux
GaelVaroquaux Jan 10, 2015 Member

Could we have a more explicite name for what "cw" means (for the variable name and the function call).

@GaelVaroquaux GaelVaroquaux commented on an outdated diff Jan 10, 2015
sklearn/ensemble/forest.py
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)
+ if cw is not None:
+ if sample_weight is not None:
+ sample_weight *= cw
@GaelVaroquaux
GaelVaroquaux Jan 10, 2015 Member

Don't we risk modifying an input argument here?

@trevorstephens
Contributor

@GaelVaroquaux , thanks for looking it over! Let me know what you think of the latest.

@trevorstephens
Contributor

@amueller @glouppe @GaelVaroquaux & others

I've been thinking about rolling class_weight out to the other ensembles in a future PR (so as not to undo the much-appreciated reviews so far) and have an API question. Forests use a bootstrap sample, but GradientBoostingClassifier uses sampling without replacement, while the BaggingClassifier allows both bootstrapping and sampling without replacement (AdaBoostClassifier doesn't implement any of these, so no issue there). Should the current class_weight='bootstrap' option for RandomForestClassifier be renamed to something less specific for consistency across ensemble classifiers? Maybe 'subsample', 'sample', 'estimator' or something else?

Sorry to bring this up after flipping from [WIP], but it seems important to decide on before merge given the other estimators in the module.

@GaelVaroquaux
Member

@trevorstephens : that's a very good comment. Maybe I would use 'subsample', which I find is less confusing than 'estimator'.

@GaelVaroquaux GaelVaroquaux commented on an outdated diff Jan 12, 2015
sklearn/ensemble/forest.py
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)
+ if expanded_class_weight is not None:
+ if sample_weight is not None:
+ sample_weight = np.copy(sample_weight) * expanded_class_weight
@GaelVaroquaux
GaelVaroquaux Jan 12, 2015 Member

The 'np.copy' doesn't seem necessary above.

@GaelVaroquaux GaelVaroquaux and 1 other commented on an outdated diff Jan 12, 2015
sklearn/ensemble/forest.py
@@ -89,6 +87,29 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts
+ if class_weight == 'bootstrap':
+ expanded_class_weight = [curr_sample_weight]
+ for k in range(y.shape[1]):
+ y_full = y[:, k]
+ classes_full = np.unique(y_full)
+ y_boot = y_full[indices]
+ classes_boot = np.unique(y_boot)
+ # Get class weights for the bootstrap sample
+ weight_k = compute_class_weight('auto', classes_boot, y_boot)
+ # Expand class weights to cover all classes in original y
+ # (in case some were missing from the bootstrap sample)
+ weight_k = np.array([weight_k[np.where(classes_boot == c)][0]
+ if c in classes_boot
+ else 0.
+ for c in classes_full])
@GaelVaroquaux
GaelVaroquaux Jan 12, 2015 Member

A clever use of np.choose should enable removing this for loop and making the corresponding code much faster (hint: use mode='clip').

@trevorstephens
trevorstephens Jan 12, 2015 Contributor

@GaelVaroquaux are you referring to the for k in range(y.shape[1]): loop or the list comprehension? I can make the weight_k calc a lot prettier with:

# Get class weights for the bootstrap sample
weight_k = np.choose(classes_full,
                     compute_class_weight('auto', np.unique(y_boot), y_boot), 
                     mode='clip')

in place of L98-104 certainly makes the code nicer. Is this close to what you were aiming for?

@GaelVaroquaux
GaelVaroquaux Jan 12, 2015 Member

@GaelVaroquaux are you referring to the for k in range(y.shape[1]): loop or the
list comprehension? I can make the weight_k calc a lot prettier with:

Get class weights for the bootstrap sample

weight_k = np.choose(classes_full,
compute_class_weight('auto', np.unique(y_boot), y_boot),
mode='clip')

in place of L98-104 certainly makes the code nicer. Is this close to what you
were aiming for?

Yes. It should also be significantly faster.

@trevorstephens
trevorstephens Jan 13, 2015 Contributor

I had to also use np.searchsorted here in order for it to be properly interpreted for all cases, but it looks a lot nicer than the list comprehension. Thanks for the tip!

@trevorstephens
Contributor

Thanks for the comments @GaelVaroquaux . Renamed the class_weight='bootstrap' option to 'subsample' and implemented the other changes you suggested. Let me know what you think.

@coveralls

Coverage Status

Coverage increased (+0.02%) when pulling 35c2535 on trevorstephens:rf-class_weight into 57544ae on scikit-learn:master.

@GaelVaroquaux
Member

LGTM.

Two 👍 and my review. I think that this can be merged.

Merging. Thanks!

@GaelVaroquaux GaelVaroquaux merged commit 527ecf5 into scikit-learn:master Jan 13, 2015

1 check passed

continuous-integration/travis-ci The Travis CI build passed
Details
@trevorstephens trevorstephens deleted the trevorstephens:rf-class_weight branch Jan 13, 2015
@trevorstephens
Contributor

Thanks! Cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment