# WeightedSHAP on the EEG dataset

In this notebook, we use weightedSHAP feature attribution method to determine which features are most informative for prediction of the binary classes Read and Speak

In [1]:
%pip install pandas==1.5.1

import sys, os
import numpy as np
import pandas as pd
sys.path.append('../')

Note: you may need to restart the kernel to use updated packages.


## Load data

In [2]:
df = pd.read_pickle('../eeg_dataset.pkl')
print(df)

     subject train_test     labels        F0        F1        F2        F3  \
0       CF60       test  WORD_BASS  0.000331 -0.000242  0.000033  0.000084   
1       CF60       test   WORD_BOG  0.000338 -0.000174  0.000187  0.000090   
2       CF60       test   WORD_BAR  0.000318  0.000161  0.000156  0.000127   
3       CF60       test       READ  0.000349  0.000043 -0.000017  0.000005   
4       CF60       test  WORD_BEAD  0.000327 -0.000270 -0.000057  0.000024   
...      ...        ...        ...       ...       ...       ...       ...   
3447    CM29      train       READ -0.007828  0.000395  0.000300  0.000199   
3448    CM29      train       READ -0.007828  0.000395  0.000300  0.000199   
3449    CM29      train  WORD_BATH  0.010894 -0.000039 -0.000098  0.000111   
3450    CM29      train       READ  0.010875 -0.000102 -0.000098  0.000093   
3451    CM29      train       READ -0.007827  0.000392  0.000289  0.000213   

            F4        F5        F6  ...          F170          

## We will now split the dataset into train, val, est, and test datasets

In [3]:
import weightedSHAP

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.svm import SVC

from torch.utils.data import DataLoader

import torch


debug = True
train_all = df[df["train_test"] == "train"]
test_all = df[df["train_test"] == "test"]

subj_list = np.unique(df["subject"])

for test_subj in subj_list:
    # train_subjs = [subj for subj in subj_list if subj != test_subj]
    # Use the held-out subject for test (and validation??)
    held_data = df[df["subject"] == test_subj]
    test = held_data[held_data["train_test"] == "test"]
    
    # Use all other subjects for train
    train_data = df[df["subject"] != test_subj]
    train_all = train_data[train_data["train_test"] != "test"]
    
    train, est = train_test_split(train_all, test_size = 0.2)
    est, val = train_test_split(est, test_size = 0.5)
    # Have: train, est, val, and test
    
    # Need: X_train, y_train, X_val, y_val, X_test, y_test
    
    # select only the features
    X_train = train.iloc[:, 3:184]
    X_est = est.iloc[:, 3:184]
    X_val = val.iloc[:, 3:184] 
    X_test = test.iloc[:, 3:184]
    
    y_train = [0 if i == "READ" else 1 for i in train["labels"]]
    y_est = [0 if i == "READ" else 1 for i in est["labels"]]
    y_val = [0 if i == "READ" else 1 for i in val["labels"]]
    y_test = [0 if i == "READ" else 1 for i in test["labels"]]
    
    X_train = torch.tensor(X_train.values, dtype=torch.float32)
    X_est = torch.tensor(X_est.values, dtype=torch.float32)
    X_val = torch.tensor(X_val.values, dtype=torch.float32)
    X_test= torch.tensor(X_test.values, dtype=torch.float32)
    
    y_train = torch.tensor(y_train)
    y_est = torch.tensor(y_est)
    y_val = torch.tensor(y_val)
    y_test = torch.tensor(y_test)
    
    # Have: X_train, y_train, X_val, y_val, X_test, y_test
    
    # Create the model
    svc = SVC(kernel='rbf', C = 1.0, class_weight='balanced')
    # Train the model
    svc.fit(X_train, y_train)
    # Generate a conditional coalition function
    conditional_extension = weightedSHAP.generate_coalition_function(svc, X_train, X_est, 'classification', 'eeg')
    
    # With the conditional coalition function, compute attributions
    exp_dict = weightedSHAP.compute_attributions('classification', 'eeg',
                                                 svc, conditional_extension,
                                                 X_train, y_train,
                                                 X_val, y_val,
                                                 X_test, y_test)
    
    print(exp_dict)
    
    
    
    if debug:
        break
    

Elapsed time for training a surrogate model: 14.58 seconds


  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0519176483168087
Total number of random sets: 300, GR_stat: 1.03052441506388
Total number of random sets: 400, GR_stat: 1.0180687799266361
Total number of random sets: 500, GR_stat: 1.0308962365809942
Total number of random sets: 600, GR_stat: 1.0126624737275611
Total number of random sets: 700, GR_stat: 1.0099584960472079
Total number of random sets: 800, GR_stat: 1.0124462180458087
Total number of random sets: 900, GR_stat: 1.0101696846747243
Total number of random sets: 1000, GR_stat: 1.0109577092508033
Total number of random sets: 1100, GR_stat: 1.0083569753985577
Total number of random sets: 1200, GR_stat: 1.0104025034046982
Total number of random sets: 1300, GR_stat: 1.0083584812976214
Total number of random sets: 1400, GR_stat: 1.0077709102186847
Total number of random sets: 1500, GR_stat: 1.0057143665420596
Total number of random sets: 1600, GR_stat: 1.0043506021238449
Therehosld: 17982
We have seen 1700 random subsets for each featu

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0334566973440984
Total number of random sets: 300, GR_stat: 1.0316911900421741
Total number of random sets: 400, GR_stat: 1.0245169214042817
Total number of random sets: 500, GR_stat: 1.0166815717869575
Total number of random sets: 600, GR_stat: 1.0156415955039557
Total number of random sets: 700, GR_stat: 1.0139871466006347
Total number of random sets: 800, GR_stat: 1.0112438530437728
Total number of random sets: 900, GR_stat: 1.014034306036496
Total number of random sets: 1000, GR_stat: 1.0087343343051578
Total number of random sets: 1100, GR_stat: 1.006878374364216
Total number of random sets: 1200, GR_stat: 1.0107319022490104
Total number of random sets: 1300, GR_stat: 1.0056131082514161
Total number of random sets: 1400, GR_stat: 1.0068430154387145
Total number of random sets: 1500, GR_stat: 1.006082883460184
Total number of random sets: 1600, GR_stat: 1.0074599842108274
Total number of random sets: 1700, GR_stat: 1.0045625104210314
The

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.051299791950762
Total number of random sets: 300, GR_stat: 1.0282104657051427
Total number of random sets: 400, GR_stat: 1.0194696646592145
Total number of random sets: 500, GR_stat: 1.0287588270703216
Total number of random sets: 600, GR_stat: 1.0180336527768956
Total number of random sets: 700, GR_stat: 1.008676052573093
Total number of random sets: 800, GR_stat: 1.0093567837744055
Total number of random sets: 900, GR_stat: 1.0118620776033205
Total number of random sets: 1000, GR_stat: 1.0099675125955474
Total number of random sets: 1100, GR_stat: 1.0092586860536505
Total number of random sets: 1200, GR_stat: 1.0053827151917623
Total number of random sets: 1300, GR_stat: 1.0070315559295842
Total number of random sets: 1400, GR_stat: 1.0065817104575419
Total number of random sets: 1500, GR_stat: 1.0051946086146835
Total number of random sets: 1600, GR_stat: 1.0043505121143939
Therehosld: 17982
We have seen 1700 random subsets for each featu

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0369111163410978
Total number of random sets: 300, GR_stat: 1.023885106480971
Total number of random sets: 400, GR_stat: 1.0172107268235526
Total number of random sets: 500, GR_stat: 1.0128307396001095
Total number of random sets: 600, GR_stat: 1.0110187156350547
Total number of random sets: 700, GR_stat: 1.0098963339612061
Total number of random sets: 800, GR_stat: 1.0070159252299458
Total number of random sets: 900, GR_stat: 1.0066331483720958
Total number of random sets: 1000, GR_stat: 1.007544011809485
Total number of random sets: 1100, GR_stat: 1.00784057738517
Total number of random sets: 1200, GR_stat: 1.0067772425992145
Total number of random sets: 1300, GR_stat: 1.006467493858842
Total number of random sets: 1400, GR_stat: 1.0063147208245902
Total number of random sets: 1500, GR_stat: 1.0065081399062048
Total number of random sets: 1600, GR_stat: 1.0084734733945566
Total number of random sets: 1700, GR_stat: 1.0058123072743097
Total

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.062917966060532
Total number of random sets: 300, GR_stat: 1.0342190176756851
Total number of random sets: 400, GR_stat: 1.0167600566606732
Total number of random sets: 500, GR_stat: 1.0202233617427123
Total number of random sets: 600, GR_stat: 1.014766867708206
Total number of random sets: 700, GR_stat: 1.0120042948946657
Total number of random sets: 800, GR_stat: 1.0080595811488249
Total number of random sets: 900, GR_stat: 1.0105332577994444
Total number of random sets: 1000, GR_stat: 1.0093763125601982
Total number of random sets: 1100, GR_stat: 1.0083074114043433
Total number of random sets: 1200, GR_stat: 1.0093070132600523
Total number of random sets: 1300, GR_stat: 1.0111324146810394
Total number of random sets: 1400, GR_stat: 1.0083817648262072
Total number of random sets: 1500, GR_stat: 1.0079661061880831
Total number of random sets: 1600, GR_stat: 1.0053690369166306
Total number of random sets: 1700, GR_stat: 1.006325764977977
Tot

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0431416568831196
Total number of random sets: 300, GR_stat: 1.0388543832938024
Total number of random sets: 400, GR_stat: 1.0160472587558582
Total number of random sets: 500, GR_stat: 1.0155417915827623
Total number of random sets: 600, GR_stat: 1.0120887982688056
Total number of random sets: 700, GR_stat: 1.0117377595496968
Total number of random sets: 800, GR_stat: 1.0133775943349101
Total number of random sets: 900, GR_stat: 1.0081531691445524
Total number of random sets: 1000, GR_stat: 1.0084163104307402
Total number of random sets: 1100, GR_stat: 1.0106016977899503
Total number of random sets: 1200, GR_stat: 1.0077569337482852
Total number of random sets: 1300, GR_stat: 1.0060792118269157
Total number of random sets: 1400, GR_stat: 1.004846157796246
Therehosld: 17982
We have seen 1500 random subsets for each feature.


  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0345782701444544
Total number of random sets: 300, GR_stat: 1.0288086026899885
Total number of random sets: 400, GR_stat: 1.0259474532683146
Total number of random sets: 500, GR_stat: 1.0194575303675377
Total number of random sets: 600, GR_stat: 1.0167930752405048
Total number of random sets: 700, GR_stat: 1.0136472905555052
Total number of random sets: 800, GR_stat: 1.0106661567669581
Total number of random sets: 900, GR_stat: 1.0090940606748686
Total number of random sets: 1000, GR_stat: 1.0065507980487394
Total number of random sets: 1100, GR_stat: 1.0077392142525867
Total number of random sets: 1200, GR_stat: 1.0076826954450406
Total number of random sets: 1300, GR_stat: 1.0096060007204413
Total number of random sets: 1400, GR_stat: 1.0062879513433827
Total number of random sets: 1500, GR_stat: 1.009081356797844
Total number of random sets: 1600, GR_stat: 1.006982486698088
Total number of random sets: 1700, GR_stat: 1.0048080009754572
Th

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0387274146632044
Total number of random sets: 300, GR_stat: 1.0222136349229287
Total number of random sets: 400, GR_stat: 1.019510189572663
Total number of random sets: 500, GR_stat: 1.0119685231731461
Total number of random sets: 600, GR_stat: 1.0181684695446631
Total number of random sets: 700, GR_stat: 1.0119365365821975
Total number of random sets: 800, GR_stat: 1.0093277061135875
Total number of random sets: 900, GR_stat: 1.0108412476150885
Total number of random sets: 1000, GR_stat: 1.0073192777390618
Total number of random sets: 1100, GR_stat: 1.0090025692674918
Total number of random sets: 1200, GR_stat: 1.0075149981505975
Total number of random sets: 1300, GR_stat: 1.0082627855101163
Total number of random sets: 1400, GR_stat: 1.0059423774031955
Total number of random sets: 1500, GR_stat: 1.0056446530384353
Total number of random sets: 1600, GR_stat: 1.0051771370141318
Total number of random sets: 1700, GR_stat: 1.0058557096374907
T

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.037381756450761
Total number of random sets: 300, GR_stat: 1.0374229970159572
Total number of random sets: 400, GR_stat: 1.0318820729930323
Total number of random sets: 500, GR_stat: 1.0225074782600574
Total number of random sets: 600, GR_stat: 1.0148426829346098
Total number of random sets: 700, GR_stat: 1.0133309448679784
Total number of random sets: 800, GR_stat: 1.0124146321545924
Total number of random sets: 900, GR_stat: 1.009013904006124
Total number of random sets: 1000, GR_stat: 1.0086379226412272
Total number of random sets: 1100, GR_stat: 1.0077581051553894
Total number of random sets: 1200, GR_stat: 1.0080018760324152
Total number of random sets: 1300, GR_stat: 1.0080239305621923
Total number of random sets: 1400, GR_stat: 1.0065789024878102
Total number of random sets: 1500, GR_stat: 1.0067465706292291
Total number of random sets: 1600, GR_stat: 1.0059706567198177
Total number of random sets: 1700, GR_stat: 1.0074011996327084
To

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0436695965478107
Total number of random sets: 300, GR_stat: 1.0294509706222297
Total number of random sets: 400, GR_stat: 1.0246236167268719
Total number of random sets: 500, GR_stat: 1.0221544827592
Total number of random sets: 600, GR_stat: 1.0145885888793178
Total number of random sets: 700, GR_stat: 1.011716329664492
Total number of random sets: 800, GR_stat: 1.0115048753094438
Total number of random sets: 900, GR_stat: 1.009440068715953
Total number of random sets: 1000, GR_stat: 1.01013114200031
Total number of random sets: 1100, GR_stat: 1.0064870446539629
Total number of random sets: 1200, GR_stat: 1.0070785544474627
Total number of random sets: 1300, GR_stat: 1.0081033233353869
Total number of random sets: 1400, GR_stat: 1.0068593245440685
Total number of random sets: 1500, GR_stat: 1.0061481412326647
Total number of random sets: 1600, GR_stat: 1.0056210389564055
Total number of random sets: 1700, GR_stat: 1.0048478589915562
Thereho

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.068498385970198
Total number of random sets: 300, GR_stat: 1.0287271228228294
Total number of random sets: 400, GR_stat: 1.0195145506176402
Total number of random sets: 500, GR_stat: 1.0229980535766179
Total number of random sets: 600, GR_stat: 1.0162598668761837
Total number of random sets: 700, GR_stat: 1.0102421540852342
Total number of random sets: 800, GR_stat: 1.0091880245029292
Total number of random sets: 900, GR_stat: 1.0101618017495666
Total number of random sets: 1000, GR_stat: 1.007218776549582
Total number of random sets: 1100, GR_stat: 1.0062617421024016
Total number of random sets: 1200, GR_stat: 1.0060467818450431
Total number of random sets: 1300, GR_stat: 1.0076333065614185
Total number of random sets: 1400, GR_stat: 1.007375906118512
Total number of random sets: 1500, GR_stat: 1.004981534232028
Therehosld: 17982
We have seen 1600 random subsets for each feature.


  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.044201618700027
Total number of random sets: 300, GR_stat: 1.0369610140656869
Total number of random sets: 400, GR_stat: 1.0194746740614624
Total number of random sets: 500, GR_stat: 1.0149158320606193
Total number of random sets: 600, GR_stat: 1.0113737514444203
Total number of random sets: 700, GR_stat: 1.0119037656224186
Total number of random sets: 800, GR_stat: 1.0090001157237958
Total number of random sets: 900, GR_stat: 1.0074565625787768
Total number of random sets: 1000, GR_stat: 1.0074719732315705
Total number of random sets: 1100, GR_stat: 1.00748514482312
Total number of random sets: 1200, GR_stat: 1.0050620338490397
Total number of random sets: 1300, GR_stat: 1.0056020800359864
Total number of random sets: 1400, GR_stat: 1.004665139692604
Therehosld: 17982
We have seen 1500 random subsets for each feature.


  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0472020571811813
Total number of random sets: 300, GR_stat: 1.0266778139078196
Total number of random sets: 400, GR_stat: 1.0207717312625535
Total number of random sets: 500, GR_stat: 1.018852911166702
Total number of random sets: 600, GR_stat: 1.0118929473957088
Total number of random sets: 700, GR_stat: 1.0152679089047238
Total number of random sets: 800, GR_stat: 1.009942250499335
Total number of random sets: 900, GR_stat: 1.0113128960300213
Total number of random sets: 1000, GR_stat: 1.0103997248790837
Total number of random sets: 1100, GR_stat: 1.0093136682576558
Total number of random sets: 1200, GR_stat: 1.0071191342365913
Total number of random sets: 1300, GR_stat: 1.0080474348796173
Total number of random sets: 1400, GR_stat: 1.008292074939134
Total number of random sets: 1500, GR_stat: 1.0076393425166128
Total number of random sets: 1600, GR_stat: 1.0040426886511165
Therehosld: 17982
We have seen 1700 random subsets for each featur

  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0568420073223777
Total number of random sets: 300, GR_stat: 1.029132300043103
Total number of random sets: 400, GR_stat: 1.0237186234107774
Total number of random sets: 500, GR_stat: 1.0164021641889887
Total number of random sets: 600, GR_stat: 1.0160524675028881
Total number of random sets: 700, GR_stat: 1.0115316180499683
Total number of random sets: 800, GR_stat: 1.012249417878942
Total number of random sets: 900, GR_stat: 1.0106649036258544
Total number of random sets: 1000, GR_stat: 1.0079813339102865
Total number of random sets: 1100, GR_stat: 1.0077861431366215
Total number of random sets: 1200, GR_stat: 1.0088677301893438
Total number of random sets: 1300, GR_stat: 1.0085909782489677
Total number of random sets: 1400, GR_stat: 1.0048090225965216
Therehosld: 17982
We have seen 1500 random subsets for each feature.


  model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),


Total number of random sets: 200, GR_stat: 1.0491306758970473


In [None]:
import pickle