In [1]:
import numpy as np
import pandas as pd
import shap
import lime
from utils.abundance_filtering import prevalence_filter
from sklearn.ensemble import RandomForestClassifier
import sklearn

import shap

# Interpretability Comparison

In [2]:
# helper functions
def extract_lime_val(lime_row):
    lime_i_dict = {}
    for note,val in lime_row:
        note_split = note.split(' ')
        bacteria = [i for i in note_split if '_' in i][0]

        lime_i_dict[bacteria]=val

    return lime_i_dict

def generate_lime_df(X_test,clf,explainer):
    l={}
    for i in range(0,X_test.shape[0]):
        sample_id=X_test.index[i]
        exp = explainer.explain_instance(X_test.values[i], clf.predict_proba, num_features=X_test.shape[1],num_samples=5)
        a=exp.as_list()
        l[sample_id]=a
    
    lime_df = pd.DataFrame([extract_lime_val(row) for row in l.values()],index=l.keys())
    
    return lime_df

Source: Elshawi et al (2019) https://github.com/DataSystemsGroupUT/Interpretability-comparison

### Load Dataset

In [3]:
bacteria_ab_path = 'data/bacteria_relative_abundance_concat.csv'
prevalent_filter_thr = 0.90
rel_ab_threshold= 1e-5

In [4]:
bacteria_ab = pd.read_csv('data/bacteria_relative_abundance_concat.csv',index_col=0)

bacteria_ab_x = bacteria_ab.drop(['CRC', 'study_name'],axis=1)
bacteria_ab_y = bacteria_ab['CRC']
study_names_df = bacteria_ab[['study_name']]

# preprocess
bacteria_ab_x_preval = prevalence_filter(bacteria_ab_x,rel_ab_threshold , prevalent_filter_thr)
# bacteria_ab_preval = pd.concat([bacteria_ab_x_preval,bacteria_ab_y],axis=1)

print(bacteria_ab_x_preval.shape)
bacteria_ab_x_preval.head(2)
bacteria_ab_x_preval.head(2)

(691, 226)


Unnamed: 0,Bacteroides_plebeius,Bacteroides_dorei,Faecalibacterium_prausnitzii,Eubacterium_eligens,Bacteroides_ovatus,Parabacteroides_distasonis,Ruminococcus_gnavus,Phascolarctobacterium_faecium,Bacteroides_uniformis,Bifidobacterium_longum,...,Enterococcus_faecalis,Coprobacillus_cateniformis,Oxalobacter_formigenes,Butyrivibrio_crossotus,Dialister_sp_CAG_357,Catenibacterium_mitsuokai,Christensenella_minuta,Lachnoclostridium_sp_An131,Clostridium_sp_CAG_167,Roseburia_sp_CAG_309
SAMD00114722,0.000131,0.040444,0.006118,0.00954,0.009849,0.044138,0.080406,0.0,0.03294,0.006659,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SAMD00114724,7.7e-05,0.238579,0.011717,0.0,0.001122,0.009492,0.142257,0.0,0.014983,0.005242,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## 1. Identity

- Identity: This metric states that if there are two identical instances, then they must have identical explanations.
- for every two instances in the testing data set, if the distance between the two instances is equal to zero (identical), then the distance between their explanations should be equal to zero. (ElShawi et al., 2020)

In [5]:
def calc_identity(exp1, exp2):
    dis = np.linalg.norm(exp1-exp2, axis = 1)
    total = dis.shape[0]
    true = np.where(abs(dis)<1e-8)[0].shape[0]
    violated = (total-true)/total # how many explanations vioalated the identity axiom
    
    satistfied = (1 - violated) *100 # percentage of explanations that satisfy the axiom
    return satistfied

### SHAP

In [6]:
study_names_unique = study_names_df['study_name'].unique()

identity_shap = {}

for study in study_names_unique:
    train_idx = study_names_df[study_names_df['study_name']!=study].index
    test_idx = study_names_df[study_names_df['study_name']==study].index
    
    X_train = bacteria_ab_x_preval.loc[train_idx]
    y_train = bacteria_ab_y.loc[train_idx]
    X_test = bacteria_ab_x_preval.loc[test_idx]
    y_test = bacteria_ab_y.loc[test_idx]
    
    clf = RandomForestClassifier(n_jobs=-1,random_state=0)
    clf.fit(X_train.values,y_train)
    
    # SHAP
    shap_explainer_1 = shap.TreeExplainer(clf, data=X_test)
    shap_explainer_2 = shap.TreeExplainer(clf, data=X_test)
    
    shap_explanations_1 = shap_explainer_1.shap_values(X_test)[1]
    shap_explanations_2 = shap_explainer_2.shap_values(X_test)[1]
    
    shap_identity = calc_identity(shap_explanations_1,shap_explanations_2)
    identity_shap[study]= shap_identity

In [7]:
identity_shap_df = pd.DataFrame.from_dict(identity_shap,orient='index',columns=['Identity'])
identity_shap_df.to_csv('output/SHAP_LIME_compare/identity_shap.csv')
identity_shap_df

Unnamed: 0,Identity
YachidaS_2019,100.0
ZellerG_2014,100.0
WirbelJ_2018,100.0
YuJ_2015,100.0
VogtmannE_2016,100.0


### LIME

In [8]:
study_names_unique = study_names_df['study_name'].unique()
identity_lime = {}

for study in study_names_unique:
    train_idx = study_names_df[study_names_df['study_name']!=study].index
    test_idx = study_names_df[study_names_df['study_name']==study].index
    
    X_train = bacteria_ab_x_preval.loc[train_idx]
    y_train = bacteria_ab_y.loc[train_idx]
    X_test = bacteria_ab_x_preval.loc[test_idx]
    y_test = bacteria_ab_y.loc[test_idx]
    
    clf = RandomForestClassifier(n_jobs=-1,random_state=0)
    clf.fit(X_train.values,y_train)
    
    lime_explainer_1 = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=X_train.columns, discretize_continuous=True)
    lime_explainer_2= lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=X_train.columns, discretize_continuous=True)
    
    lime_explanations_1 = generate_lime_df(X_test=X_test, clf=clf,explainer=lime_explainer_1)
    lime_explanations_2 = generate_lime_df(X_test=X_test, clf=clf,explainer=lime_explainer_2)
    
    # make index and columns the same
    lime_explanations_2.index = lime_explanations_1.index
    lime_explanations_2.columns = lime_explanations_1.columns
    
    lime_explanations_1 = lime_explanations_1.values
    lime_explanations_2 = lime_explanations_2.values
    
    
    lime_identity = calc_identity(lime_explanations_1,lime_explanations_2)
    identity_lime[study]= lime_identity
    

In [9]:
identity_lime_df = pd.DataFrame.from_dict(identity_lime,orient='index',columns=['Identity'])
identity_lime_df.to_csv('output/SHAP_LIME_compare/identity_lime.csv')
identity_lime_df

Unnamed: 0,Identity
YachidaS_2019,0.0
ZellerG_2014,0.0
WirbelJ_2018,0.0
YuJ_2015,0.0
VogtmannE_2016,0.0


# 2. Separability

- Separability: This metric states that if there are two dissimilar instances,then they must have dissimilar explanations. 
- This metric holds the assumption that the model does not have degree of freedom; this means that all the features used in the model are relevant to the prediction. (ElShawi et al., 2020)
 

In [10]:
def calc_separability(exp):
    violated = 0
    for i in range(exp.shape[0]):
        for j in range(exp.shape[0]):
            if i == j:
                continue
            eq = np.array_equal(exp[i],exp[j])
            if eq:
                violated += 1
    total = exp.shape[0]
    
    satistfied = (1 - violated) *100 # percentage of explanations that satisfy the axiom

    return satistfied

## SHAP

In [11]:
study_names_unique = study_names_df['study_name'].unique()

separability_shap = {}

for study in study_names_unique:
    train_idx = study_names_df[study_names_df['study_name']!=study].index
    test_idx = study_names_df[study_names_df['study_name']==study].index
    
    X_train = bacteria_ab_x_preval.loc[train_idx]
    y_train = bacteria_ab_y.loc[train_idx]
    X_test = bacteria_ab_x_preval.loc[test_idx]
    y_test = bacteria_ab_y.loc[test_idx]
    
    clf = RandomForestClassifier(n_jobs=-1,random_state=0)
    clf.fit(X_train.values,y_train)
    
    # SHAP
    shap_explainer_1 = shap.TreeExplainer(clf, data=X_test)
    shap_explanations_1 = shap_explainer_1.shap_values(X_test)[1]
    
    shap_separability= calc_separability(shap_explanations_1)
    separability_shap[study]= shap_separability

In [12]:
separability_shap_df = pd.DataFrame.from_dict(separability_shap,orient='index',columns=['Separability'])
separability_shap_df.to_csv('output/SHAP_LIME_compare/separability_shap.csv')
separability_shap_df

Unnamed: 0,Separability
YachidaS_2019,100
ZellerG_2014,100
WirbelJ_2018,100
YuJ_2015,100
VogtmannE_2016,100


## LIME

In [13]:
study_names_unique = study_names_df['study_name'].unique()
separability_lime = {}

for study in study_names_unique:
    train_idx = study_names_df[study_names_df['study_name']!=study].index
    test_idx = study_names_df[study_names_df['study_name']==study].index
    
    X_train = bacteria_ab_x_preval.loc[train_idx]
    y_train = bacteria_ab_y.loc[train_idx]
    X_test = bacteria_ab_x_preval.loc[test_idx]
    y_test = bacteria_ab_y.loc[test_idx]
    
    clf = RandomForestClassifier(n_jobs=-1,random_state=0)
    clf.fit(X_train.values,y_train)
    
    lime_explainer_1 = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=X_train.columns, discretize_continuous=True)    
    lime_explanations_1 = generate_lime_df(X_test=X_test, clf=clf,explainer=lime_explainer_1).values
    
    lime_separability = calc_separability(lime_explanations_1)
    separability_lime[study]= lime_separability

In [14]:
separability_lime_df = pd.DataFrame.from_dict(separability_lime,orient='index',columns=['Separability'])
separability_lime_df.to_csv('output/SHAP_LIME_compare/separability_lime.csv')
separability_lime_df

Unnamed: 0,Separability
YachidaS_2019,100
ZellerG_2014,100
WirbelJ_2018,100
YuJ_2015,100
VogtmannE_2016,100


# 3. Stability

- Stability: This metric states that instances belong to the same class must have comparable explanations.
- Measuring the stability metric is done by clustering the explanations of all instances in the testing data set using K-means clustering algorithm such that the number of clusters equals to the number of labels of the data set. For each instance in the testing data set, we compare the cluster label assigned to its explanation after clustering with the instance’s predicted class label and if they match then this explanation satisfies the stability metric. (ElShawi et al., 2020)

In [15]:
def calc_stability(exp, labels):
    total = labels.shape[0]
    label_values = np.unique(labels)
    n_clusters = label_values.shape[0]
    init = np.array([[np.average(exp[np.where(labels == i)], axis = 0)] for i in label_values]).squeeze()
    ct = sklearn.cluster.KMeans(n_clusters = n_clusters, random_state=1, n_init=1,init = init)
    ct.fit(exp)
    error = np.sum(np.abs(labels-ct.labels_))
    if error/total > 0.5:
        error = total-error
        
    score = 1 - (error/total)
    score = round(score,3)*100
    return score

## SHAP

In [16]:
study_names_unique = study_names_df['study_name'].unique()

stability_shap = {}

for study in study_names_unique:
    train_idx = study_names_df[study_names_df['study_name']!=study].index
    test_idx = study_names_df[study_names_df['study_name']==study].index
    
    X_train = bacteria_ab_x_preval.loc[train_idx]
    y_train = bacteria_ab_y.loc[train_idx]
    X_test = bacteria_ab_x_preval.loc[test_idx]
    y_test = bacteria_ab_y.loc[test_idx]
    
    clf = RandomForestClassifier(n_jobs=-1,random_state=0)
    clf.fit(X_train.values,y_train)
    
    # SHAP
    shap_explainer_1 = shap.TreeExplainer(clf, data=X_test)
    shap_explanations_1 = shap_explainer_1.shap_values(X_test)[1]
    
    shap_stability= calc_stability(shap_explanations_1,labels=y_test)
    stability_shap[study]= shap_stability

In [17]:
stability_shap_df = pd.DataFrame.from_dict(stability_shap,orient='index',columns=['Stability'])
stability_shap_df.to_csv('output/SHAP_LIME_compare/stability_shap.csv')
stability_shap_df

Unnamed: 0,Stability
YachidaS_2019,74.1
ZellerG_2014,62.3
WirbelJ_2018,75.2
YuJ_2015,65.6
VogtmannE_2016,63.5


In [18]:
stability_shap_df.mean()

Stability    68.14
dtype: float64

# LIME

In [19]:
study_names_unique = study_names_df['study_name'].unique()
stability_lime = {}

for study in study_names_unique:
    train_idx = study_names_df[study_names_df['study_name']!=study].index
    test_idx = study_names_df[study_names_df['study_name']==study].index
    
    X_train = bacteria_ab_x_preval.loc[train_idx]
    y_train = bacteria_ab_y.loc[train_idx]
    X_test = bacteria_ab_x_preval.loc[test_idx]
    y_test = bacteria_ab_y.loc[test_idx]
    
    clf = RandomForestClassifier(n_jobs=-1,random_state=0)
    clf.fit(X_train.values,y_train)
    
    lime_explainer_1 = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=X_train.columns, discretize_continuous=True)    
    lime_explanations_1 = generate_lime_df(X_test=X_test, clf=clf,explainer=lime_explainer_1).values
    
    lime_stability = calc_stability(lime_explanations_1,labels=y_test)
    stability_lime[study]= lime_stability

In [20]:
stability_lime_df = pd.DataFrame.from_dict(stability_lime,orient='index',columns=['Stability'])
stability_lime_df.to_csv('output/SHAP_LIME_compare/stability_lime.csv')
stability_lime_df

Unnamed: 0,Stability
YachidaS_2019,70.0
ZellerG_2014,69.3
WirbelJ_2018,77.6
YuJ_2015,78.1
VogtmannE_2016,63.5


In [21]:
stability_lime_df.mean()

Stability    71.7
dtype: float64

# Summary

In [22]:
shap_res = [identity_shap_df, separability_shap_df, stability_shap_df]
shap_res_df = pd.concat(shap_res,axis=1)

shap_res_df.to_excel('output/SHAP_LIME_compare/shap_result.xlsx')
shap_res_df

Unnamed: 0,Identity,Separability,Stability
YachidaS_2019,100.0,100,74.1
ZellerG_2014,100.0,100,62.3
WirbelJ_2018,100.0,100,75.2
YuJ_2015,100.0,100,65.6
VogtmannE_2016,100.0,100,63.5


In [23]:
lime_res = [identity_lime_df, separability_lime_df, stability_lime_df]
lime_res_df = pd.concat(lime_res,axis=1)

lime_res_df.to_excel('output/SHAP_LIME_compare/lime_result.xlsx')
lime_res_df

Unnamed: 0,Identity,Separability,Stability
YachidaS_2019,0.0,100,70.0
ZellerG_2014,0.0,100,69.3
WirbelJ_2018,0.0,100,77.6
YuJ_2015,0.0,100,78.1
VogtmannE_2016,0.0,100,63.5


In [24]:
lime_res_df-shap_res_df

Unnamed: 0,Identity,Separability,Stability
YachidaS_2019,-100.0,0,-4.1
ZellerG_2014,-100.0,0,7.0
WirbelJ_2018,-100.0,0,2.4
YuJ_2015,-100.0,0,12.5
VogtmannE_2016,-100.0,0,0.0
