New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Normalization fix for gbm ensemble feature importances #11176

Merged
merged 9 commits into from Jun 21, 2018

Conversation

Projects
None yet
5 participants
@gforsyth
Contributor

gforsyth commented May 31, 2018

Reference Issues/PRs

What does this implement/fix? Explain your changes.

This makes a small change to the BaseDecisionTree and
DecisionTreeRegressor, adding a normalize_feature_importances kwarg
that defaults to True (matching current behavior) but is set to False
for GBMs.

Tree-level feature importances are calculated as the proportional reduction of
that tree's root node entropy (see _tree.pyx) and ensemble-level importances are
averaged across trees.

Because the feature importances are (currently, by default) normalized and then
averaged, feature importances from later stages are overweighted.

GBMs now defer the normalization step until after the per-tree importances have
been calculated (for the DecisionTreeRegressor used in
ensemble.gradient_boosting.BaseGradientBoosting).

Any other comments?

h/t to Joel Danke for spotting this and writing up the issue.

@gforsyth gforsyth changed the title from Normalization fix for gbm ensemble feature importances to [MRG] Normalization fix for gbm ensemble feature importances May 31, 2018

Gil Forsyth Gil Forsyth
Normalization fix for gbm ensemble feature importances
This makes a small change to the ``BaseDecisionTree`` and
``DecisionTreeRegressor``, adding a ``normalize_feature_importances`` ``kwarg``
that defaults to ``True`` (matching current behavior) but is set to ``False``
for GBMs.

Tree-level feature importances are calculated as the proportional reduction of
that tree's root node entropy (see ``_tree.pyx``) and ensemble-level importances are
averaged across trees.

Because the feature importances are (currently, by default) normalized and then
averaged, feature importances from later stages are overweighted.

GBMs now defer the normalization step until after the per-tree importances have
been calculated (for the ``DecisionTreeRegressor`` used in
``ensemble.gradient_boosting.BaseGradientBoosting``).

h/t to Joel Danke for spotting this and writing up the issue.

Signed-off-by: Gil Forsyth <Gilbert.Forsyth@capitalone.com>
@jnothman

Thanks for the fully-fledged pr. You don't seem to test the effect of changing the new parameter as such

Show outdated Hide outdated sklearn/ensemble/tests/test_gradient_boosting.py
@@ -1234,6 +1235,7 @@ def feature_importances_(self):
total_sum += stage_sum
importances = total_sum / len(self.estimators_)
importances /= importances.sum()

This comment has been minimized.

@jnothman

jnothman Jun 1, 2018

Member

Please clarify how this relates to the new parameter

@jnothman

jnothman Jun 1, 2018

Member

Please clarify how this relates to the new parameter

This comment has been minimized.

@gforsyth

gforsyth Jun 1, 2018

Contributor

The previous behavior normalized by the sum of importances, but did this on a per-tree basis before adding up all of the importances across trees. The new parameter stops the normalization on the per-tree basis so it is added back in here.

The general idea is to switch from a normalize then sum workflow to a sum then normalize in order to avoid overweighting features that contribute less to total gini importance.

@gforsyth

gforsyth Jun 1, 2018

Contributor

The previous behavior normalized by the sum of importances, but did this on a per-tree basis before adding up all of the importances across trees. The new parameter stops the normalization on the per-tree basis so it is added back in here.

The general idea is to switch from a normalize then sum workflow to a sum then normalize in order to avoid overweighting features that contribute less to total gini importance.

This comment has been minimized.

@glemaitre

glemaitre Jun 18, 2018

Contributor

But when are we using normalize_feature_importances actually. It seems that it is never used.

@glemaitre

glemaitre Jun 18, 2018

Contributor

But when are we using normalize_feature_importances actually. It seems that it is never used.

This comment has been minimized.

@gforsyth

gforsyth Jun 18, 2018

Contributor

It's never being used in GBM, but it is set True (by default) when used by Random Forests.

@gforsyth

gforsyth Jun 18, 2018

Contributor

It's never being used in GBM, but it is set True (by default) when used by Random Forests.

This comment has been minimized.

@glemaitre

glemaitre Jun 18, 2018

Contributor

I got that but I don't see a statement:

if normalize_feature_importances:
    importance /= ...
@glemaitre

glemaitre Jun 18, 2018

Contributor

I got that but I don't see a statement:

if normalize_feature_importances:
    importance /= ...

This comment has been minimized.

@gforsyth

gforsyth Jun 18, 2018

Contributor

Ahhh, gotcha. Ok, I'll add that in. Thanks!

@gforsyth

gforsyth Jun 18, 2018

Contributor

Ahhh, gotcha. Ok, I'll add that in. Thanks!

This comment has been minimized.

@gforsyth

gforsyth Jun 18, 2018

Contributor

Actually -- sorry, I think I misunderstood.

It's only being used in the call to instantiate the DecisionTreeRegressor without adding in a totally separate class.

@gforsyth

gforsyth Jun 18, 2018

Contributor

Actually -- sorry, I think I misunderstood.

It's only being used in the call to instantiate the DecisionTreeRegressor without adding in a totally separate class.

This comment has been minimized.

@glemaitre

glemaitre Jun 18, 2018

Contributor

Oh sorry I see. The computation of the importance is in the cython file, right.

@glemaitre

glemaitre Jun 18, 2018

Contributor

Oh sorry I see. The computation of the importance is in the cython file, right.

This comment has been minimized.

@gforsyth

gforsyth Jun 18, 2018

Contributor

Right, so for GBMs only, we skip the normalization part of the computation in the cython file and then do the normalization here (which is reasonably cheap since it's just the single numpy operation no matter the number of trees)

@gforsyth

gforsyth Jun 18, 2018

Contributor

Right, so for GBMs only, we skip the normalization part of the computation in the cython file and then do the normalization here (which is reasonably cheap since it's just the single numpy operation no matter the number of trees)

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 1, 2018

Contributor

In re: the effect of changing the parameter, it is going to change the feature importance orders -- for the example here: http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html
both the relative importances and the order of the features will be different.

Contributor

gforsyth commented Jun 1, 2018

In re: the effect of changing the parameter, it is going to change the feature importance orders -- for the example here: http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html
both the relative importances and the order of the features will be different.

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 1, 2018

Contributor

Thanks for looking it over!

Contributor

gforsyth commented Jun 1, 2018

Thanks for looking it over!

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 2, 2018

Member
Member

jnothman commented Jun 2, 2018

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 2, 2018

Contributor

It's a bug that should be fixed.

Contributor

gforsyth commented Jun 2, 2018

It's a bug that should be fixed.

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 6, 2018

Contributor

Hey @jnothman -- just a gentle ping on this. Is there anything I can add to help get this merged in?

Contributor

gforsyth commented Jun 6, 2018

Hey @jnothman -- just a gentle ping on this. Is there anything I can add to help get this merged in?

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 6, 2018

Member

Thanks for the ping.

If it's definitely a bug, is there any benefit to providing an option to switch it? On the other hand, how much user code to we break by changing it outright? (Perhaps it's not so bad if we're only breaking feature importances and not prediction??)

I'm basically unsure about the correct way to go on this, and would appreciate others' opinion and expertise on gradient boosting.

Member

jnothman commented Jun 6, 2018

Thanks for the ping.

If it's definitely a bug, is there any benefit to providing an option to switch it? On the other hand, how much user code to we break by changing it outright? (Perhaps it's not so bad if we're only breaking feature importances and not prediction??)

I'm basically unsure about the correct way to go on this, and would appreciate others' opinion and expertise on gradient boosting.

@ahmadia

This comment has been minimized.

Show comment
Hide comment
@ahmadia

ahmadia Jun 8, 2018

(This comment has been edited lightly for clarity)

Hey there! Let me try to provide a bit more context on this from the modeling angle, and how this incorporates with the software changes in Gil's patch.

The problem arises because there are multiple modeling approaches that use decision tree regressors, however, scikit-learn's decision tree regressor implementation did not provide the unnormalized importance information back to the calling method (gradient boosting), in order for gradient boosting to correctly report feature importance back. Because gradient boosting was reporting feature importance normalized by tree, it was overweighting features in later stages of the boosting (which contributed overall less to the model's predictive power). Gil fixes this by normalizing across trees. The right way to think of this is that earlier trees in gradient boosting might have values like 1000/900 in their feature importance values, and later trees might have feature importance values like 10/9. If we normalize the feature importance values before summing them together, as scikit-learn was doing before this patch, we'll be misled into thinking the feature importances are "equal".

R's GBM implementation reports feature importance correctly.

The scikit-learn example for calculating feature importance of the Boston dataset (Gil's test checks for a regression against this) shows how calculating the tree-normalized feature importance in gradient boosting will lead to incorrect feature importances: http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html. In the plotted example, AGE is overweighted and NOX is underweighted in terms of their overall predictive power.

ahmadia commented Jun 8, 2018

(This comment has been edited lightly for clarity)

Hey there! Let me try to provide a bit more context on this from the modeling angle, and how this incorporates with the software changes in Gil's patch.

The problem arises because there are multiple modeling approaches that use decision tree regressors, however, scikit-learn's decision tree regressor implementation did not provide the unnormalized importance information back to the calling method (gradient boosting), in order for gradient boosting to correctly report feature importance back. Because gradient boosting was reporting feature importance normalized by tree, it was overweighting features in later stages of the boosting (which contributed overall less to the model's predictive power). Gil fixes this by normalizing across trees. The right way to think of this is that earlier trees in gradient boosting might have values like 1000/900 in their feature importance values, and later trees might have feature importance values like 10/9. If we normalize the feature importance values before summing them together, as scikit-learn was doing before this patch, we'll be misled into thinking the feature importances are "equal".

R's GBM implementation reports feature importance correctly.

The scikit-learn example for calculating feature importance of the Boston dataset (Gil's test checks for a regression against this) shows how calculating the tree-normalized feature importance in gradient boosting will lead to incorrect feature importances: http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html. In the plotted example, AGE is overweighted and NOX is underweighted in terms of their overall predictive power.

@ahmadia

This comment has been minimized.

Show comment
Hide comment
@ahmadia

ahmadia Jun 8, 2018

Sorry for the wall of text, here's an academic reference for the issue we're discussing here:

In Friedman’s “Greedy Function Approximation” in the Annals of Statistics, 2001, https://projecteuclid.org/download/pdf_1/euclid.aos/1013203451, the relative importance of input variables is described in section 8.1. Equation 44 (from Breiman, Friedman, Olshen & Stone, 1983) shows that a feature’s relative importance in a tree is the total improvement in squared error over all nodes splitting on that feature — not normalized or proportional — with equation 45 computing the feature’s relative importance to the GBM by taking the average over all trees of the sum (again, not the average over proportions).

The PR is “canonical” and consistent with other implementations, e.g., R’s gbm (and others).

(Hat tip again to Joel Danke at Capital One for originally spotting this and helping Gil and I with the context we're providing here).

ahmadia commented Jun 8, 2018

Sorry for the wall of text, here's an academic reference for the issue we're discussing here:

In Friedman’s “Greedy Function Approximation” in the Annals of Statistics, 2001, https://projecteuclid.org/download/pdf_1/euclid.aos/1013203451, the relative importance of input variables is described in section 8.1. Equation 44 (from Breiman, Friedman, Olshen & Stone, 1983) shows that a feature’s relative importance in a tree is the total improvement in squared error over all nodes splitting on that feature — not normalized or proportional — with equation 45 computing the feature’s relative importance to the GBM by taking the average over all trees of the sum (again, not the average over proportions).

The PR is “canonical” and consistent with other implementations, e.g., R’s gbm (and others).

(Hat tip again to Joel Danke at Capital One for originally spotting this and helping Gil and I with the context we're providing here).

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 13, 2018

Contributor

Hey @jnothman -- thanks again for sticking with us on this. Do you have any suggestions on who among the maintainers has the requisite expertise on gradient boosting to look this over?

Contributor

gforsyth commented Jun 13, 2018

Hey @jnothman -- thanks again for sticking with us on this. Do you have any suggestions on who among the maintainers has the requisite expertise on gradient boosting to look this over?

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 13, 2018

Member

@glemaitre, if you can chime in on whether you think this should be treated as a bug fix or an option?

Member

jnothman commented Jun 13, 2018

@glemaitre, if you can chime in on whether you think this should be treated as a bug fix or an option?

@jnothman jnothman added the Bug label Jun 18, 2018

@jnothman jnothman added this to the 0.20 milestone Jun 18, 2018

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 18, 2018

Contributor

This look like a bug from what I can read in Friedman's paper.

Contributor

glemaitre commented Jun 18, 2018

This look like a bug from what I can read in Friedman's paper.

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 18, 2018

Contributor

I don't think that the fix is the best one (API related). Introducing a parameter in the tree which is not related to the tree or not useful to user is a drawback (we already have a lot of parameter to understand in tree).

However, I think that we could call compute_features_imporance directly in the GradientBoosting.

Contributor

glemaitre commented Jun 18, 2018

I don't think that the fix is the best one (API related). Introducing a parameter in the tree which is not related to the tree or not useful to user is a drawback (we already have a lot of parameter to understand in tree).

However, I think that we could call compute_features_imporance directly in the GradientBoosting.

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 18, 2018

Contributor

So modifying this function:

@property
def feature_importances_(self):
"""Return the feature importances (the higher, the more important the
feature).
Returns
-------
feature_importances_ : array, shape = [n_features]
"""
self._check_initialized()
total_sum = np.zeros((self.n_features_, ), dtype=np.float64)
for stage in self.estimators_:
stage_sum = sum(tree.feature_importances_
for tree in stage) / len(stage)
total_sum += stage_sum
importances = total_sum / len(self.estimators_)
return importances

Contributor

glemaitre commented Jun 18, 2018

So modifying this function:

@property
def feature_importances_(self):
"""Return the feature importances (the higher, the more important the
feature).
Returns
-------
feature_importances_ : array, shape = [n_features]
"""
self._check_initialized()
total_sum = np.zeros((self.n_features_, ), dtype=np.float64)
for stage in self.estimators_:
stage_sum = sum(tree.feature_importances_
for tree in stage) / len(stage)
total_sum += stage_sum
importances = total_sum / len(self.estimators_)
return importances

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 18, 2018

Contributor

So for me the fix should be:

stage_sum = sum(tree.tree_.compute_feature_importances_(normalize=False) 
                for tree in stage) / len(stage) 
Contributor

glemaitre commented Jun 18, 2018

So for me the fix should be:

stage_sum = sum(tree.tree_.compute_feature_importances_(normalize=False) 
                for tree in stage) / len(stage) 
@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 18, 2018

Contributor

Hey @glemaitre -- thanks for the review! I've reverted the API changes and added in the feature importance fix as you suggested. If it looks good to you I'll make sure CI passes and then squash down the changes and fix the original commit message to match the actual fix implemented.

Contributor

gforsyth commented Jun 18, 2018

Hey @glemaitre -- thanks for the review! I've reverted the API changes and added in the feature importance fix as you suggested. If it looks good to you I'll make sure CI passes and then squash down the changes and fix the original commit message to match the actual fix implemented.

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 18, 2018

Contributor

then squash down the changes and fix the original commit message to match the actual fix implemented.

Do not squash. It will mess up the comments on GitHub. We are going to squash with the GitHub interface and take care about the history at merging time ;)

Contributor

glemaitre commented Jun 18, 2018

then squash down the changes and fix the original commit message to match the actual fix implemented.

Do not squash. It will mess up the comments on GitHub. We are going to squash with the GitHub interface and take care about the history at merging time ;)

Gil Forsyth Gil Forsyth
@glemaitre

Some nitpicks.

I am not sure about the regression test.
@jnothman do you have any other way to ensure the feature importance order.

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 18, 2018

Contributor

Hey @glemaitre -- I've made the changes you requested. The style stuff in the regression test (and the non-standard train-test split) was pulled from this example in the docs.

I'm happy to also update the example code but that seems like it should be in a separate PR...?

Contributor

gforsyth commented Jun 18, 2018

Hey @glemaitre -- I've made the changes you requested. The style stuff in the regression test (and the non-standard train-test split) was pulled from this example in the docs.

I'm happy to also update the example code but that seems like it should be in a separate PR...?

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 18, 2018

Contributor

I'm happy to also update the example code but that seems like it should be in a separate PR...?

Yes in a separate PR that we can review in parallel. It will be merged much faster since that this is not controversial.

Contributor

glemaitre commented Jun 18, 2018

I'm happy to also update the example code but that seems like it should be in a separate PR...?

Yes in a separate PR that we can review in parallel. It will be merged much faster since that this is not controversial.

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 18, 2018

Contributor

The test failure in python 2 is however an issue. I might mean that the test if flaky but it should interesting what is the difference between the results in different environment. It could be something fishy.

Contributor

glemaitre commented Jun 18, 2018

The test failure in python 2 is however an issue. I might mean that the test if flaky but it should interesting what is the difference between the results in different environment. It could be something fishy.

@jnothman

Otherwise lgtm

Show outdated Hide outdated doc/whats_new/v0.20.rst
@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 19, 2018

Contributor

Hey @glemaitre -- I've made the changes you requested and the regression test is passing on Travis now. There is a spurious failure on travis that seems to be from a 404 for some external data source one of the doc pages needs?

@jnothman -- thanks for the review, I've updated the whats_new entry with your suggestions

Contributor

gforsyth commented Jun 19, 2018

Hey @glemaitre -- I've made the changes you requested and the regression test is passing on Travis now. There is a spurious failure on travis that seems to be from a 404 for some external data source one of the doc pages needs?

@jnothman -- thanks for the review, I've updated the whats_new entry with your suggestions

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 19, 2018

Contributor
Contributor

glemaitre commented Jun 19, 2018

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 20, 2018

Contributor

So here come what I would consider as a test:

import numpy as np                                                                   
from sklearn.datasets import fetch_california_housing                                
from sklearn.ensemble import GradientBoostingRegressor                               
from sklearn.ensemble import RandomForestRegressor                                   
from sklearn.model_selection import train_test_split                                 
                                                                                     
california = fetch_california_housing()                                              
X, y = california.data, california.target                                            
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)            
                                                                                     
# Gradient boosting                                                                  
reg = GradientBoostingRegressor(loss='huber', learning_rate=0.1,                     
                                max_leaf_nodes=6, n_estimators=800,                  
                                random_state=0)                                      
reg.fit(X_train, y_train)                                                            
sorted_idx = np.argsort(reg.feature_importances_)[::-1]                              
print('Gradient boosting')                                                           
print([california.feature_names[s] for s in sorted_idx])                             
print(reg.feature_importances_[sorted_idx])                                          
                                                                                     
# Random Forest                                                                      
reg = RandomForestRegressor(n_estimators=100, n_jobs=-1, min_samples_leaf=5,         
                            random_state=0)                                          
reg.fit(X_train, y_train)                                                            
sorted_idx = np.argsort(reg.feature_importances_)[::-1]                              
print('Random forest')                                                               
print([california.feature_names[s] for s in sorted_idx])                             
print(reg.feature_importances_[sorted_idx])

Results in the master branch

Gradient boosting
['Latitude', 'Longitude', 'MedInc', 'AveOccup', 'AveRooms', 'AveBedrms', 'HouseAge', 'Population']
[0.19931333 0.19822656 0.14355326 0.13547336 0.10937617 0.09139491
 0.06798544 0.05467696]
Random forest
['MedInc', 'AveOccup', 'Latitude', 'Longitude', 'HouseAge', 'AveRooms', 'Population', 'AveBedrms']
[0.56970374 0.13632459 0.08061255 0.08030935 0.05368477 0.03888769
 0.02031201 0.02016529]

img2

Results in the current PR:

Gradient boosting
['MedInc', 'Longitude', 'AveOccup', 'Latitude', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population']
[0.58500419 0.12820293 0.11039006 0.10886243 0.03395709 0.02155092
 0.00812921 0.00390317]
Random forest
['MedInc', 'AveOccup', 'Latitude', 'Longitude', 'HouseAge', 'AveRooms', 'Population', 'AveBedrms']
[0.56970374 0.13632459 0.08061255 0.08030935 0.05368477 0.03888769
 0.02031201 0.02016529]

img1

Results for Random Forest

rf1

Results mentioned in Hastie et al.

img3

Contributor

glemaitre commented Jun 20, 2018

So here come what I would consider as a test:

import numpy as np                                                                   
from sklearn.datasets import fetch_california_housing                                
from sklearn.ensemble import GradientBoostingRegressor                               
from sklearn.ensemble import RandomForestRegressor                                   
from sklearn.model_selection import train_test_split                                 
                                                                                     
california = fetch_california_housing()                                              
X, y = california.data, california.target                                            
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)            
                                                                                     
# Gradient boosting                                                                  
reg = GradientBoostingRegressor(loss='huber', learning_rate=0.1,                     
                                max_leaf_nodes=6, n_estimators=800,                  
                                random_state=0)                                      
reg.fit(X_train, y_train)                                                            
sorted_idx = np.argsort(reg.feature_importances_)[::-1]                              
print('Gradient boosting')                                                           
print([california.feature_names[s] for s in sorted_idx])                             
print(reg.feature_importances_[sorted_idx])                                          
                                                                                     
# Random Forest                                                                      
reg = RandomForestRegressor(n_estimators=100, n_jobs=-1, min_samples_leaf=5,         
                            random_state=0)                                          
reg.fit(X_train, y_train)                                                            
sorted_idx = np.argsort(reg.feature_importances_)[::-1]                              
print('Random forest')                                                               
print([california.feature_names[s] for s in sorted_idx])                             
print(reg.feature_importances_[sorted_idx])

Results in the master branch

Gradient boosting
['Latitude', 'Longitude', 'MedInc', 'AveOccup', 'AveRooms', 'AveBedrms', 'HouseAge', 'Population']
[0.19931333 0.19822656 0.14355326 0.13547336 0.10937617 0.09139491
 0.06798544 0.05467696]
Random forest
['MedInc', 'AveOccup', 'Latitude', 'Longitude', 'HouseAge', 'AveRooms', 'Population', 'AveBedrms']
[0.56970374 0.13632459 0.08061255 0.08030935 0.05368477 0.03888769
 0.02031201 0.02016529]

img2

Results in the current PR:

Gradient boosting
['MedInc', 'Longitude', 'AveOccup', 'Latitude', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population']
[0.58500419 0.12820293 0.11039006 0.10886243 0.03395709 0.02155092
 0.00812921 0.00390317]
Random forest
['MedInc', 'AveOccup', 'Latitude', 'Longitude', 'HouseAge', 'AveRooms', 'Population', 'AveBedrms']
[0.56970374 0.13632459 0.08061255 0.08030935 0.05368477 0.03888769
 0.02031201 0.02016529]

img1

Results for Random Forest

rf1

Results mentioned in Hastie et al.

img3

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 20, 2018

Contributor

So it seems that this is a good fix even if the relative importance is still not the same than in the book but the trend is the same.

So I would rework the test based on the snippet that I showed (we might not need 800 estimators). As some test conditions, I would check that MedInc > Longitude > HouseAge > Population.

And I would add a comment at the beginning of the text referring to the experiment in Hastie et al. p.373

Contributor

glemaitre commented Jun 20, 2018

So it seems that this is a good fix even if the relative importance is still not the same than in the book but the trend is the same.

So I would rework the test based on the snippet that I showed (we might not need 800 estimators). As some test conditions, I would check that MedInc > Longitude > HouseAge > Population.

And I would add a comment at the beginning of the text referring to the experiment in Hastie et al. p.373

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 20, 2018

Contributor

@jnothman Do you agree with those?

Contributor

glemaitre commented Jun 20, 2018

@jnothman Do you agree with those?

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jun 20, 2018

Member

I don't know how brittle that ordering is, but it's certainly comforting that this pr brings us closer. Yes it seems an appropriate test

Member

jnothman commented Jun 20, 2018

I don't know how brittle that ordering is, but it's certainly comforting that this pr brings us closer. Yes it seems an appropriate test

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 20, 2018

Contributor

So it seems that this is a good fix even if the relative importance is still not the same than in the book but the trend is the same.

Hey @glemaitre -- I was working on the same example yesterday afternoon and have the same results. I'll experiment a little bit but looking at the AAE plot from the original example in ESL I think that 200 estimators should be more than enough to match the expected results.

Putting this together now and will report back results.

Contributor

gforsyth commented Jun 20, 2018

So it seems that this is a good fix even if the relative importance is still not the same than in the book but the trend is the same.

Hey @glemaitre -- I was working on the same example yesterday afternoon and have the same results. I'll experiment a little bit but looking at the AAE plot from the original example in ESL I think that 200 estimators should be more than enough to match the expected results.

Putting this together now and will report back results.

Gil Forsyth and others added some commits Jun 20, 2018

Gil Forsyth Gil Forsyth
Replace regression test with example from ESL
Replacing regression test with a GBM example from the Elements of
Statistical Learning, pg. 373.
@@ -12,7 +12,7 @@
from sklearn import datasets

This comment has been minimized.

@ahmadia

ahmadia Jun 21, 2018

Is this line still necessary?

@ahmadia

ahmadia Jun 21, 2018

Is this line still necessary?

This comment has been minimized.

@ogrisel

ogrisel Jun 21, 2018

Member

Yes it's used elsewhere in the test file.

@ogrisel

ogrisel Jun 21, 2018

Member

Yes it's used elsewhere in the test file.

Show outdated Hide outdated sklearn/ensemble/tests/test_gradient_boosting.py
Show outdated Hide outdated sklearn/ensemble/tests/test_gradient_boosting.py
Show outdated Hide outdated sklearn/ensemble/tests/test_gradient_boosting.py
Show outdated Hide outdated sklearn/ensemble/tests/test_gradient_boosting.py
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 21, 2018

Member

Thanks for the non-regression test @gforsyth . I found it a bit too slow. I pushed a change to relax it a bit while checking that I could still reproduce the issue on master with only 100 trees (instead of 200).

Member

ogrisel commented Jun 21, 2018

Thanks for the non-regression test @gforsyth . I found it a bit too slow. I pushed a change to relax it a bit while checking that I could still reproduce the issue on master with only 100 trees (instead of 200).

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 21, 2018

Member

BTW @gforsyth, remark for later: you should not use your master branch to submit PRs, otherwise you cannot work on several PRs in parallel.

Member

ogrisel commented Jun 21, 2018

BTW @gforsyth, remark for later: you should not use your master branch to submit PRs, otherwise you cannot work on several PRs in parallel.

@ogrisel

LGTM if CI is still green with the last commit.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 21, 2018

Member

@ahmadia @gforsyth in the plots reported by @glemaitre there seems to be a difference in the relative feature importances w.r.t. the plot of ESLII. It would be great to try to understand the source of that remaining discrepancy (although I think only the order really matters in practice). I you have sparse cycles feel free to open another issue with the results of your investigations.

Member

ogrisel commented Jun 21, 2018

@ahmadia @gforsyth in the plots reported by @glemaitre there seems to be a difference in the relative feature importances w.r.t. the plot of ESLII. It would be great to try to understand the source of that remaining discrepancy (although I think only the order really matters in practice). I you have sparse cycles feel free to open another issue with the results of your investigations.

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 21, 2018

Contributor

I found it a bit too slow. I pushed a change to relax it a bit while checking that I could still reproduce the issue on master with only 100 trees (instead of 200).

Sounds good to me @ogrisel -- thanks for the fix.

you should not use your master branch to submit PRs, otherwise you cannot work on several PRs in parallel.

can't I just open PRs from arbitrary branches?

Contributor

gforsyth commented Jun 21, 2018

I found it a bit too slow. I pushed a change to relax it a bit while checking that I could still reproduce the issue on master with only 100 trees (instead of 200).

Sounds good to me @ogrisel -- thanks for the fix.

you should not use your master branch to submit PRs, otherwise you cannot work on several PRs in parallel.

can't I just open PRs from arbitrary branches?

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jun 21, 2018

Member

can't I just open PRs from arbitrary branches?

You can but it's confusing to have a branch named "master" that does not actually track the content of the upstream master branch.

Member

ogrisel commented Jun 21, 2018

can't I just open PRs from arbitrary branches?

You can but it's confusing to have a branch named "master" that does not actually track the content of the upstream master branch.

@jnothman

Thanks and good work, everyone.

@glemaitre

This comment has been minimized.

Show comment
Hide comment
@glemaitre

glemaitre Jun 21, 2018

Contributor

@ahmadia @gforsyth in the plots reported by @glemaitre there seems to be a difference in the relative feature importances w.r.t. the plot of ESLII. It would be great to try to understand the source of that remaining discrepancy (although I think only the order really matters in practice). I you have sparse cycles feel free to open another issue with the results of your investigations.

I would even assume that this something shared by the GradientBoosting and RandomForest. I would check if there is not something fishy with some normalization when computing the the Gini Importance descending each tree.

Contributor

glemaitre commented Jun 21, 2018

@ahmadia @gforsyth in the plots reported by @glemaitre there seems to be a difference in the relative feature importances w.r.t. the plot of ESLII. It would be great to try to understand the source of that remaining discrepancy (although I think only the order really matters in practice). I you have sparse cycles feel free to open another issue with the results of your investigations.

I would even assume that this something shared by the GradientBoosting and RandomForest. I would check if there is not something fishy with some normalization when computing the the Gini Importance descending each tree.

@gforsyth

This comment has been minimized.

Show comment
Hide comment
@gforsyth

gforsyth Jun 21, 2018

Contributor

CI passed! Thanks for the reviews and help, all!

Contributor

gforsyth commented Jun 21, 2018

CI passed! Thanks for the reviews and help, all!

@jnothman jnothman merged commit 08f04d9 into scikit-learn:master Jun 21, 2018

5 of 6 checks passed

continuous-integration/appveyor/pr Waiting for AppVeyor build to complete
Details
LGTM analysis: Python No alert changes
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: python2 Your tests passed on CircleCI!
Details
ci/circleci: python3 Your tests passed on CircleCI!
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment