In [1]:
from sklearn.model_selection import train_test_split
import lightgbm as lgb
import shap

# print the JS visualization code to the notebook
shap.initjs()

In [2]:
X,y = shap.datasets.adult()
X_display,y_display = shap.datasets.adult(display=True)

# create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = lgb.Dataset(X_train, label=y_train)
d_test = lgb.Dataset(X_test, label=y_test)

In [3]:
params = {
    "max_bin": 512,
    "learning_rate": 0.05,
    "boosting_type": "gbdt",
    "objective": "binary",
    "metric": "binary_logloss",
    "num_leaves": 10,
    "verbose": -1,
    "min_data": 100,
    "boost_from_average": True
}

model = lgb.train(params, d_train, 10000, valid_sets=[d_test], early_stopping_rounds=50, verbose_eval=1000)


Training until validation scores don't improve for 50 rounds
Early stopping, best iteration is:
[644]	valid_0's binary_logloss: 0.278029


In [4]:
%%time 
explainer = shap.TreeExplainer(model)



CPU times: user 232 ms, sys: 3.86 ms, total: 236 ms
Wall time: 237 ms


In [5]:
%%time
shap_values = explainer.shap_values(X)

CPU times: user 34 s, sys: 65.2 ms, total: 34.1 s
Wall time: 10.3 s


LightGBM binary classifier with TreeExplainer shap values output has changed to a list of ndarray


In [6]:
shap.force_plot(explainer.expected_value[1], shap_values[1][0,:], X_display.iloc[0,:])

In [7]:
shap.force_plot(explainer.expected_value[1], shap_values[1][:1000,:], X_display.iloc[:1000,:])

In [8]:
type(explainer.expected_value[1])

numpy.float64

In [9]:
type(shap_values[1][:1000,:])

numpy.ndarray

In [12]:
type(shap_values)

list

In [13]:
shap_values

[array([[-4.90548484e-01,  2.75392657e-01, -5.84700009e-01, ...,
          2.49908066e-02,  1.24473086e-01, -1.64659509e-02],
        [-1.00895059e+00,  3.29944730e-01, -7.85310857e-01, ...,
          7.53558627e-02,  1.13838141e+00, -1.50493404e-02],
        [-4.37856520e-01, -1.98391007e-02,  3.08878503e-01, ...,
          2.06305188e-02,  2.67436729e-02, -3.85132536e-03],
        ...,
        [-8.13230626e-01,  2.03190278e-02,  3.91046300e-01, ...,
          3.04790494e-02,  1.54367909e-02, -1.16830782e-02],
        [ 1.89759917e+00,  2.31645411e-02,  2.65437686e-01, ...,
          3.25766309e-02,  1.22408956e+00,  3.59659083e-04],
        [-1.03679176e+00, -2.16312813e-01,  4.57916496e-01, ...,
          2.36833985e-02, -1.14525043e-01, -1.11461339e-02]]),
 array([[ 4.90548484e-01, -2.75392657e-01,  5.84700009e-01, ...,
         -2.49908066e-02, -1.24473086e-01,  1.64659509e-02],
        [ 1.00895059e+00, -3.29944730e-01,  7.85310857e-01, ...,
         -7.53558627e-02, -1.13838141e

In [10]:
type(X_display.iloc[:1000,:])

pandas.core.frame.DataFrame

In [11]:
X_display.iloc[:1000,:].shape

(1000, 12)