# Quick feature selection through regression on Shapley values

Feature selection for tabular models is a hard problem, and most solutions proposed for it are computationally expensive. Here we show a heuristic method that is quite computationally efficient, due to the fact that computing Shapley values on tree-based models (such as XGBoost, LightGBM, or CatBoost) is quite quick. 

For those who haven't come across them before, Shapley values are simply a way of decomposing a model's output into contributions from the individual feature values, with the nice property that all the features' contributions are guaranteed to add up to the model output. 

The process goes as follows: first, you split your dataset into a training and a validation set, and train a tree-based model on the training set, using all the available features, ideally with early stopping. If you already have a model thus fitted, you can just use that instead.

In the second step, you calculate the Shapley values of all the features for that model, on the validation set. And now comes the fun part: for every data point in the validation set the Shapley values add up, by construction, to the model output for that data point. 

Now you are in linear country. As the next step, you run a regression of the target value on the shapley values of the features, on the validation set. If the model was perfect (model output identical to target) all the regression coefficients would be equal to 1.0. In practice, that will not be the case, and the coefficients of irrelevant features end up either being statistically insignificant (because the contributions of those features don't, on average, bring the model output closer to the target on the validation set), or negative, indicating that their presence is actually harming validation set performance.

So our algorithm recommends first discarding all features with negative coefficients, then ranking the rest according to their statistical significance, and choosing some significance threshold (default 5%) getting below which will make us keep the feature. 

Here's an example on synthetic data:

In [1]:
import os, sys
from typing import List

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

try:
    from shap_select import shap_select
except ModuleNotFoundError:
    # If you're running shap_select from source
    root = os.path.realpath("..")
    sys.path.append(root)
    from shap_select import shap_select

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
np.random.seed(42)
n_samples = 100000

# Create 9 normally distributed features
X = pd.DataFrame(
    {
        "x1": np.random.normal(size=n_samples),
        "x2": np.random.normal(size=n_samples),
        "x3": np.random.normal(size=n_samples),
        "x4": np.random.normal(size=n_samples),
        "x5": np.random.normal(size=n_samples),
        "x6": np.random.normal(size=n_samples),
        "x7": np.random.normal(size=n_samples),
        "x8": np.random.normal(size=n_samples),
        "x9": np.random.normal(size=n_samples),
    }
)

# Make all the features positive-ish
X += 3

# Define the target based on the formula y = x1 + x2*x3 + x4*x5*x6
y = (
    3 * X["x1"]
    + X["x2"] * X["x3"]
    + X["x4"] * X["x5"] * X["x6"]
    + 10 * np.random.normal(size=n_samples)  # lots of noise
)
X["x6"] *= 0.1
X["x6"] += np.random.normal(size=n_samples)

# Split the dataset into training and validation sets (both with 10K rows)
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.1, random_state=42
)

Let's train, for example, an xgboost model on the training set:

In [3]:
import xgboost as xgb

dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
params = {
        "objective": "reg:squarederror",
        "eval_metric": "rmse",
        "verbosity": 0,
    }

model = xgb.train(
    params, dtrain, num_boost_round=1000, evals= [(dval, "valid")], early_stopping_rounds=50
)

[0]	valid-rmse:17.78711
[1]	valid-rmse:16.44843
[2]	valid-rmse:15.64895
[3]	valid-rmse:15.19588
[4]	valid-rmse:14.92683
[5]	valid-rmse:14.75290
[6]	valid-rmse:14.65225
[7]	valid-rmse:14.56790
[8]	valid-rmse:14.50784
[9]	valid-rmse:14.46584
[10]	valid-rmse:14.43859
[11]	valid-rmse:14.42790
[12]	valid-rmse:14.41093
[13]	valid-rmse:14.39674
[14]	valid-rmse:14.38603
[15]	valid-rmse:14.38173
[16]	valid-rmse:14.37627
[17]	valid-rmse:14.37386
[18]	valid-rmse:14.36957
[19]	valid-rmse:14.36874
[20]	valid-rmse:14.36958
[21]	valid-rmse:14.37481
[22]	valid-rmse:14.37414
[23]	valid-rmse:14.37449
[24]	valid-rmse:14.37473
[25]	valid-rmse:14.37843
[26]	valid-rmse:14.38056
[27]	valid-rmse:14.38592
[28]	valid-rmse:14.39205
[29]	valid-rmse:14.39171
[30]	valid-rmse:14.38889
[31]	valid-rmse:14.39872
[32]	valid-rmse:14.40221
[33]	valid-rmse:14.40517
[34]	valid-rmse:14.41196
[35]	valid-rmse:14.41776
[36]	valid-rmse:14.41830
[37]	valid-rmse:14.42190
[38]	valid-rmse:14.42338
[39]	valid-rmse:14.42358
[40]	valid

Now let's generate the feature significance scores. The final column shows whether we suggest to select the feature; -1 means feature is rejected because of a negative regression coefficient, 0 means it's rejected because of not passing the significance threshold.

In [4]:
selected_features_df = shap_select(model, X_val, y_val, task="regression", threshold=0.05)

Condition number: 67.24977


In [5]:
# Let's color the output prettily
def prettify(df: pd.DataFrame, exclude: List[str]):
    styled_df = df.style.background_gradient(
        cmap='coolwarm', subset=pd.IndexSlice[:, [c for i,c in enumerate(df.columns) if c not in exclude]]
    )
    return styled_df

prettify(selected_features_df, exclude=["feature name"])

Unnamed: 0,feature name,t-value,stat.significance,coefficient,selected
0,x5,20.211298,0.0,1.05203,1
1,x4,18.315144,0.0,0.952416,1
2,x3,6.83569,0.0,1.098154,1
3,x2,6.45714,0.0,1.044842,1
4,x1,5.530556,0.0,0.917242,1
5,x6,2.390868,0.016827,1.497983,1
6,x7,0.901098,0.367558,2.865508,0
7,x8,0.563214,0.573302,1.933632,0
8,x9,-1.607814,0.107908,-4.537098,-1


## What about classifier models?
You'll be happy to hear that the above approach works just fine on the classifier models. There is a slight difference under the hood, described below, but both the function call, and the interpretation of the output, work exactly the same. 

### Technical details for classifier models
The `shap` package automatically regcognizes whether it's given a classifier model, and in that case, calculates the shap values for log odds of a particular outcome.

In the case of a binary classifier, this means that we now have to run a logistic, rather than a linear regression, and then proceed exactly like before with interpreting the coefficients and significances.

In the case of a multiclass classifier, we get shapley values for each value of the target; we run a binary regression for each and then for each coefficient take the largest t-value across these regresssions, and calculate the statistical significance from that. Finally, to avoid the data mining effect of multiple tests, we apply the Bonferroni correction by multiplying the resulting significance by the number of classes; this way, you can compare that value to the original threshold value. 

Below is an example of a multiclass classifier.


In [6]:
np.random.seed(42)
n_samples = 100000

# Create 9 normally distributed features
X = pd.DataFrame(
    {
        "x1": np.random.normal(size=n_samples),
        "x2": np.random.normal(size=n_samples),
        "x3": np.random.normal(size=n_samples),
        "x4": np.random.normal(size=n_samples),
        "x5": np.random.normal(size=n_samples),
        "x6": np.random.normal(size=n_samples),
        "x7": np.random.normal(size=n_samples),
        "x8": np.random.normal(size=n_samples),
        "x9": np.random.normal(size=n_samples),
    }
)

# Make all the features positive-ish
X += 3

# Create a multiclass target with 3 classes
y = pd.cut(
    X["x1"] + X["x2"] * X["x3"] + X["x4"] * X["x5"] * X["x6"],
    bins=3,
    labels=[0, 1, 2],
).astype(int)

# Split the dataset into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.1, random_state=42
)

dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)

params = {
    "objective": "multi:softprob",
    "num_class": 3,
    "eval_metric": "mlogloss",
    "verbosity": 0,
}


evals = [(dval, "valid")]
model = xgb.train(
    params, dtrain, num_boost_round=1000, evals=evals, early_stopping_rounds=50
)



[0]	valid-mlogloss:0.78966
[1]	valid-mlogloss:0.60695
[2]	valid-mlogloss:0.48586
[3]	valid-mlogloss:0.40006
[4]	valid-mlogloss:0.33654
[5]	valid-mlogloss:0.28842
[6]	valid-mlogloss:0.25138
[7]	valid-mlogloss:0.22226
[8]	valid-mlogloss:0.19882
[9]	valid-mlogloss:0.17992
[10]	valid-mlogloss:0.16560
[11]	valid-mlogloss:0.15291
[12]	valid-mlogloss:0.14259
[13]	valid-mlogloss:0.13417
[14]	valid-mlogloss:0.12714
[15]	valid-mlogloss:0.12163
[16]	valid-mlogloss:0.11609
[17]	valid-mlogloss:0.11109
[18]	valid-mlogloss:0.10706
[19]	valid-mlogloss:0.10308
[20]	valid-mlogloss:0.09909
[21]	valid-mlogloss:0.09610
[22]	valid-mlogloss:0.09318
[23]	valid-mlogloss:0.09023
[24]	valid-mlogloss:0.08807
[25]	valid-mlogloss:0.08563
[26]	valid-mlogloss:0.08399
[27]	valid-mlogloss:0.08230
[28]	valid-mlogloss:0.08096
[29]	valid-mlogloss:0.07934
[30]	valid-mlogloss:0.07750
[31]	valid-mlogloss:0.07608
[32]	valid-mlogloss:0.07493
[33]	valid-mlogloss:0.07354
[34]	valid-mlogloss:0.07225
[35]	valid-mlogloss:0.07103
[3

[286]	valid-mlogloss:0.03211
[287]	valid-mlogloss:0.03207
[288]	valid-mlogloss:0.03203
[289]	valid-mlogloss:0.03205
[290]	valid-mlogloss:0.03209
[291]	valid-mlogloss:0.03205
[292]	valid-mlogloss:0.03204
[293]	valid-mlogloss:0.03205
[294]	valid-mlogloss:0.03196
[295]	valid-mlogloss:0.03193
[296]	valid-mlogloss:0.03194
[297]	valid-mlogloss:0.03189
[298]	valid-mlogloss:0.03187
[299]	valid-mlogloss:0.03189
[300]	valid-mlogloss:0.03184
[301]	valid-mlogloss:0.03186
[302]	valid-mlogloss:0.03190
[303]	valid-mlogloss:0.03187
[304]	valid-mlogloss:0.03183
[305]	valid-mlogloss:0.03182
[306]	valid-mlogloss:0.03178
[307]	valid-mlogloss:0.03175
[308]	valid-mlogloss:0.03179
[309]	valid-mlogloss:0.03178
[310]	valid-mlogloss:0.03172
[311]	valid-mlogloss:0.03169
[312]	valid-mlogloss:0.03165
[313]	valid-mlogloss:0.03163
[314]	valid-mlogloss:0.03162
[315]	valid-mlogloss:0.03161
[316]	valid-mlogloss:0.03161
[317]	valid-mlogloss:0.03163
[318]	valid-mlogloss:0.03158
[319]	valid-mlogloss:0.03153
[320]	valid-ml

[569]	valid-mlogloss:0.03016
[570]	valid-mlogloss:0.03014
[571]	valid-mlogloss:0.03011
[572]	valid-mlogloss:0.03015
[573]	valid-mlogloss:0.03013
[574]	valid-mlogloss:0.03016
[575]	valid-mlogloss:0.03021
[576]	valid-mlogloss:0.03017
[577]	valid-mlogloss:0.03018
[578]	valid-mlogloss:0.03017
[579]	valid-mlogloss:0.03018
[580]	valid-mlogloss:0.03023
[581]	valid-mlogloss:0.03018
[582]	valid-mlogloss:0.03018
[583]	valid-mlogloss:0.03018
[584]	valid-mlogloss:0.03019
[585]	valid-mlogloss:0.03016
[586]	valid-mlogloss:0.03014
[587]	valid-mlogloss:0.03017
[588]	valid-mlogloss:0.03019
[589]	valid-mlogloss:0.03016
[590]	valid-mlogloss:0.03013
[591]	valid-mlogloss:0.03014
[592]	valid-mlogloss:0.03014
[593]	valid-mlogloss:0.03011
[594]	valid-mlogloss:0.03010
[595]	valid-mlogloss:0.03013
[596]	valid-mlogloss:0.03015
[597]	valid-mlogloss:0.03016
[598]	valid-mlogloss:0.03016
[599]	valid-mlogloss:0.03017
[600]	valid-mlogloss:0.03016
[601]	valid-mlogloss:0.03014
[602]	valid-mlogloss:0.03014
[603]	valid-ml

In [7]:
selected_features_df = shap_select(model, X_val, y_val, task="multiclass", threshold=0.05)

prettify(selected_features_df, exclude=["feature name"])

Optimization terminated successfully.
         Current function value: 0.028663
         Iterations 13
Optimization terminated successfully.
         Current function value: 0.066662
         Iterations 11
Optimization terminated successfully.
         Current function value: 0.001348
         Iterations 17
Condition number: 77.557655


Unnamed: 0,feature name,t-value,stat.significance,coefficient,selected
0,x4,25.927565,0.0,1.559384,1
1,x5,25.874027,0.0,1.571661,1
2,x6,25.782536,0.0,1.561214,1
3,x2,21.367053,0.0,1.753463,1
4,x3,21.330803,0.0,1.79263,1
5,x1,12.835856,0.0,2.19731,1
6,x7,0.773525,0.658817,1.901079,0
7,x9,-0.206328,1.745198,-0.317295,-1
8,x8,-0.636902,2.213717,-1.25937,-1
