# Exact explainer - inner working

In this notebook, we take a look at the inner working of the `ExactExplainer` to understand how it calculates the SHAP  values. 

In [1]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

from xgboost import XGBClassifier

import shap

## Load dataset

In [2]:
# Load iris dataset

X, y = load_iris(return_X_y=True, as_frame=True)

X.head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
0,5.1,3.5,1.4,0.2
1,4.9,3.0,1.4,0.2
2,4.7,3.2,1.3,0.2
3,4.6,3.1,1.5,0.2
4,5.0,3.6,1.4,0.2


We split the data leaving only **four** samples in the training set to better understand how `ExactExplainer` works.

In [3]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.97, random_state=0)

X_train.shape, X_test.shape

((4, 4), (146, 4))

## Train the model to explain

In [4]:
gbm = XGBClassifier(random_state=3)

gbm.fit(X_train, y_train)

In [5]:
gbm.score(X_test, y_test)

0.3287671232876712

## Under the hood of the exact explainer

In [6]:
# set up an explainer

exp = shap.Explainer(gbm.predict, X_train)

In [7]:
# this is the class that does the masking

from shap.utils._masked_model import MaskedModel

# the point to explain
sample = np.array([0.1, 0.2, 0.3, 0.4])

# the function that makes the explaining

fm = MaskedModel(exp.model, exp.masker, exp.link, exp.linearize_link, sample)

In [8]:
extended_delta_indexes = np.zeros(2**len(X_train), dtype=int)

extended_delta_indexes

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [9]:
# the samples generated for the coalitions

masked_inputs, varying_rows = fm.masker(extended_delta_indexes, sample)

masked_inputs

(    sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
 0                 0.1               2.9                5.6               1.8
 1                 0.1               2.7                4.1               1.0
 2                 0.1               3.8                6.7               2.2
 3                 0.1               3.2                1.4               0.2
 4                 6.3               2.9                5.6               1.8
 ..                ...               ...                ...               ...
 59                0.1               3.2                1.4               0.2
 60                6.3               2.9                5.6               1.8
 61                5.8               2.7                4.1               1.0
 62                7.7               3.8                6.7               2.2
 63                4.6               3.2                1.4               0.2
 
 [64 rows x 4 columns],)

In [10]:
len(X_train) * 16 # n_samples * n_coalitions

64

In [11]:
# inds = fm.varying_inputs()
# inds

In [12]:
# delta_indexes = exp._cached_gray_codes(len(inds))
# delta_indexes

In [13]:
# extended_delta_indexes = np.zeros(2**len(inds), dtype=int)
# extended_delta_indexes

In [14]:
# MaskedModel.delta_mask_noop_value

In [15]:
# for i in range(2**len(inds)):
#     if delta_indexes[i] == MaskedModel.delta_mask_noop_value:
#         extended_delta_indexes[i] = delta_indexes[i]
#     else:
#         extended_delta_indexes[i] = inds[delta_indexes[i]]
        
# extended_delta_indexes

In [16]:
# outputs = fm(extended_delta_indexes, zero_index=0, batch_size=10)

# outputs

In [17]:
# df = pd.DataFrame([0.1, 0.2, 0.3, 0.4]).T
# df.columns = X_train.columns
# df

In [18]:
# gbm.predict(df)

In [19]:
# getattr(fm.masker, "supports_delta_masking", False)

In [20]:
# fm._delta_masking_call(extended_delta_indexes, zero_index=0, batch_size=10)

## The coalitions

In [21]:
# the empty set
# all values differ from sample to analyse

masked_inputs[0].loc[0:3]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
0,0.1,2.9,5.6,1.8
1,0.1,2.7,4.1,1.0
2,0.1,3.8,6.7,2.2
3,0.1,3.2,1.4,0.2


In [22]:
# feature 4 switched on, all the rest off

masked_inputs[0].loc[4:7]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
4,6.3,2.9,5.6,1.8
5,5.8,2.7,4.1,1.0
6,7.7,3.8,6.7,2.2
7,4.6,3.2,1.4,0.2


In [23]:
# features 3 and 4 on, 1 and 2 off

masked_inputs[0].loc[8:11]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
8,0.1,2.9,5.6,1.8
9,0.1,2.7,4.1,1.0
10,0.1,3.8,6.7,2.2
11,0.1,3.2,1.4,0.2


In [24]:
masked_inputs[0].loc[12:15]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
12,6.3,2.9,5.6,1.8
13,5.8,2.7,4.1,1.0
14,7.7,3.8,6.7,2.2
15,4.6,3.2,1.4,0.2


In [25]:
masked_inputs[0].loc[16:19]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
16,0.1,2.9,5.6,1.8
17,0.1,2.7,4.1,1.0
18,0.1,3.8,6.7,2.2
19,0.1,3.2,1.4,0.2


In [26]:
masked_inputs[0].loc[20:23]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
20,6.3,2.9,5.6,1.8
21,5.8,2.7,4.1,1.0
22,7.7,3.8,6.7,2.2
23,4.6,3.2,1.4,0.2


In [27]:
masked_inputs[0].loc[24:27]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
24,0.1,2.9,5.6,1.8
25,0.1,2.7,4.1,1.0
26,0.1,3.8,6.7,2.2
27,0.1,3.2,1.4,0.2


In [28]:
masked_inputs[0].loc[28:31]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
28,6.3,2.9,5.6,1.8
29,5.8,2.7,4.1,1.0
30,7.7,3.8,6.7,2.2
31,4.6,3.2,1.4,0.2


In [29]:
masked_inputs[0].loc[32:35]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
32,0.1,2.9,5.6,1.8
33,0.1,2.7,4.1,1.0
34,0.1,3.8,6.7,2.2
35,0.1,3.2,1.4,0.2


In [30]:
masked_inputs[0].loc[36:39]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
36,6.3,2.9,5.6,1.8
37,5.8,2.7,4.1,1.0
38,7.7,3.8,6.7,2.2
39,4.6,3.2,1.4,0.2


In [31]:
masked_inputs[0].loc[40:43]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
40,0.1,2.9,5.6,1.8
41,0.1,2.7,4.1,1.0
42,0.1,3.8,6.7,2.2
43,0.1,3.2,1.4,0.2


In [32]:
masked_inputs[0].loc[44:47]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
44,6.3,2.9,5.6,1.8
45,5.8,2.7,4.1,1.0
46,7.7,3.8,6.7,2.2
47,4.6,3.2,1.4,0.2


In [33]:
masked_inputs[0].loc[48:51]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
48,0.1,2.9,5.6,1.8
49,0.1,2.7,4.1,1.0
50,0.1,3.8,6.7,2.2
51,0.1,3.2,1.4,0.2


In [34]:
masked_inputs[0].loc[52:55]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
52,6.3,2.9,5.6,1.8
53,5.8,2.7,4.1,1.0
54,7.7,3.8,6.7,2.2
55,4.6,3.2,1.4,0.2


In [35]:
masked_inputs[0].loc[56:59]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
56,0.1,2.9,5.6,1.8
57,0.1,2.7,4.1,1.0
58,0.1,3.8,6.7,2.2
59,0.1,3.2,1.4,0.2


In [36]:
masked_inputs[0].loc[60:63]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
60,6.3,2.9,5.6,1.8
61,5.8,2.7,4.1,1.0
62,7.7,3.8,6.7,2.2
63,4.6,3.2,1.4,0.2


## The predictions

In [37]:
gbm.predict(masked_inputs[0].loc[0:3])

array([2, 2, 2, 2], dtype=int64)

In [38]:
gbm.predict(masked_inputs[0].loc[4:7])

array([2, 2, 2, 2], dtype=int64)

In [39]:
gbm.predict(masked_inputs[0].loc[8:11])

array([2, 2, 2, 2], dtype=int64)

In [40]:
gbm.predict(masked_inputs[0].loc[44:47])

array([2, 2, 2, 2], dtype=int64)

### All outputs for 16 coalitions

In [41]:
outputs = fm(extended_delta_indexes, zero_index=0, batch_size=10)

outputs

array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])

## The masks

In [42]:
inds = fm.varying_inputs()
inds

array([0, 1, 2, 3], dtype=int64)

In [43]:
from shap.utils._general import shapley_coefficients

coeff = shapley_coefficients(len(X_train))
coeff

array([0.25      , 0.08333333, 0.08333333, 0.25      ])

In [44]:
row_values = np.zeros((len(fm),) + outputs.shape[1:])
row_values

array([0., 0., 0., 0.])

In [45]:
mask = np.zeros(len(fm), dtype=bool)
mask

array([False, False, False, False])

In [46]:
set_size = 0
M = len(X_train)
print(mask)

for i in range(2**M):
    print(i)

    # update the mask
    delta_ind = extended_delta_indexes[i]
    
    if delta_ind != MaskedModel.delta_mask_noop_value:
        mask[delta_ind] = ~mask[delta_ind]
        print(mask)
        if mask[delta_ind]:
            set_size += 1
        else:
            set_size -= 1
        print(set_size)

    # update the output row values
    on_coeff = coeff[set_size-1]
    
    if set_size < M:
        off_coeff = coeff[set_size]

    out = outputs[i]
    for j in inds:
        if mask[j]:
            row_values[j] += out * on_coeff
        else:
            row_values[j] -= out * off_coeff
    print()

[False False False False]
0
[ True False False False]
1

1
[False False False False]
0

2
[ True False False False]
1

3
[False False False False]
0

4
[ True False False False]
1

5
[False False False False]
0

6
[ True False False False]
1

7
[False False False False]
0

8
[ True False False False]
1

9
[False False False False]
0

10
[ True False False False]
1

11
[False False False False]
0

12
[ True False False False]
1

13
[False False False False]
0

14
[ True False False False]
1

15
[False False False False]
0

