# Example: WeightedSHAP on the fraud dataset

- In this notebook, we introduce weightedSHAP, which provides a generalized feature attribution method.

In [1]:
import sys, os
import numpy as np
import pickle 
np.random.seed(2022)
from analysis_utils import *

sys.path.append('../weightedSHAP')
import data, train, weightedSHAPEngine

## Load data
- We use the fraud dataset (https://www.openml.org/search?type=data&status=active&id=42397).
- A function `data.load_data` will load the `train`, `val`, `est`, and `test` datasets. 
 - `train`: to train a model to explain
 - `val`: to optimize hyperparameters
 - `est`: to estimate coalition functions
 - `test`: to evaluate the quality of feature attributions 

In [2]:
# Load dataset
problem='classification' 
dataset='fraud'
ML_model='boosting' 
(X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)=data.load_data(problem, dataset)    

------------------------------
Load a dataset
------------------------------
--------------------------------------------------
Fraud Detection
--------------------------------------------------
------------------------------
Before adding noise
Shape of X_train, X_val, X_est, X_test: (13310, 30), (1902, 30), (1902, 30), (1902, 30)
------------------------------
Rho: 0.0388
After adding noise
Shape of X_train, X_val, X_est, X_test: (13310, 90), (1902, 90), (1902, 90), (1902, 90)
------------------------------


## Train a model to explain
 - This step is a typical routine in machine learning. Given training and validation datasets, we train a model. Our goal is to interpret this model by looking a particular prediction (i.e., a local attribution problem).

In [3]:
model_to_explain=train.create_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model)

------------------------------
Train a model
Train a model to explain: Boosting
Training until validation scores don't improve for 25 rounds
Early stopping, best iteration is:
[12]	valid_0's binary_logloss: 0.646885	valid_0's binary_error: 0.0320715
Elapsed time for training a model to explain: 0.54 seconds
------------------------------


## Compute attributions and evaluate its performance
- `weightedSHAPEngine.run_attribution_core` computes conditional expectation values $\mathbb{E}[f(X) \mid X_S = x_S]$ based on various weighted versions $\phi_\mathbf{w} (x_i) := \sum_{j=1} ^{d} w_j \Delta_{j}(x_i)$ of SHAP (Equation (5) of the paper). We consider a set $\mathcal{W}$ which includes SHAP as well.
- We store them in `exp_dict` which includes
 - `cond_pred_keep_absolute` (add features with large absolute values first)
 - `cond_pred_remove_absolute` (remove features with large absolute values first)
 - `pred_masking` (add features with large absolute values first and other featuers are masked with zero).

In [None]:
if not os.path.exists('fraud_example.pickle'):
    # Train a surrogate model and generate a coalition function
    # To efficiently estimate a conditional coalition function, we train a surrogate model. 
    conditional_extension=train.generate_coalition_function(model_to_explain, X_train, X_est, problem, ML_model)
    
    # With a surrogate model, we compute conditional expectations
    exp_dict=weightedSHAPEngine.run_attribution_core(problem, ML_model,
                                                     model_to_explain, conditional_extension,
                                                     X_train, y_train,
                                                     X_val, y_val, 
                                                     X_test, y_test)

    with open('fraud_example.pickle', 'wb') as handle:
        pickle.dump(exp_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open('fraud_example.pickle', 'rb') as handle:
        exp_dict = pickle.load(handle)


Elapsed time for training a surrogate model: 548.97 seconds


  0%|                                                                                                                | 0/100 [00:00<?, ?it/s]

Total number of random sets: 900, GR_stat: 1.007274776668469
Total number of random sets: 1000, GR_stat: 1.0067500997071723
Total number of random sets: 1100, GR_stat: 1.0066346271310442
Total number of random sets: 1200, GR_stat: 1.0051719850889846
Total number of random sets: 1300, GR_stat: 1.0059864735184305
Total number of random sets: 1400, GR_stat: 1.0055448719292726
Total number of random sets: 1500, GR_stat: 1.004794489707504
MCI score: 8994, Therehosld: 8991
We have seen 1600 random subsets for each feature.


  1%|█                                                                                                     | 1/100 [00:46<1:17:24, 46.92s/it]

Total number of random sets: 900, GR_stat: 1.007101462204402
Total number of random sets: 1100, GR_stat: 1.0064411589724716
Total number of random sets: 1300, GR_stat: 1.006964299222497
Total number of random sets: 1400, GR_stat: 1.005469781048714
Total number of random sets: 1500, GR_stat: 1.0050134224244465
Total number of random sets: 1600, GR_stat: 1.007086689745792
Total number of random sets: 1700, GR_stat: 1.0041482290171406
MCI score: 8996, Therehosld: 8991
We have seen 1800 random subsets for each feature.


  2%|██                                                                                                    | 2/100 [01:39<1:22:03, 50.24s/it]

Total number of random sets: 800, GR_stat: 1.0116016608201182
Total number of random sets: 1000, GR_stat: 1.0084466638234113
Total number of random sets: 1300, GR_stat: 1.0054095659241913
Total number of random sets: 1400, GR_stat: 1.0049605126227612
MCI score: 8997, Therehosld: 8991
We have seen 1500 random subsets for each feature.


  3%|███                                                                                                   | 3/100 [02:22<1:16:04, 47.06s/it]

Total number of random sets: 1200, GR_stat: 1.0074184719588355
Total number of random sets: 1300, GR_stat: 1.0048106818611944
MCI score: 8994, Therehosld: 8991
We have seen 1400 random subsets for each feature.


  4%|████                                                                                                  | 4/100 [03:03<1:11:00, 44.38s/it]

Total number of random sets: 1000, GR_stat: 1.0078974448414677
Total number of random sets: 1100, GR_stat: 1.0047230983769468
MCI score: 8995, Therehosld: 8991
We have seen 1200 random subsets for each feature.


  5%|█████                                                                                                 | 5/100 [03:37<1:04:27, 40.71s/it]

Total number of random sets: 800, GR_stat: 1.00803748990048
Total number of random sets: 1400, GR_stat: 1.0047963131950424
MCI score: 8994, Therehosld: 8991
We have seen 1500 random subsets for each feature.


  6%|██████                                                                                                | 6/100 [04:20<1:05:10, 41.60s/it]

Total number of random sets: 1100, GR_stat: 1.0048207882301539
MCI score: 8992, Therehosld: 8991
We have seen 1200 random subsets for each feature.


  7%|███████▏                                                                                              | 7/100 [04:54<1:00:38, 39.13s/it]

Total number of random sets: 900, GR_stat: 1.0085050163717155
Total number of random sets: 1100, GR_stat: 1.0059396772246032
Total number of random sets: 1200, GR_stat: 1.0052468816311353
Total number of random sets: 1300, GR_stat: 1.0045262657958098
MCI score: 8996, Therehosld: 8991
We have seen 1400 random subsets for each feature.


  8%|████████▏                                                                                             | 8/100 [05:34<1:00:29, 39.45s/it]

Total number of random sets: 1000, GR_stat: 1.0049833615234542
MCI score: 8996, Therehosld: 8991
We have seen 1100 random subsets for each feature.


  9%|█████████▎                                                                                              | 9/100 [06:05<55:55, 36.87s/it]

Total number of random sets: 1000, GR_stat: 1.0114249651285447
Total number of random sets: 1100, GR_stat: 1.0083409315243033
Total number of random sets: 1200, GR_stat: 1.0097357892547856
Total number of random sets: 1300, GR_stat: 1.0113938536418166
Total number of random sets: 1400, GR_stat: 1.006596022380817
Total number of random sets: 1500, GR_stat: 1.0053342681570259
Total number of random sets: 1600, GR_stat: 1.0071756715193805
Total number of random sets: 1700, GR_stat: 1.0057508285644738
Total number of random sets: 1800, GR_stat: 1.007268489549289
Total number of random sets: 1900, GR_stat: 1.00816040509585
Total number of random sets: 2000, GR_stat: 1.0081773674203363
Total number of random sets: 2100, GR_stat: 1.0073145730361863
Total number of random sets: 2200, GR_stat: 1.0058554864598057
Total number of random sets: 2300, GR_stat: 1.0056892125418884
Total number of random sets: 2400, GR_stat: 1.0079856988839428
Total number of random sets: 2500, GR_stat: 1.0046990101346

 10%|██████████                                                                                           | 10/100 [07:22<1:13:49, 49.21s/it]

Total number of random sets: 800, GR_stat: 1.0089416082261773
Total number of random sets: 900, GR_stat: 1.0087905032361324
Total number of random sets: 1000, GR_stat: 1.0088305513367513
Total number of random sets: 1100, GR_stat: 1.0060630339840113
Total number of random sets: 1200, GR_stat: 1.0054017502068644
Total number of random sets: 1300, GR_stat: 1.0060885363871754
Total number of random sets: 1500, GR_stat: 1.0046194897592617
MCI score: 8993, Therehosld: 8991
We have seen 1600 random subsets for each feature.


 11%|███████████                                                                                          | 11/100 [08:09<1:11:42, 48.34s/it]

Total number of random sets: 700, GR_stat: 1.0077533436389683
Total number of random sets: 1100, GR_stat: 1.0102683814519522
Total number of random sets: 1200, GR_stat: 1.0066323879801786
Total number of random sets: 1300, GR_stat: 1.007311530286495
Total number of random sets: 1400, GR_stat: 1.0066018242839772
Total number of random sets: 1500, GR_stat: 1.0043927120397595
MCI score: 8992, Therehosld: 8991
We have seen 1600 random subsets for each feature.


 12%|████████████                                                                                         | 12/100 [08:55<1:10:03, 47.76s/it]

Total number of random sets: 1300, GR_stat: 1.00554749511505
Total number of random sets: 1500, GR_stat: 1.0061762138723847
Total number of random sets: 1600, GR_stat: 1.0043106287796406
MCI score: 8997, Therehosld: 8991
We have seen 1700 random subsets for each feature.


 13%|█████████████▏                                                                                       | 13/100 [09:44<1:09:56, 48.24s/it]

Total number of random sets: 700, GR_stat: 1.0100307989324342
Total number of random sets: 1100, GR_stat: 1.0078306419865517
Total number of random sets: 1200, GR_stat: 1.006167067828919
Total number of random sets: 1300, GR_stat: 1.0068750390478802
Total number of random sets: 1400, GR_stat: 1.005611606491501
Total number of random sets: 1500, GR_stat: 1.0038825693075275
MCI score: 8992, Therehosld: 8991
We have seen 1600 random subsets for each feature.


 14%|██████████████▏                                                                                      | 14/100 [10:31<1:08:21, 47.69s/it]

Total number of random sets: 800, GR_stat: 1.0123107947827377
Total number of random sets: 1100, GR_stat: 1.0078513717435436
Total number of random sets: 1300, GR_stat: 1.0066345386689064
Total number of random sets: 1400, GR_stat: 1.008144488388134
Total number of random sets: 1500, GR_stat: 1.0063177837627029
Total number of random sets: 1600, GR_stat: 1.0084877288255083
Total number of random sets: 1700, GR_stat: 1.0066976573876958
Total number of random sets: 1800, GR_stat: 1.003335935696763
MCI score: 8996, Therehosld: 8991
We have seen 1900 random subsets for each feature.


 15%|███████████████▏                                                                                     | 15/100 [11:26<1:10:56, 50.07s/it]

Total number of random sets: 1300, GR_stat: 1.0043011174758965
MCI score: 8994, Therehosld: 8991
We have seen 1400 random subsets for each feature.


 16%|████████████████▏                                                                                    | 16/100 [12:07<1:06:02, 47.17s/it]

Total number of random sets: 800, GR_stat: 1.0088814992608206
Total number of random sets: 900, GR_stat: 1.0073084019020526
Total number of random sets: 1100, GR_stat: 1.0044221967501579
MCI score: 8993, Therehosld: 8991
We have seen 1200 random subsets for each feature.


 17%|█████████████████▌                                                                                     | 17/100 [12:41<59:52, 43.28s/it]

Total number of random sets: 800, GR_stat: 1.0062361228450658
Total number of random sets: 900, GR_stat: 1.0083900385561306
Total number of random sets: 1000, GR_stat: 1.0089888084853407
Total number of random sets: 1100, GR_stat: 1.0074111343348238
Total number of random sets: 1200, GR_stat: 1.006269787077985
Total number of random sets: 1300, GR_stat: 1.0069165929138328
Total number of random sets: 1600, GR_stat: 1.0037205579364952
MCI score: 8995, Therehosld: 8991
We have seen 1700 random subsets for each feature.


 18%|██████████████████▏                                                                                  | 18/100 [13:30<1:01:38, 45.10s/it]

Total number of random sets: 1100, GR_stat: 1.005722286866246
Total number of random sets: 1200, GR_stat: 1.0053898359886089
Total number of random sets: 1300, GR_stat: 1.0047727429787403
MCI score: 8996, Therehosld: 8991
We have seen 1400 random subsets for each feature.


 19%|███████████████████▌                                                                                   | 19/100 [14:11<58:58, 43.68s/it]

Total number of random sets: 1100, GR_stat: 1.0058654774879898
Total number of random sets: 1200, GR_stat: 1.0084477111199148
Total number of random sets: 1300, GR_stat: 1.003966253552515
MCI score: 8995, Therehosld: 8991
We have seen 1400 random subsets for each feature.


 20%|████████████████████▌                                                                                  | 20/100 [14:51<56:52, 42.65s/it]

Total number of random sets: 900, GR_stat: 1.0089472226669236
Total number of random sets: 1100, GR_stat: 1.00561813763626
Total number of random sets: 1200, GR_stat: 1.0059517127770687
Total number of random sets: 1300, GR_stat: 1.0068125429660706
Total number of random sets: 1500, GR_stat: 1.0074275173984415
Total number of random sets: 1600, GR_stat: 1.0045087587742487
MCI score: 8995, Therehosld: 8991
We have seen 1700 random subsets for each feature.


 21%|█████████████████████▋                                                                                 | 21/100 [15:40<58:47, 44.65s/it]

Total number of random sets: 900, GR_stat: 1.0075693790466647
Total number of random sets: 1000, GR_stat: 1.0069168787142042
Total number of random sets: 1100, GR_stat: 1.0082297666456115
Total number of random sets: 1200, GR_stat: 1.0057384748994613
Total number of random sets: 1300, GR_stat: 1.0053133340975076
Total number of random sets: 1500, GR_stat: 1.0048531394385691
MCI score: 8994, Therehosld: 8991
We have seen 1600 random subsets for each feature.


 22%|██████████████████████▋                                                                                | 22/100 [16:27<58:43, 45.18s/it]

Total number of random sets: 1100, GR_stat: 1.0052083533227063
Total number of random sets: 1200, GR_stat: 1.0057265505847763
Total number of random sets: 1400, GR_stat: 1.0046734385466016
MCI score: 8994, Therehosld: 8991
We have seen 1500 random subsets for each feature.


 23%|███████████████████████▋                                                                               | 23/100 [17:10<57:14, 44.61s/it]

Total number of random sets: 1000, GR_stat: 1.009601660423236
Total number of random sets: 1200, GR_stat: 1.0065318211282788
Total number of random sets: 1300, GR_stat: 1.0076766851554309
Total number of random sets: 1400, GR_stat: 1.005283747032439
Total number of random sets: 1500, GR_stat: 1.0051048593222507
Total number of random sets: 1600, GR_stat: 1.0052170922771897
Total number of random sets: 1700, GR_stat: 1.0055036067193428
Total number of random sets: 1800, GR_stat: 1.0041734170961198
MCI score: 8998, Therehosld: 8991
We have seen 1900 random subsets for each feature.


 24%|████████████████████████▏                                                                            | 24/100 [18:06<1:00:40, 47.90s/it]

Total number of random sets: 1100, GR_stat: 1.0075444546265124
Total number of random sets: 1300, GR_stat: 1.0061850911746568
Total number of random sets: 1500, GR_stat: 1.0072602117776293
Total number of random sets: 1700, GR_stat: 1.0052035439342615
Total number of random sets: 1800, GR_stat: 1.0045037523177187
MCI score: 8994, Therehosld: 8991
We have seen 1900 random subsets for each feature.


 25%|█████████████████████████▎                                                                           | 25/100 [19:01<1:02:43, 50.18s/it]

Total number of random sets: 1200, GR_stat: 1.004601183953125
MCI score: 8995, Therehosld: 8991
We have seen 1300 random subsets for each feature.


 26%|██████████████████████████▊                                                                            | 26/100 [19:39<57:09, 46.34s/it]

Total number of random sets: 900, GR_stat: 1.0075752224795687
Total number of random sets: 1000, GR_stat: 1.0085860968952842
Total number of random sets: 1100, GR_stat: 1.005364231975834
Total number of random sets: 1400, GR_stat: 1.0067690051336717
Total number of random sets: 1500, GR_stat: 1.0069489324753278
Total number of random sets: 1600, GR_stat: 1.0045403287313615
MCI score: 8992, Therehosld: 8991
We have seen 1700 random subsets for each feature.


 27%|███████████████████████████▊                                                                           | 27/100 [20:28<57:37, 47.36s/it]

Total number of random sets: 900, GR_stat: 1.010033082242018
Total number of random sets: 1100, GR_stat: 1.0070645976368044
Total number of random sets: 1500, GR_stat: 1.0061879406524161
Total number of random sets: 1600, GR_stat: 1.0041084704532455
MCI score: 8994, Therehosld: 8991
We have seen 1700 random subsets for each feature.


 28%|████████████████████████████▊                                                                          | 28/100 [21:18<57:40, 48.06s/it]

Total number of random sets: 1000, GR_stat: 1.0066044763084026
Total number of random sets: 1100, GR_stat: 1.00467821067648
MCI score: 8997, Therehosld: 8991
We have seen 1200 random subsets for each feature.


 29%|█████████████████████████████▊                                                                         | 29/100 [21:52<52:02, 43.98s/it]

## WeightedSHAP and Prediction recovery error curve 
 - We obtain weightedSHAP (Equation (6) of the paper)
 \begin{align*}
    \phi_{\mathrm{WeightedSHAP}}(\mathcal{T}, \mathcal{W}) := \phi_{\mathbf{w}^* (\mathcal{T}, \mathcal{W})},
 \end{align*} 
 where $\mathbf{w}^* (\mathcal{T}, \mathcal{W}) := \mathrm{argmax}_{\mathbf{w} \in \mathcal{W}} \mathcal{T} (\phi_{\mathbf{w}})$.
 - A default implementation for $\mathcal{T}$ includes the negative area under the prediction recovery error curve (AUP, Equation (4) of the paper) with `cond_pred_keep_absolute`. Note that `exp_dict` contains `remove_abolute` and `masking` as well.

In [None]:
y_test=np.array(exp_dict['true_list'])
pred_list=np.array(exp_dict['pred_list'])
cond_pred_keep_absolute=np.array(exp_dict['cond_pred_keep_absolute'])
# Find a optimal weight and construct WeightedSHAP
optimal_ind_keep_absolute_list=find_optimal_list(cond_pred_keep_absolute, pred_list) 

# {13:MCI, 6:SHAP}
cond_pred_keep_absolute_short=np.array((cond_pred_keep_absolute[:,13,:],
                                    cond_pred_keep_absolute[:,6,:],
                                    cond_pred_keep_absolute[np.arange(len(optimal_ind_keep_absolute_list)),
                                                            optimal_ind_keep_absolute_list,:])).transpose((1,0,2))
recovery_curve_keep_absolute=np.mean(np.abs(cond_pred_keep_absolute_short - pred_list.reshape(-1,1,1)), axis=0)


 - The recovery error curves will become more smooth by increasing the number of test samples. It comes with more computational expenses. 

In [None]:
n_features=len(recovery_curve_keep_absolute[0])
n_display_features=int(n_features*0.6)

plt.plot(recovery_curve_keep_absolute[0][max(1,int(n_features*0.075)):n_display_features],
         label='MCI', color='blue', linewidth=2, alpha=0.6)
plt.plot(recovery_curve_keep_absolute[1][max(1,int(n_features*0.075)):n_display_features],
         label='Shapley', color='green', linewidth=2, alpha=0.6)
plt.plot(recovery_curve_keep_absolute[2][max(1,int(n_features*0.075)):n_display_features],
         label='WeightedSHAP', color='red', linewidth=2, alpha=0.6)
plt.legend(fontsize=12)
xlabel_text='Number of features added' 
plt.title(f'Prediction recovery error curve, Dataset: fraud \n the lower, the better', fontsize=15)
plt.xticks(np.arange(n_features)[max(1,int(n_features*0.075)):n_display_features][::n_display_features//6],
               np.arange(n_features)[max(1,int(n_features*0.075)):n_display_features][::n_display_features//6])
plt.xlabel(xlabel_text, fontsize=15)
plt.ylabel(r'$|f(x)-\mathbb{E}[f(X) \mid X_S = x_S]|$', fontsize=15)
    