# SurvSHAP(t): Time-Dependent Explanations Of Machine Learning Survival Models
### M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek
## Experiment 2: Comparison to SurvLIME

#### Imports

In [None]:
import pandas as pd
import numpy as np
import pickle
from sksurv.util import Surv

#### Preparing data and models 

In [None]:
dataset0_train = pd.read_csv("data/exp2_dataset0_train.csv")
dataset0_test = pd.read_csv("data/exp2_dataset0_test.csv")
X_train0 = dataset0_train.iloc[:, :5]
X_test0 = dataset0_test.iloc[:, :5]
y_train0 = Surv.from_dataframe("event", "time", dataset0_train)
y_test0 = Surv.from_dataframe("event", "time", dataset0_test)

In [None]:
dataset1_train = pd.read_csv("data/exp2_dataset1_train.csv")
dataset1_test = pd.read_csv("data/exp2_dataset1_test.csv")
X_train1 = dataset1_train.iloc[:, :5]
X_test1 = dataset1_test.iloc[:, :5]
y_train1 = Surv.from_dataframe("event", "time", dataset1_train)
y_test1 = Surv.from_dataframe("event", "time", dataset1_test)

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph_dataset0 = CoxPHSurvivalAnalysis()
cph_dataset0.fit(X_train0, y_train0)

In [None]:
cph_dataset1 = CoxPHSurvivalAnalysis()
cph_dataset1.fit(X_train1, y_train1)

#### Reading explanations
##### SurvLIME dataset0

In [None]:
with open("pickles/exp2_survlime_dataset0_cph", "rb") as f:
    exp2_survlime_dataset0_cph = pickle.load(f)
with open("pickles/exp2_survlime_dataset0_rsf", "rb") as f:
    exp2_survlime_dataset0_rsf = pickle.load(f)

##### SurvLIME dataset1 

In [None]:
with open("pickles/exp2_survlime_dataset1_cph", "rb") as f:
    exp2_survlime_dataset1_cph = pickle.load(f)
with open("pickles/exp2_survlime_dataset1_rsf", "rb") as f:
    exp2_survlime_dataset1_rsf = pickle.load(f)

##### SurvSHAP(t) dataset0

In [None]:
with open("pickles/exp2_survshap_dataset0_cph", "rb") as f:
    exp2_survshap_dataset0_cph = pickle.load(f)
with open("pickles/exp2_survshap_dataset0_rsf", "rb") as f:
    exp2_survshap_dataset0_rsf = pickle.load(f)

##### SurvSHAP(t) dataset1

In [None]:
with open("pickles/exp2_survshap_dataset1_cph", "rb") as file:
    exp2_survshap_dataset1_cph = pickle.load(file)
with open("pickles/exp2_survshap_dataset1_rsf", "rb") as file:
    exp2_survshap_dataset1_rsf = pickle.load(file)

#### Local accuracy

In [None]:
def get_local_accuracy_from_shap_explanations(all_explanations, method_label, cluster_label, model_label, last_index=None):
    if last_index is None:
        last_index=len(all_explanations[0].timestamps)
    diffs = []
    preds = []
    for explanation in all_explanations:
        preds.append(explanation.predicted_function[:last_index])
        diffs.append(explanation.predicted_function[:last_index] - explanation.baseline_function[:last_index] - np.array(explanation.result.iloc[:, 5:].sum(axis=0))[:last_index])
    diffs_squared = np.array(diffs)**2
    E_diffs_squared = np.mean(diffs_squared, axis=0)
    preds_squared = np.array(preds)**2
    E_preds_squared = np.mean(preds_squared, axis=0)
    return  pd.DataFrame({"time": all_explanations[0].timestamps[:last_index], "sigma": np.sqrt(E_diffs_squared) / np.sqrt(E_preds_squared), 
     "method": method_label, "cluster": cluster_label, "model": model_label })

In [None]:
def get_local_accuracy_from_lime_explanations(all_explanations, method_label, cluster_label, model_label, last_index=None):
    if last_index is None:
        last_index=len(all_explanations[0].timestamps)
    diffs = []
    preds = []
    for explanation in all_explanations:
        preds.append(explanation.predicted_sf[:last_index])
        diffs.append(explanation.predicted_sf[:last_index] - np.array(explanation.survlime_sf[:last_index]))
    diffs_squared = np.array(diffs)**2
    E_diffs_squared = np.mean(diffs_squared, axis=0)
    preds_squared = np.array(preds)**2
    E_preds_squared = np.mean(preds_squared, axis=0)
    return  pd.DataFrame({"time": all_explanations[0].timestamps[:last_index], "sigma": np.sqrt(E_diffs_squared) / np.sqrt(E_preds_squared), 
    "method": method_label, "cluster": cluster_label, "model": model_label })

In [None]:
local_accuracy_shap_cph_cluster_0 = get_local_accuracy_from_shap_explanations(exp2_survshap_dataset0_cph, "shap", "0", "cph")
local_accuracy_lime_cph_cluster_0 = get_local_accuracy_from_lime_explanations(exp2_survlime_dataset0_cph, "lime", "0", "cph")

pd.concat([local_accuracy_shap_cph_cluster_0, local_accuracy_lime_cph_cluster_0]).to_csv("results/exp2_local_accuracy_cph_dataset0.csv")

In [None]:
local_accuracy_shap_cph_cluster_1 = get_local_accuracy_from_shap_explanations(exp2_survshap_dataset1_cph, "shap", "1", "cph")
local_accuracy_lime_cph_cluster_1 = get_local_accuracy_from_lime_explanations(exp2_survlime_dataset1_cph, "lime", "1", "cph")

pd.concat([local_accuracy_shap_cph_cluster_1, local_accuracy_lime_cph_cluster_1]).to_csv("results/exp2_local_accuracy_cph_dataset1.csv")

In [None]:
local_accuracy_shap_rsf_cluster_0 = get_local_accuracy_from_shap_explanations(exp2_survshap_dataset0_rsf, "shap", "0", "rsf")
local_accuracy_lime_rsf_cluster_0 = get_local_accuracy_from_lime_explanations(exp2_survlime_dataset0_rsf, "lime", "0", "rsf")

pd.concat([local_accuracy_shap_rsf_cluster_0, local_accuracy_lime_rsf_cluster_0]).to_csv("results/exp2_local_accuracy_rsf_dataset0.csv")

In [None]:
local_accuracy_shap_rsf_cluster_1 = get_local_accuracy_from_shap_explanations(exp2_survshap_dataset1_rsf, "shap", "1", "rsf")
local_accuracy_lime_rsf_cluster_1 = get_local_accuracy_from_lime_explanations(exp2_survlime_dataset1_rsf, "lime", "1", "rsf")

pd.concat([local_accuracy_shap_rsf_cluster_1, local_accuracy_lime_rsf_cluster_1]).to_csv("results/exp2_local_accuracy_rsf_dataset1.csv")

#### Importance rankings

In [None]:
def get_orderings_and_ranks_lime(explanations):
    importance_orderings = []
    importance_ranks = []
    for explanation in explanations:
        df = explanation.result
        df["impact"] = df["variable_value"] * df["coefficient"] 
        importance_orderings.append(df.sort_values(by="impact", key=lambda x: -abs(x)).index.to_list())
        importance_ranks.append(np.abs(df.impact).rank(ascending=False).to_list())
    return pd.DataFrame(importance_orderings), pd.DataFrame(importance_ranks)

from scipy.integrate import trapezoid    
def get_orderings_and_ranks_shap(explanations):
    importance_orderings = []
    importance_ranks = []
    for explanation in explanations:
        df = explanation.result.copy()
        df["aggregated_change"] = trapezoid(np.abs(df.iloc[:, 5:].values), explanation.timestamps)
        importance_orderings.append(df.sort_values(by="aggregated_change", key=lambda x: -abs(x)).index.to_list())
        importance_ranks.append(np.abs(df.aggregated_change).rank(ascending=False).to_list())
    return pd.DataFrame(importance_orderings), pd.DataFrame(importance_ranks)

from scipy.stats import weightedtau
def mean_weighted_tau(ranks1, ranks2):
    taus = [None] * 100
    for i in range(100):
        tau, _ = weightedtau(ranks1.iloc[i], ranks2.iloc[i])
        taus[i] = tau
    return np.mean(taus)

def prepare_ranking_summary_long(ordering):
    res = pd.DataFrame(columns=[0, 1, 2, 3, 4])
    for i in range(5):
        tmp = pd.DataFrame(ordering[i].value_counts().to_dict(), index=[i+1])
        res = pd.concat([res, tmp])
    res = res.reset_index().rename(columns={0: "x1", 1: "x2", 2: "x3", 3: "x4", 4: "x5", "index": "importance_ranking"})
    return res.melt(id_vars=["importance_ranking"], value_vars=["x1", "x2", "x3", "x4", "x5"])

##### dataset0
- $\beta^T = [10^{−6}, 0.1, -0.15, 10^{−6}, 10^{−6}]$
- ranking (by index): [0/3/4, 1, 2]

In [None]:
cph_dataset0.coef_

##### CPH

In [None]:
dataset0_cph_survlime_orderings, dataset0_cph_survlime_ranks = get_orderings_and_ranks_lime(exp2_survlime_dataset0_cph)

In [None]:
print("The least important (0/3/4)")
print(dataset0_cph_survlime_orderings[4].value_counts())

print("The second most important (1)")
print(dataset0_cph_survlime_orderings[1].value_counts())

print("The most important (2)")
print(dataset0_cph_survlime_orderings[0].value_counts())

prepare_ranking_summary_long(dataset0_cph_survlime_orderings).to_csv("results/exp2_survlime_orderings_cph_dataset0.csv", index=False)

In [None]:
dataset0_cph_survshap_orderings, dataset0_cph_survshap_ranks = get_orderings_and_ranks_shap(exp2_survshap_dataset0_cph)

In [None]:
print("The least important (0/3/4)")
print(dataset0_cph_survshap_orderings[4].value_counts())

print("The second most important (1)")
print(dataset0_cph_survshap_orderings[1].value_counts())

print("The most important (2)")
print(dataset0_cph_survshap_orderings[0].value_counts())

prepare_ranking_summary_long(dataset0_cph_survshap_orderings).to_csv("results/exp2_survshap_orderings_cph_dataset0.csv", index=False)

In [None]:
# GT CPH
importance_ranks = []
for i, row in X_test0.iterrows():
    impact = row * cph_dataset0.coef_
    importance_ranks.append(np.abs(impact).rank(ascending=False).to_list())
dataset0_cph_true_ranks = pd.DataFrame(importance_ranks)

In [None]:
mean_weighted_tau(dataset0_cph_survlime_ranks, dataset0_cph_survshap_ranks)

In [None]:
mean_weighted_tau(dataset0_cph_survlime_ranks, dataset0_cph_true_ranks)

In [None]:
mean_weighted_tau(dataset0_cph_survshap_ranks, dataset0_cph_true_ranks)

#### RSF 

In [None]:
dataset0_rsf_survlime_orderings, dataset0_rsf_survlime_ranks = get_orderings_and_ranks_lime(exp2_survlime_dataset0_rsf)

In [None]:
print("The least important (0/3/4)")
print(dataset0_rsf_survlime_orderings[4].value_counts())

print("The second most important (1)")
print(dataset0_rsf_survlime_orderings[1].value_counts())

print("The most important (2)")
print(dataset0_rsf_survlime_orderings[0].value_counts())

prepare_ranking_summary_long(dataset0_rsf_survlime_orderings).to_csv("results/exp2_survlime_orderings_rsf_dataset0.csv", index=False)

In [None]:
dataset0_rsf_survshap_orderings, dataset0_rsf_survshap_ranks = get_orderings_and_ranks_shap(exp2_survshap_dataset0_rsf)

In [None]:
print("The least important (0/3/4)")
print(dataset0_rsf_survshap_orderings[4].value_counts())

print("The second most important (1)")
print(dataset0_rsf_survshap_orderings[1].value_counts())

print("The most important (2)")
print(dataset0_rsf_survshap_orderings[0].value_counts())
prepare_ranking_summary_long(dataset0_rsf_survshap_orderings).to_csv("results/exp2_survshap_orderings_rsf_dataset0.csv", index=False)

##### dataset1
- $\beta^T = [10^{−6}, −0.15, 10^{−6}, 10^{−6}, −0.1]$
- ranking (by index): [0/2/3, 4, 1]

In [None]:
cph_dataset1.coef_

#### CPH

In [None]:
dataset1_cph_survlime_orderings, dataset1_cph_survlime_ranks = get_orderings_and_ranks_lime(exp2_survlime_dataset1_cph)

In [None]:
print("The least important (0/2/3)")
print(dataset1_cph_survlime_orderings[4].value_counts())

print("The second most important (4)")
print(dataset1_cph_survlime_orderings[1].value_counts())

print("The most important (1)")
print(dataset1_cph_survlime_orderings[0].value_counts())

prepare_ranking_summary_long(dataset1_cph_survlime_orderings).to_csv("results/exp2_survlime_orderings_cph_dataset1.csv", index=False)

In [None]:
dataset1_cph_survshap_orderings, dataset1_cph_survshap_ranks = get_orderings_and_ranks_shap(exp2_survshap_dataset1_cph)

In [None]:
print("The least important (0/2/3)")
print(dataset1_cph_survshap_orderings[4].value_counts())

print("The second most important (4)")
print(dataset1_cph_survshap_orderings[1].value_counts())

print("The most important (1)")
print(dataset1_cph_survshap_orderings[0].value_counts())

prepare_ranking_summary_long(dataset1_cph_survshap_orderings).to_csv("results/exp2_survshap_orderings_cph_dataset1.csv", index=False)

In [None]:
# GT CPH
importance_ranks = [] 
for i, row in X_test1.iterrows():
    impact = row * cph_dataset1.coef_
    importance_ranks.append(np.abs(impact).rank(ascending=False).to_list())
dataset1_cph_true_ranks = pd.DataFrame(importance_ranks)

In [None]:
mean_weighted_tau(dataset1_cph_survlime_ranks, dataset1_cph_survshap_ranks)

In [None]:
mean_weighted_tau(dataset1_cph_survlime_ranks, dataset1_cph_true_ranks)

In [None]:
mean_weighted_tau(dataset1_cph_survshap_ranks, dataset1_cph_true_ranks)

#### RSF 

In [None]:
dataset1_rsf_survlime_orderings, dataset1_rsf_survlime_ranks = get_orderings_and_ranks_lime(exp2_survlime_dataset1_rsf)

In [None]:
print("The least important (0/2/3)")
print(dataset1_rsf_survlime_orderings[4].value_counts())

print("The second most important (4)")
print(dataset1_rsf_survlime_orderings[1].value_counts())

print("The most important (1)")
print(dataset1_rsf_survlime_orderings[0].value_counts())

prepare_ranking_summary_long(dataset1_rsf_survlime_orderings).to_csv("results/exp2_survlime_orderings_rsf_dataset1.csv", index=False)

In [None]:
dataset1_rsf_survshap_orderings, dataset1_rsf_survshap_ranks = get_orderings_and_ranks_shap(exp2_survshap_dataset1_rsf)

In [None]:
print("The least important (0/2/3)")
print(dataset1_rsf_survshap_orderings[4].value_counts())

print("The second most important (4)")
print(dataset1_rsf_survshap_orderings[1].value_counts())

print("The most important (1)")
print(dataset1_rsf_survshap_orderings[0].value_counts())

prepare_ranking_summary_long(dataset1_rsf_survshap_orderings).to_csv("results/exp2_survshap_orderings_rsf_dataset1.csv", index=False)