In [4]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.datasets import load_diabetes
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator, RegressorMixin
from cvxopt import matrix, solvers

In [5]:
import sys
sys.path.append("../src/tinyshap/")

from explainer import SHAPExplainer

In [6]:
dataset = load_diabetes()
X = pd.DataFrame(dataset["data"], columns=dataset["feature_names"])
y = dataset["target"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

model = GradientBoostingRegressor()
model.fit(X_train, y_train)

X_train_summary = pd.DataFrame(KMeans(n_clusters=10, n_init="auto").fit(X_train).cluster_centers_, columns=X.columns)


In [7]:
x_test = X_test.iloc[0].values

In [18]:
model.predict(x_test.reshape(1, -1))



array([257.48348281])

In [9]:
explainer = SHAPExplainer(model.predict, X=X_train_summary)

In [10]:
shap_values = explainer._explain_sample(x_test)



In [11]:
shap_values

age                -0.741638
sex                -5.232452
bmi                77.990100
bp                 34.279944
s1                 -4.645594
s2                 10.445669
s3                 12.686111
s4                 -5.413604
s5                 -5.836733
s6                  4.386970
avg_prediction    139.564710
dtype: float64

In [12]:
shap_values.sum()

257.4834828115392

In [19]:
shap_values = explainer.shap_values(X_test)
shap_values.head()



Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,avg_prediction
362,-14.356042,-14.160712,76.601739,30.029894,3.643864,12.739738,-0.362683,7.247268,-4.966307,5.042116,156.024609
249,7.801589,-9.392816,21.806386,15.5882,-6.595684,-11.534782,15.996917,-1.138181,34.769792,-1.324831,159.35525
271,3.395596,-4.04276,31.701555,14.764363,8.814374,-2.45811,7.938095,1.642688,-14.942494,-5.312649,138.210414
435,-5.715132,15.493892,-8.059393,4.604459,6.524089,-8.04701,8.490407,5.359144,-28.770435,2.027765,138.180347
400,8.400164,9.140222,39.392627,38.220929,17.985413,-2.395535,-9.981999,-4.340782,-34.894754,9.121141,136.015214


In [23]:
np.allclose(shap_values.sum(axis=1).values, y_test, rtol=0.1)

False

In [25]:
y_test

array([321., 215., 127.,  64., 175., 275., 179., 232., 142.,  99., 252.,
       174., 129.,  74., 264.,  49.,  86.,  75., 101., 155., 170., 276.,
       110., 136.,  68., 128., 103.,  93., 191., 196., 217., 181., 168.,
       200., 219., 281., 151., 257.,  49., 198.,  96., 179.,  95., 198.,
       244.,  89., 214., 182.,  84., 270., 156., 138., 113., 131., 195.,
       171., 122.,  61., 230., 235.,  52., 121., 144., 107., 132., 302.,
        53., 317., 137.,  57.,  98., 170.,  88.,  90.,  67., 163., 104.,
       186., 180., 283., 141., 150.,  47., 297., 104.,  49., 103., 142.,
        59.])

In [24]:
shap_values.sum(axis=1).values

array([257.48348281, 225.33184141, 179.7110708 , 130.08813324,
       206.66263917, 245.42009698, 108.23077693, 211.7493365 ,
       113.89937094, 243.6955665 , 185.85944825, 165.03938356,
       118.64869397,  98.66999232, 284.78320515,  90.99789137,
       139.87519692,  72.07460863, 110.95165125, 229.85424766,
       194.34119765, 124.39092218, 176.45207515, 144.24191718,
       220.29055134, 194.1217516 , 132.67355455,  60.39661767,
       245.2668384 , 157.64369639, 198.85050859,  92.91202402,
       145.56633086, 155.55454263, 137.40616108, 166.10247174,
       161.96030631, 137.06516734,  84.94203291, 197.01436314,
       106.43775976, 164.09099565, 129.56177404, 187.73745607,
       167.72124052,  76.69148178, 110.44764268, 109.94160467,
        93.26044539, 274.63869156, 137.11408148,  62.53084065,
       160.13117285, 176.51831866, 242.84044873, 174.72853813,
       192.27590283, 122.47316401,  90.8436274 , 175.25427935,
       237.73911273, 144.85224532, 121.57362986,  97.54