# Code

**Team:** Kamil Grudzień, Krystian Sztenderski, Jakub Bednarz.

**Project:** (V) Time-dependent explanations of neural networks for survival analysis.

## Checkpoint I

**What was done?** We've:

1. Learned how to use the `pycox` package, in particular the DeepHit survival analysis model.
2. Learned how to use the `sksurv` package.
3. Wrapped various models (Cox proportional hazards, Random Survival Forest and DeepHit) into a uniform interface to make training and evaluation easier.
4. Read the [SurvSHAP(t) paper](https://arxiv.org/abs/2208.11080) and the [implementation provided by the authors](https://github.com/MI2DataLab/survshap).
5. Learned how to use various NN-specific explainability techniques from the `captum` library - in particular, we've adapted the `DeepLift`, `DeepLiftShap` and `IntegratedGradients` methods to provide explanations for DeepHit analogous to the ones given by SurvSHAP(t).
6. Wrapped all of them into a single interface to compare them "on equal ground."
7. Replicated experiment from the SurvSHAP(t) paper to verify we're using the library correctly. Beyond that, we've also trained and evaluated DeepHit on the same dataset, obtained ground-truth explanations with SurvSHAP(t) and ran DeepLift, DeepLiftShap and Integrated Gradients to see how NN-specific explanations compare with SurvSHAP(t).
8. We've also performed a preliminary experiment on a real-world dataset (METABRIC) in a similar fashion to one described in (7).
9. For experiments in (7) and (8), we've made a "coarse-grained analysis of the results.", i.e. we've made plots of the SHAP values at given time points and evaluated them qualitatively.

**What are the difficulties?**

1. The NN-specific explanations do not *seem* to correlate at all with the ground-truth, so a further analysis would be needed.
2. Although we evaluate the models quantitatively (via concordance index,) we still don't exactly know if the models we've trained for these dataset give "reasonable results". Of course, if the model does not perform well, the explanations given would be meaningless, so it would be wise to eliminate that cause of uncertainty.

**What will be done next?**

1. Adding quantitative metrics for comparing the explanations given by SurvSHAP(t) and other methods.
2. Performing deeper analysis of the trained models and the explanations.
3. (Possibly) Testing other NNs for survival analysis than DeepHit.
4. Adding measurement of execution time.

In [1]:
import numpy as np
import pandas as pd
from ruamel.yaml import safe_load
import scipy.integrate
import scipy.optimize
from scipy.interpolate import interp1d

from warnings import catch_warnings, simplefilter
from dataclasses import dataclass
from typing import Callable, Optional, Any, Union
from tqdm import tqdm

import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as py
import plotly.io as pio
# pio.renderers.default = "jpeg"

import torch
import torch.nn as nn
import torch.optim as optim
import torchtuples as tt
from torchtuples.practical import MLPVanilla

import sksurv
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from sksurv.datasets import get_x_y
from sksurv.functions import StepFunction
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper 
from sklearn.model_selection import train_test_split

from pycox.datasets import metabric
from pycox.evaluation import EvalSurv
import pycox.models as pycox_models

from survshap import SurvivalModelExplainer, ModelSurvSHAP, PredictSurvSHAP
from survshap.model_explanations.utils import aggregate_change

from captum.attr import DeepLift, IntegratedGradients, DeepLiftShap

import quantus
import quantus.functions.perturb_func as perturb_func

## DeepHit adapter

In [2]:
@dataclass
class TrainConf:
    optimizer_fn: Callable[[nn.Module], optim.Optimizer] \
        = lambda net: optim.Adam(net.parameters(), lr=1e-3)
    device: Optional[torch.device] = None
    batch_size: int = 256
    epochs: int = 1
    callbacks: Any = None
    verbose: bool = False
    num_workers: int = 0
    shuffle: bool = False
    metrics: Any = None
    val_data: Optional[torch.Tensor] = None
    val_batch_size: int = 8224


class DeepHitSingle:
    """A sksurv-like wrapper around pycox's DeepHitSingle."""

    def __init__(self, net: nn.Module, timestamps=None, alpha=0.2, sigma=0.1):
        """Create an instance.
        :param net: NN at the core of DeepHit. The number of output features
         is the number of cuts/timestamps.
        :param timestamps: (Optional) Predefined timestamps/steps to use. 
        Number of them must be equal to the dimensionality of net's output 
        space. If not provided, default cuts are used (see pycox docs for more 
        details.)
        :param alpha, sigma: Parameters for the pycox.models.DeepHitSingle 
        class."""

        self.net = net
        self.timestamps = None
        self.alpha = alpha
        self.sigma = sigma
    
    def encode(self, X, y, fit=False):
            input = torch.tensor(np.asarray(X), dtype=torch.float32)
            if fit:
                self.tf.fit(y["duration"], y["event"])
            durations, events = self.tf.transform(y["duration"], y["event"])
            durations = torch.tensor(durations, dtype=torch.int64)
            events = torch.tensor(events, dtype=torch.float32)
            target = (durations, events)
            return input, target
    
    def fit(self, X, y, conf: Optional[TrainConf] = None):
        if conf is None:
            conf = TrainConf()

        # Get the timestamps/cuts
        if self.timestamps is not None:
            cuts = self.timestamps
        else:
            # If the cuts are unspecified, we guess it by dry-running the net 
            # and checking the dimensionality of the output space.
            with torch.no_grad():
                self.net.eval()
                num_features = X.shape[1]
                dummy = torch.empty((1, num_features), dtype=torch.float32)
                res = self.net(dummy)[0]
                self.net.train()
            cuts = res.shape[0]
        
        self.tf = pycox_models.DeepHitSingle.label_transform(cuts)

        # Encode X and y to a form acceptable for DeepHitSingle.
        
        input, target = self.encode(X, y, fit=True)
            
        self._model = pycox_models.DeepHitSingle(
            net=self.net,
            optimizer=conf.optimizer_fn(self.net),
            device=conf.device,
            duration_index=self.tf.cuts,
            alpha=self.alpha,
            sigma=self.sigma,
        )

        if conf.val_data is not None:
            val_X, val_y = conf.val_data
            val_data = self.encode(val_X, val_y)
        else:
            val_data = None

        self._log = self._model.fit(input, target, conf.batch_size, conf.epochs,
            conf.callbacks, conf.verbose, conf.num_workers, conf.shuffle,
            conf.metrics, val_data, conf.val_batch_size)
        
        self.event_times_ = self.tf.cuts
        
        return self
    
    def predict_surv_df(self, X):
        X = torch.tensor(np.asarray(X), dtype=torch.float32)
        return self._model.predict_surv_df(X).astype("float32")
    
    def predict_survival_function(self, X, return_array=False):
        """Predict survival function. See sksurv models for more details."""

        surv_df = self.predict_surv_df(X)

        event_times_ = surv_df.index.values
        sf_values = surv_df.T.values

        if return_array:
            return sf_values
        else:
            return np.array([
                StepFunction(event_times_, values)
                for values in sf_values
            ])

    def predict_cumulative_hazard_function(self, X, return_array=False):
        raise NotImplementedError
    
    def score(self, X, y):
        surv = self.predict_surv_df(X)

        with catch_warnings():
            simplefilter("ignore")
            eval = EvalSurv(
                surv=surv,
                durations=y["duration"],
                events=y["event"],
                censor_surv="km",
            )
            return eval.concordance_td(method="antolini")

## Attribution methods

In [3]:
class Explanation:
    def __init__(self, var_attrs: pd.DataFrame, experiment_name: str):
        self.var_attrs = var_attrs
        self.experiment_name = experiment_name
    
    def __len__(self):
        return len(self.var_attrs)

    def __getitem__(self, idx):
        slice_ = self.var_attrs.iloc[idx]
        if isinstance(slice_, pd.Series):
            slice_ = slice_.to_frame()
        return Explanation(slice_, self.experiment_name)
    
    def plot(self):
        if len(self) != 1:
            raise RuntimeError("Plotting (at the moment) works only for single examples.")
        attrs = self.var_attrs.iloc[0]
        
        fig = go.Figure()
        for var in sorted(attrs.index):
            var_attr = attrs[var]
            fig.add_trace(go.Scatter(
                x=var_attr.x, y=var_attr.y,
                mode="lines",
                line=dict(shape="hv"),
                name=var,
            )) 
        
        fig.update_layout(
            xaxis_title="Time",
            yaxis_title="Contribution",
            title=self.experiment_name,
        )
        
        return fig
    
    def __str__(self):
        return str(self.var_attrs)

### SurvSHAP(t) Adapter

In [4]:
class SurvShapExplainer:
    """A bit more shap-esque Explainer wrapper for SurvSHAP."""

    def __init__(self, model, data=None, y=None, calculation_method="kernel", aggregation_method="integral", path="average", B=25, random_state=42, pbar=False):
        self.model = model
        self.calculation_method = calculation_method
        self.aggregation_method = aggregation_method
        self.path = path
        self.B = B
        self.random_state = random_state
        self.exp = SurvivalModelExplainer(self.model, data=data, y=y)
        self.pbar = pbar
        self.experiment_name = type(self).__name__ + " " + self.model.__class__.__name__

    def __call__(self, observations: pd.DataFrame, timestamps=None):
        if timestamps is None:
            timestamps = self.model.event_times_

        skip = ["variable_str", "variable_name", "variable_value", "B", "aggregated_change", "index"]

        all_results = []
        
        idx_seq = range(len(observations))
        if self.pbar:
            idx_seq = tqdm(idx_seq)
        
        for idx in idx_seq:
            obs = observations.iloc[[idx]]
            shap = PredictSurvSHAP(
                calculation_method=self.calculation_method,
                aggregation_method=self.aggregation_method,
                path=self.path,
                B=self.B,
                random_state=self.random_state,
            )

            shap.fit(self.exp, obs, timestamps)
            obs_df = shap.result
            obs_df.insert(len(skip)-1, "index", idx)
            all_results.append(obs_df)
        
        res_df = pd.concat(all_results)
        self.result = res_df
            
        g = res_df.groupby(by="variable_name")

        var_attr_values = {}
        for var in g.groups:
            grp: pd.DataFrame = g.get_group(var)
            grp = grp.sort_values(by=["index"])
            attr_values = grp.iloc[:,len(skip):].values
            var_attr_values[var] = [
                StepFunction(timestamps, attr_values_)
                for attr_values_ in attr_values
            ]
            
        res_df = pd.DataFrame(var_attr_values)
        res_df = res_df.set_index(observations.index, drop=True)
        return Explanation(res_df, self.experiment_name)

### DeepLift Adapter

In [5]:
from contextlib import contextmanager


@contextmanager
def eval_ctx(net: nn.Module):
    prev_val = net.training
    net.train(mode=False)
    try:
        yield net
    finally:
        net.train(mode=prev_val)


class DeepHit_SurvOut(nn.Module):
    def __init__(self, deephit: DeepHitSingle):
        super().__init__()
        self.deephit = deephit
    
    def forward(self, x):
        with eval_ctx(self.deephit.net):
            preds = self.deephit.net(x)
            preds = preds.view(len(preds), -1)
            zero_pad = torch.zeros(len(preds), 1)
            pmf = torch.cat([preds, zero_pad], dim=1).softmax(dim=1)[:, :-1]
            sf = 1.0 - pmf.cumsum(dim=1)
            return sf


class DeepLiftExplainer:
    """A shap-esque wrapper for DeepLift. See SurvShapExplainer for more
     details."""
     
    def __init__(self, model: DeepHitSingle, data=None, y=None, baseline="mean"):
        self.model = model
        self.baselines = None
        if data is not None and baseline is not None:
            data = torch.tensor(data.values, dtype=torch.float32)
            if baseline == 'zero':
                self.baselines = torch.zeros(size=data.shape[1:])
            if baseline == 'mean':
                self.baselines = data.mean(dim=0)
            elif baseline == 'median':
                self.baselines = data.median(dim=0).values
            else:
                print('Not implemented baseline method!')

        self.experiment_name = type(self).__name__ + " " + self.model.__class__.__name__
    
    def __call__(self, X: pd.DataFrame):
        inputs = torch.tensor(X.values, dtype=torch.float32)
        if self.baselines is not None:
            baselines = self.baselines.broadcast_to(inputs.shape)
        else:
            baselines = None

        deep_lift = DeepLift(DeepHit_SurvOut(self.model))

        with catch_warnings():
            simplefilter("ignore")
            attr_values = []
            for idx in range(len(self.model.event_times_)):
                attrs = deep_lift.attribute(inputs, baselines, target=idx)
                attr_values.append(attrs)
            attr_values = torch.stack(attr_values, dim=2).detach().numpy()

        var_attr_values = {}
        for var_idx, var in enumerate(X.columns):
            var_attr_values[var] = [
                StepFunction(
                    x=self.model.event_times_,
                    y=attr_values[obs_idx,var_idx],
                ) for obs_idx in range(len(X))
            ]
        
        return Explanation(pd.DataFrame(var_attr_values), self.experiment_name)

### DeepLiftShap Adapter

In [6]:
class DeepLiftShapExplainer:
    """A shap-esque wrapper for DeepLiftShap. See SurvShapExplainer for more
     details."""
     
    def __init__(self, model: DeepHitSingle, data=None, y=None):
        self.model = model
        if data is not None:
            self.baselines = torch.tensor(data.values, dtype=torch.float32)
        else:
            self.baselines = None
        self.experiment_name = type(self).__name__ + " " + self.model.__class__.__name__
    
    def __call__(self, X: pd.DataFrame):
        inputs = torch.tensor(X.values, dtype=torch.float32)
        if self.baselines is not None:
            baselines = self.baselines
        else:
            baselines = inputs

        deep_lift_shap = DeepLiftShap(DeepHit_SurvOut(self.model))

        with catch_warnings():
            simplefilter("ignore")
            attr_values = []
            for idx in range(len(self.model.event_times_)):
                attrs = deep_lift_shap.attribute(
                    inputs, baselines, target=idx)
                attr_values.append(attrs)
            attr_values = torch.stack(attr_values, dim=2).detach().numpy()

        var_attr_values = {}
        for var_idx, var in enumerate(X.columns):
            var_attr_values[var] = [
                StepFunction(
                    x=self.model.event_times_,
                    y=attr_values[obs_idx,var_idx],
                ) for obs_idx in range(len(X))
            ]
        
        return Explanation(pd.DataFrame(var_attr_values), self.experiment_name)

### Integrated Gradients (IG) Adapter

In [7]:
class IGExplainer:
    def __init__(self, model: DeepHitSingle, data=None, y=None):
        self.model = model
        if data is not None:
            data = torch.tensor(data.values, dtype=torch.float32)
            self.baselines = data.mean(dim=0)
        else:
            self.baselines = None
        self.experiment_name = type(self).__name__ + " " + self.model.__class__.__name__
    
    def __call__(self, X: pd.DataFrame):
        inputs = torch.tensor(X.values, dtype=torch.float32)
        if self.baselines is not None:
            baselines = self.baselines.broadcast_to(inputs.shape)
        else:
            baselines = inputs.mean(dim=0).broadcast_to(inputs.shape)

        ig = IntegratedGradients(DeepHit_SurvOut(self.model))

        with catch_warnings():
            simplefilter("ignore")
            attr_values = []
            for idx in range(len(self.model.event_times_)):
                attrs = ig.attribute(inputs, baselines, target=idx)
                attr_values.append(attrs)
            attr_values = torch.stack(attr_values, dim=2).detach().numpy()

        var_attr_values = {}
        for var_idx, var in enumerate(X.columns):
            var_attr_values[var] = [
                StepFunction(
                    x=self.model.event_times_,
                    y=attr_values[obs_idx,var_idx],
                ) for obs_idx in range(len(X))
            ]
        
        return Explanation(pd.DataFrame(var_attr_values), self.experiment_name)

## Checking if the SurvSHAP adapter works fine

### Data - `exp1_data.csv`

In [8]:
def exp1_csv():
    df = pd.read_csv("data/exp1_data.csv")
    df = df.rename(columns={"time": "duration"})
    return df

exp1_df = exp1_csv()

In [9]:
df = exp1_df

X, y = get_x_y(df, attr_labels=["event", "duration"], pos_label=1)
train_X = test_X = X
train_y = test_y = y

# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# train_X, train_y = get_x_y(train_df, 
#     attr_labels=["event", "duration"], pos_label=1)
# test_X, test_y = get_x_y(test_df, 
#     attr_labels=["event", "duration"], pos_label=1)

In [10]:
def eval_model(model, name):
    print(f"[{name}] Train score: {model.score(train_X, train_y)}")
    print(f"[{name}] Test score: {model.score(test_X, test_y)}")

### Linear model (`CoxPHSurvivalAnalysis`)

In [11]:
cph = CoxPHSurvivalAnalysis()
cph = cph.fit(train_X, train_y)
eval_model(cph, "cph")

[cph] Train score: 0.622925665249838
[cph] Test score: 0.622925665249838


### Random survival forest

In [12]:
rsf = RandomSurvivalForest(
    random_state=42,
    n_estimators=100,
    min_samples_split=8,
    min_samples_leaf=4,
    max_features=3,
    max_samples=0.8,
)

rsf = rsf.fit(train_X, train_y)
eval_model(rsf, "rsf")

[rsf] Train score: 0.8095397887950072
[rsf] Test score: 0.8095397887950072


### DeepHit

In [13]:
net = MLPVanilla(train_X.shape[1], [32, 32], 91, True, 0.1)
deephit_ = DeepHitSingle(net, alpha=0.2, sigma=0.1)

x_train, x_val, y_train, y_val = \
    train_test_split(train_X, train_y, test_size=0.2, random_state=42)

conf = TrainConf(
    optimizer_fn=lambda net: torch.optim.Adam(net.parameters(), lr=1e-2),
    epochs=100,
    callbacks=[tt.callbacks.EarlyStopping()],
    batch_size=256,
    val_data=(x_val, y_val),
    verbose=True,
)

deephit_ = deephit_.fit(x_train, y_train, conf=conf)
eval_model(deephit_, "deephit")

0:	[0s / 0s],		train_loss: 0.9295,	val_loss: 1.0969
1:	[0s / 0s],		train_loss: 0.8599,	val_loss: 1.0366
2:	[0s / 0s],		train_loss: 0.8452,	val_loss: 1.0123
3:	[0s / 0s],		train_loss: 0.8162,	val_loss: 0.9940
4:	[0s / 0s],		train_loss: 0.8010,	val_loss: 0.9795
5:	[0s / 0s],		train_loss: 0.7840,	val_loss: 0.9701
6:	[0s / 0s],		train_loss: 0.7694,	val_loss: 0.9676
7:	[0s / 0s],		train_loss: 0.7783,	val_loss: 0.9690
8:	[0s / 0s],		train_loss: 0.7574,	val_loss: 0.9713
9:	[0s / 0s],		train_loss: 0.7441,	val_loss: 0.9740
10:	[0s / 0s],		train_loss: 0.7416,	val_loss: 0.9887
11:	[0s / 0s],		train_loss: 0.7441,	val_loss: 0.9912
12:	[0s / 0s],		train_loss: 0.7182,	val_loss: 0.9888
13:	[0s / 0s],		train_loss: 0.7166,	val_loss: 0.9930
14:	[0s / 0s],		train_loss: 0.7095,	val_loss: 0.9948
15:	[0s / 0s],		train_loss: 0.7799,	val_loss: 1.0140
16:	[0s / 0s],		train_loss: 0.7082,	val_loss: 1.1291
[deephit] Train score: 0.6492804001537288
[deephit] Test score: 0.6492804001537288


### Plots 

In [14]:
from contextlib import contextmanager
import time

exec_times = {}

@contextmanager
def measure_exec_time(experiment_name: str):
    try:
        before = time.perf_counter()
        yield
    finally:
        after = time.perf_counter()
        if experiment_name not in exec_times:
            exec_times[experiment_name] = []
        exec_times[experiment_name].append(after - before)
        print(f"{experiment_name} --- Exec time: {after - before:.2f}s")


def run_explainer_n_times(explainer, observations: pd.DataFrame, gt_duration: float, dataset_name: str, n: int = 10):
    for _ in range(n):
        with measure_exec_time(dataset_name + " " + explainer.experiment_name):
            explanation = explainer(observations)
    
    fig = explanation.plot()
    fig.add_vline(x=gt_duration, line_dash="dash", annotation_text=f"{gt_duration:.3f}")
    fig.write_image(f"imgs/{dataset_name}_{explainer.experiment_name}.png")
    fig.show()

    return explanation

In [15]:
cph_expl = SurvShapExplainer(cph, test_X, test_y)

cph_expl0 = run_explainer_n_times(cph_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1", n=10)

exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.56s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.52s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.58s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.58s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.59s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.66s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.48s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.47s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.47s
exp1 SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 2.59s


In [16]:
cph_expl0.var_attrs.head()

Unnamed: 0,x1,x2,x3,x4,x5
690,"StepFunction(x=array([1.26006939e-02, 2.576424...","StepFunction(x=array([1.26006939e-02, 2.576424...","StepFunction(x=array([1.26006939e-02, 2.576424...","StepFunction(x=array([1.26006939e-02, 2.576424...","StepFunction(x=array([1.26006939e-02, 2.576424..."


In [17]:
rsf_expl = SurvShapExplainer(rsf, test_X, test_y)

rsf_expl0 = run_explainer_n_times(rsf_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1", n=10)

exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.59s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.32s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.30s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.23s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.71s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.53s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.25s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.27s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.54s
exp1 SurvShapExplainer RandomSurvivalForest --- Exec time: 10.25s


In [18]:
dh_expl = SurvShapExplainer(deephit_, test_X, test_y)

dh_expl0 = run_explainer_n_times(dh_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1", n=10)

exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.62s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.48s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.52s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.46s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.97s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.39s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.34s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.38s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.36s
exp1 SurvShapExplainer DeepHitSingle --- Exec time: 1.42s


In [19]:
lift_expl = DeepLiftExplainer(deephit_, test_X, test_y, "mean")

lift_expl0 = run_explainer_n_times(lift_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1_mean", n=10)

exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.39s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.48s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.38s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.46s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.37s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.39s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.43s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.39s
exp1_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.50s


In [20]:
lift_expl = DeepLiftExplainer(deephit_, test_X, test_y, "median")

lift_expl0 = run_explainer_n_times(lift_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1_median", n=10)

exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.45s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.35s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.45s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.43s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.43s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.59s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.60s
exp1_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.59s


In [21]:
lift_expl = DeepLiftExplainer(deephit_, test_X, test_y, "zero")

lift_expl0 = run_explainer_n_times(lift_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1_zero", n=10)

Not implemented baseline method!
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.61s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.51s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.67s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.57s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.77s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.38s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.36s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.35s
exp1_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.34s


In [22]:
dls_expl = DeepLiftShapExplainer(deephit_, test_X, test_y)

dls_expl0 = run_explainer_n_times(dls_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1", n=10)

exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.30s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 3.32s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.97s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.84s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.56s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.13s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.21s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 3.31s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.58s
exp1 DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.65s


In [23]:
ig_expl = IGExplainer(deephit_, test_X, test_y)

ig_expl0 = run_explainer_n_times(ig_expl, test_X.iloc[[690]], test_y[690]["duration"], "exp1", n=10)

exp1 IGExplainer DeepHitSingle --- Exec time: 1.06s
exp1 IGExplainer DeepHitSingle --- Exec time: 0.90s
exp1 IGExplainer DeepHitSingle --- Exec time: 1.03s
exp1 IGExplainer DeepHitSingle --- Exec time: 0.91s
exp1 IGExplainer DeepHitSingle --- Exec time: 0.74s
exp1 IGExplainer DeepHitSingle --- Exec time: 0.85s
exp1 IGExplainer DeepHitSingle --- Exec time: 0.84s
exp1 IGExplainer DeepHitSingle --- Exec time: 0.93s
exp1 IGExplainer DeepHitSingle --- Exec time: 1.45s
exp1 IGExplainer DeepHitSingle --- Exec time: 0.87s


## Real-world case: METABRIC dataset

Note: data and training protocol taken from [example pycox notebook](https://nbviewer.org/github/havakv/pycox/blob/master/examples/deephit.ipynb).

In [24]:
np.random.seed(1234)
_ = torch.manual_seed(123)

### Data

In [25]:
df_train = metabric.read_df()
df_train = df_train.rename(columns={
    "x0": "MK167",
    "x1": "EGFR",
    "x2": "PGR",
    "x3": "ERBB2",
    "x4": "hormone_therapy",
    "x5": "radiotherapy",
    "x6": "chemotherapy",
    "x7": "ER_positive",
    "x8": "age_at_diagnosis",
})

df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

In [26]:
cols_standardize = ['MK167', 'EGFR', 'PGR', 'ERBB2', 'age_at_diagnosis']
cols_leave = ['hormone_therapy', 'radiotherapy', 'chemotherapy', 'ER_positive']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)

def pycox_get_x_y(df):
    values = x_mapper.transform(df).astype("float32")
    X = pd.DataFrame(values, columns=x_mapper.transformed_names_)
    _, y = get_x_y(df, ["event", "duration"], 1)
    return X, y

In [27]:
x_mapper.fit(df_train)

x_train, y_train = pycox_get_x_y(df_train)
x_val, y_val = pycox_get_x_y(df_val)
x_test, y_test = pycox_get_x_y(df_test)

In [28]:
net = MLPVanilla(x_train.shape[1], [32, 32], 91, True, 0.1)
deephit_ = DeepHitSingle(net, alpha=0.2, sigma=0.1)

In [29]:
conf = TrainConf(
    optimizer_fn=lambda net: torch.optim.Adam(net.parameters(), lr=1e-2),
    epochs=100,
    callbacks=[tt.callbacks.EarlyStopping()],
    batch_size=256,
    val_data=(x_val, y_val),
    verbose=True,
)

deephit_ = deephit_.fit(x_train, y_train, conf=conf)

0:	[0s / 0s],		train_loss: 0.8832,	val_loss: 0.7864
1:	[0s / 0s],		train_loss: 0.8432,	val_loss: 0.7837
2:	[0s / 0s],		train_loss: 0.8172,	val_loss: 0.7836
3:	[0s / 0s],		train_loss: 0.7960,	val_loss: 0.7887
4:	[0s / 0s],		train_loss: 0.7858,	val_loss: 0.7919
5:	[0s / 0s],		train_loss: 0.7728,	val_loss: 0.8005
6:	[0s / 0s],		train_loss: 0.7624,	val_loss: 0.8068
7:	[0s / 1s],		train_loss: 0.7568,	val_loss: 0.8110
8:	[0s / 1s],		train_loss: 0.7483,	val_loss: 0.8168
9:	[0s / 1s],		train_loss: 0.7416,	val_loss: 0.8229
10:	[0s / 1s],		train_loss: 0.7371,	val_loss: 0.8288
11:	[0s / 1s],		train_loss: 0.7287,	val_loss: 0.8369
12:	[0s / 2s],		train_loss: 0.7215,	val_loss: 0.8377


In [30]:
deephit_.score(x_test, y_test)

0.6863156556021874

### Linear model

In [31]:
cph = CoxPHSurvivalAnalysis()
cph.fit(x_train, y_train)
cph.score(x_test, y_test)

0.6503878926618339

In [32]:
rsf = RandomSurvivalForest(
    random_state=42,
    n_estimators=100,
    min_samples_split=8,
    min_samples_leaf=4,
    max_features=3,
    max_samples=0.8,
)
rsf = rsf.fit(x_train, y_train)
rsf.score(x_test, y_test)

0.6550722794522871

### Explanations

In [33]:
rng = np.random.default_rng(119)

# Select only finished samples.
finished_ids = np.where(y_test["event"] == 1)[0]

sample_ids = rng.choice(finished_ids, size=10, replace=False)
sample_ids

array([244,  57, 305, 289, 117, 356, 272, 212, 161, 222])

In [34]:
cph_expl = SurvShapExplainer(cph, x_train, y_train)

cph_expl0 = run_explainer_n_times(cph_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric", n=10)

metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 75.14s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 83.16s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 86.07s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 85.98s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 85.55s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 83.21s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 81.11s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 81.30s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 81.18s
metabric SurvShapExplainer CoxPHSurvivalAnalysis --- Exec time: 81.21s


In [35]:
rsf_expl = SurvShapExplainer(rsf, x_train, y_train)

rsf_expl0 = run_explainer_n_times(rsf_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric", n=10)

metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 177.30s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 176.21s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 176.32s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 186.59s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 193.36s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 189.92s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 190.15s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 188.12s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 189.36s
metabric SurvShapExplainer RandomSurvivalForest --- Exec time: 188.56s


In [36]:
dh_expl = SurvShapExplainer(deephit_, x_train, y_train)

dh_expl0 = run_explainer_n_times(dh_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric", n=10)

metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.59s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.65s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.91s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.91s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.60s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.92s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.77s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.68s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.93s
metabric SurvShapExplainer DeepHitSingle --- Exec time: 22.94s


In [37]:
lift_expl = DeepLiftExplainer(deephit_, x_train, y_train, "mean")

lift_expl0 = run_explainer_n_times(lift_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric_mean", n=10)

metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.76s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.53s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.45s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.38s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.38s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.47s
metabric_mean DeepLiftExplainer DeepHitSingle --- Exec time: 0.51s


In [38]:
lift_expl = DeepLiftExplainer(deephit_, x_train, y_train, "median")

lift_expl0 = run_explainer_n_times(lift_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric_median", n=10)

metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.39s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.49s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.49s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.50s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.36s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.38s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.45s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.40s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.50s
metabric_median DeepLiftExplainer DeepHitSingle --- Exec time: 0.41s


In [39]:
lift_expl = DeepLiftExplainer(deephit_, x_train, y_train, "zero")

lift_expl0 = run_explainer_n_times(lift_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric_zero", n=10)

Not implemented baseline method!
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.59s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.50s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.42s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.48s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.37s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.19s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.53s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.37s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.41s
metabric_zero DeepLiftExplainer DeepHitSingle --- Exec time: 0.41s


In [40]:
dls_expl = DeepLiftShapExplainer(deephit_, x_train, y_train)

dls_expl0 = run_explainer_n_times(dls_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric", n=10)

metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 3.00s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 3.04s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.89s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.97s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.86s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.93s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.88s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.89s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.86s
metabric DeepLiftShapExplainer DeepHitSingle --- Exec time: 2.81s


In [41]:
ig_expl = IGExplainer(deephit_, x_train, y_train)

ig_expl0 = run_explainer_n_times(ig_expl, x_test.iloc[[sample_ids[0]]], y_test[sample_ids[0]]["duration"], "metabric", n=10)

metabric IGExplainer DeepHitSingle --- Exec time: 0.90s
metabric IGExplainer DeepHitSingle --- Exec time: 0.84s
metabric IGExplainer DeepHitSingle --- Exec time: 0.99s
metabric IGExplainer DeepHitSingle --- Exec time: 0.86s
metabric IGExplainer DeepHitSingle --- Exec time: 0.83s
metabric IGExplainer DeepHitSingle --- Exec time: 0.77s
metabric IGExplainer DeepHitSingle --- Exec time: 0.82s
metabric IGExplainer DeepHitSingle --- Exec time: 0.81s
metabric IGExplainer DeepHitSingle --- Exec time: 0.84s
metabric IGExplainer DeepHitSingle --- Exec time: 0.88s


## Evaluation measures

In [42]:
def aggregate_step_function(step_function: StepFunction, method: str):
    if method == "sum_of_squares":
        return np.sum(np.square(step_function.y))
    elif method == "max":
        return np.max(step_function.y)
    elif method == "mean":
        return np.mean(step_function.y)
    elif method == "integral":
        return np.trapz(step_function.y, step_function.x)

def aggregate_explanation_single(attr_vars: pd.Series, method: str):
    aggregated_attr = {}
    for var in attr_vars.index:
        aggregated_attr[var] = aggregate_step_function(attr_vars[var], method)
    return aggregated_attr

def aggregate_explanation(explanation: Explanation, method: str):
    aggregated_attr = {var: [] for var in sorted(explanation.var_attrs.iloc[0].index)}
    for _, row in explanation.var_attrs.iterrows():
        for var, value in aggregate_explanation_single(row, method).items():
            aggregated_attr[var].append(value)
    return pd.DataFrame.from_dict(aggregated_attr)

In [43]:
@contextmanager
def train_ctx(net: nn.Module):
    prev_val = net.training
    net.train(mode=True)
    try:
        yield net
    finally:
        net.train(mode=prev_val)

class DeepHitForwardWrapper(nn.Module):
    def __init__(self, deephit):
        super().__init__()
        self.deephit = deephit

    def forward(self, x):
        with train_ctx(self.deephit.net):
            preds = self.deephit.net(x)
            return preds

In [44]:
def run_evaluation(metric, aggregation_method:str, deephit: DeepHitSingle, x_batch: pd.DataFrame, y_batch: np.ndarray, explainer):
    columns = x_batch.columns
    def explain_func(model, inputs, targets, **kwargs):
        inputs = pd.DataFrame(data=inputs, columns=columns)
        explanation = explainer(inputs)
        aggregated = aggregate_explanation(explanation, aggregation_method)

        return aggregated.values

    time_len = len(deephit.event_times_)
    batch_size = len(x_batch)
    model = DeepHitForwardWrapper(deephit)

    x_batch = x_batch.values
    _, y = deephit.encode(x_batch, y_batch, fit=True)
    durations, events = y

    res = metric(
        model=model,
        x_batch=x_batch,
        y_batch=y_batch,
        explain_func=explain_func,
    )
    return res

In [45]:
lift_expl = DeepLiftExplainer(deephit_, x_train, y_train)
metric = quantus.MaxSensitivity(
    nr_samples=5,
    lower_bound=0.2,
)
lift_max_sens = run_evaluation(metric, "integral", deephit_, x_test.iloc[sample_ids], y_test[sample_ids], lift_expl)

 (1) The Max Sensitivity metric is likely to be sensitive to the choice of amount of noise added 'lower_bound' and 'upper_bound', the number of samples iterated over 'nr_samples', the function to perturb the input 'perturb_func', the similarity metric 'similarity_func' as well as norm calculations on the numerator and denominator of the sensitivity equation i.e., 'norm_numerator' and 'norm_denominator'.  
 (2) If attributions are normalised or their absolute values are taken it may destroy or skew information in the explanation and as a result, affect the overall evaluation outcome.
 (3) Make sure to validate the choices for hyperparameters of the metric (by calling .get_params of the metric instance).
 (4) For further information, see original publication: Yeh, Chih-Kuan, et al. 'On the (in) fidelity and sensitivity for explanations.' arXiv preprint arXiv:1901.09392 (2019).




The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.



In [47]:
surv_expl = SurvShapExplainer(deephit_, x_train, y_train)
metric = quantus.MaxSensitivity(
    nr_samples=5,
    lower_bound=0.2,
)
surv_max_sens = run_evaluation(metric, "integral", deephit_, x_test.iloc[sample_ids], y_test[sample_ids], surv_expl)

 (1) The Max Sensitivity metric is likely to be sensitive to the choice of amount of noise added 'lower_bound' and 'upper_bound', the number of samples iterated over 'nr_samples', the function to perturb the input 'perturb_func', the similarity metric 'similarity_func' as well as norm calculations on the numerator and denominator of the sensitivity equation i.e., 'norm_numerator' and 'norm_denominator'.  
 (2) If attributions are normalised or their absolute values are taken it may destroy or skew information in the explanation and as a result, affect the overall evaluation outcome.
 (3) Make sure to validate the choices for hyperparameters of the metric (by calling .get_params of the metric instance).
 (4) For further information, see original publication: Yeh, Chih-Kuan, et al. 'On the (in) fidelity and sensitivity for explanations.' arXiv preprint arXiv:1901.09392 (2019).




The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.


The settings for perturbing input e.g., 'perturb_func' didn't cause change in input. Reconsider the parameter settings.



In [49]:
max_sens = pd.DataFrame({"DeepLift": lift_max_sens, "SurvSHAP": surv_max_sens})

fig = px.box(max_sens, points="all")
fig = fig.update_layout(
    xaxis_title="Explainer",
    yaxis_title="Max sensitivity",
    title=f"Max sensitivity of explainers on Metabric dataset for 5 samples",
)
fig.write_image(f"imgs/max_sens_metabric.png")
fig.show()

# Estimation error

In [50]:
survshap_expl = SurvShapExplainer(deephit_, x_train, y_train)
lift_expl = DeepLiftExplainer(deephit_, x_train, y_train)

survshap_expl0 = survshap_expl(x_test.iloc[sample_ids])
lift_expl0 = lift_expl(x_test.iloc[sample_ids])

### Aggregated measures

In [67]:
def plot_explanation(explanation: Explanation, method: str, title: str):
    aggregated = aggregate_explanation(explanation, method)
    x_len = len(aggregated.columns)

    fig = px.box(
        x=np.repeat(np.arange(x_len), len(aggregated)),
        y=aggregated.values.flatten("F"),
        points=False,
    )
    
    for i in range(len(aggregated)):
        fig.add_trace(go.Scatter(x=np.arange(x_len) - (0.3 + i / (3 * x_len)), y=aggregated.iloc[i].values, mode='markers', name=f"Sample {i}"))

    fig = fig.update_layout(
        xaxis_title="Variable",
        yaxis_title=f"Aggregated explanation {method}",
        title=title,
        font=dict(size=25),
    )


    fig = fig.update_xaxes(type="linear", ticktext=aggregated.columns, tickvals=np.arange(x_len))
    return fig

In [68]:
for method in ["sum_of_squares", "max", "mean", "integral"]:
    fig = plot_explanation(survshap_expl0, method, f"SurvSHAP {method} on Metabric dataset for 5 samples")
    fig.write_image(f"imgs/survshap_{method}_metabric.png", width=1400, height=1000)
    fig.show()

In [69]:
for method in ["sum_of_squares", "max", "mean", "integral"]:
    fig = plot_explanation(lift_expl0, method, f"DeepLift {method} on Metabric dataset for 5 samples")
    fig.write_image(f"imgs/deeplift_{method}_metabric.png", width=1400, height=1000)
    fig.show()

### Per variable explanations

In [54]:
def create_comparison(expl1, expl2, transform, aggregate):
    diffs = {}
    if aggregate is None:
        aggregate = lambda x: x

    for var in expl1.var_attrs.iloc[0].index:
        val = []
        for sample_idx in range(len(expl1)):
            val.append(
                transform(expl2.var_attrs.iloc[sample_idx][var].y, expl1.var_attrs.iloc[sample_idx][var].y)
            )
        val = aggregate(val)
        diffs[var] = [
            StepFunction(
                x=expl1.var_attrs.iloc[sample_idx][var].x,
                y=v) for v in val
        ]

    diff_expl = pd.DataFrame(diffs)
    return diff_expl


diff_expl_lift0_survshap0 = create_comparison(lift_expl0, survshap_expl0, lambda y1, y2: np.abs(y1 - y2), None)
diff_expl_lift0_survshap0

Unnamed: 0,MK167,EGFR,PGR,ERBB2,age_at_diagnosis,hormone_therapy,radiotherapy,chemotherapy,ER_positive
0,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
1,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
2,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
3,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
4,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
5,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
6,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
7,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
8,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."
9,"StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666...","StepFunction(x=array([ 0. , 3.946666..."


In [70]:
fig = Explanation(diff_expl_lift0_survshap0.iloc[[0]], f"Absolute contribution difference").plot()
fig = fig.update_layout(
    yaxis_title="Contribution difference",
)
fig.write_image(f"imgs/abs_distance_metabric.png", width=1400, height=1000)
fig.show()

In [71]:
diff_expl_lift0_survshap0_mean = create_comparison(lift_expl0, survshap_expl0, lambda x1, x2: np.abs(x1 - x2), lambda x: [np.mean(x, axis=0)])
fig = Explanation(diff_expl_lift0_survshap0_mean, f"Mean absolute contribution difference").plot()
fig.write_image(f"imgs/abs_distance_mean_metabric.png", width=1400, height=1000)
fig.show()

In [57]:
def compare_features(expl1: Explanation, expl2: Explanation, gt: np.ndarray, idx: int):
    gt_val = gt["duration"][idx]
    for var in expl1.var_attrs.iloc[idx].index:
        df = pd.DataFrame({expl1.experiment_name: [expl1.var_attrs.iloc[idx][var]], expl2.experiment_name: [expl2.var_attrs.iloc[idx][var]]})
        fig = Explanation(df, var).plot()
        fig.add_vline(x=gt_val, line_dash="dash", annotation_text=f"{gt_val:.3f}")
        fig.write_image(f"imgs/feature_{var}_metabric.png", width=1400, height=1000)
        fig.show()

compare_features(lift_expl0, survshap_expl0, y_test[sample_ids], 2)

## Execution times

In [58]:
exec_times_df = pd.DataFrame.from_dict(exec_times)
exec_times_df.to_csv("exec_times.csv")
exec_times_df

Unnamed: 0,exp1 SurvShapExplainer CoxPHSurvivalAnalysis,exp1 SurvShapExplainer RandomSurvivalForest,exp1 SurvShapExplainer DeepHitSingle,exp1_mean DeepLiftExplainer DeepHitSingle,exp1_median DeepLiftExplainer DeepHitSingle,exp1_zero DeepLiftExplainer DeepHitSingle,exp1 DeepLiftShapExplainer DeepHitSingle,exp1 IGExplainer DeepHitSingle,metabric SurvShapExplainer CoxPHSurvivalAnalysis,metabric SurvShapExplainer RandomSurvivalForest,metabric SurvShapExplainer DeepHitSingle,metabric_mean DeepLiftExplainer DeepHitSingle,metabric_median DeepLiftExplainer DeepHitSingle,metabric_zero DeepLiftExplainer DeepHitSingle,metabric DeepLiftShapExplainer DeepHitSingle,metabric IGExplainer DeepHitSingle
0,2.561369,10.585525,1.623114,0.385092,0.451977,0.607862,2.302496,1.055441,75.135091,177.302558,22.591857,0.756137,0.393561,0.590465,2.997532,0.895035
1,2.520087,10.316211,1.482779,0.481881,0.350663,0.512398,3.319178,0.902482,83.164011,176.205935,22.646538,0.531502,0.493903,0.495049,3.035184,0.838788
2,2.580824,10.295043,1.516404,0.384888,0.448958,0.671747,2.972099,1.025255,86.067206,176.32306,22.911996,0.453498,0.488762,0.416241,2.89188,0.994426
3,2.576405,10.234228,1.456215,0.462075,0.416439,0.566846,2.843245,0.910263,85.98029,186.588958,22.91097,0.416207,0.50114,0.480252,2.972674,0.856005
4,2.585601,10.714177,1.968473,0.369551,0.428007,0.767199,2.558227,0.741806,85.55326,193.35728,22.596443,0.419481,0.359435,0.371863,2.864392,0.833929
5,2.657259,10.529349,1.39198,0.385591,0.415472,0.415298,2.128949,0.84961,83.21137,189.915866,22.918879,0.41688,0.377752,0.193326,2.931528,0.773493
6,2.484493,10.245734,1.341215,0.423269,0.433153,0.381952,2.210467,0.844175,81.11306,190.153652,22.770606,0.378079,0.445929,0.533699,2.877982,0.820858
7,2.469136,10.274786,1.383913,0.426226,0.588979,0.363744,3.310922,0.92627,81.30336,188.120775,22.678848,0.379842,0.404061,0.371138,2.892711,0.805587
8,2.471727,10.539805,1.364512,0.385686,0.597103,0.349649,2.580234,1.454991,81.183632,189.361949,22.931849,0.472052,0.499057,0.407826,2.856459,0.843061
9,2.590123,10.245479,1.422351,0.49719,0.591546,0.342797,2.647926,0.865444,81.209492,188.561412,22.935561,0.508191,0.413716,0.411744,2.811036,0.875379


In [63]:
def plot_exec_times(dataset_name: str, dataset_desc: str):
    cols = exec_times_df.columns
    cols_dataset = [col for col in cols if dataset_name in col]
    exec_times_df_copy = exec_times_df[cols_dataset].copy()
    exec_times_df_copy = exec_times_df_copy.drop(columns=[
        f"{dataset_name} SurvShapExplainer RandomSurvivalForest",
        f"{dataset_name} SurvShapExplainer CoxPHSurvivalAnalysis",
        f"{dataset_name}_zero DeepLiftExplainer DeepHitSingle",
        f"{dataset_name}_median DeepLiftExplainer DeepHitSingle"
    ])
    exec_times_df_copy.columns = [col.split(" ")[1] for col in exec_times_df_copy.columns]

    fig = px.box(exec_times_df_copy, points="all")
    fig = fig.update_layout(
        xaxis_title="Explainer",
        yaxis_title="Time (s)",
        title=f"Execution time of explainer methods on a single sample for {dataset_desc}",
    )
    fig.write_image(f"imgs/exec_times_{dataset_name}.png")
    fig.show()

In [64]:
exec_times_df

Unnamed: 0,exp1 SurvShapExplainer CoxPHSurvivalAnalysis,exp1 SurvShapExplainer RandomSurvivalForest,exp1 SurvShapExplainer DeepHitSingle,exp1_mean DeepLiftExplainer DeepHitSingle,exp1_median DeepLiftExplainer DeepHitSingle,exp1_zero DeepLiftExplainer DeepHitSingle,exp1 DeepLiftShapExplainer DeepHitSingle,exp1 IGExplainer DeepHitSingle,metabric SurvShapExplainer CoxPHSurvivalAnalysis,metabric SurvShapExplainer RandomSurvivalForest,metabric SurvShapExplainer DeepHitSingle,metabric_mean DeepLiftExplainer DeepHitSingle,metabric_median DeepLiftExplainer DeepHitSingle,metabric_zero DeepLiftExplainer DeepHitSingle,metabric DeepLiftShapExplainer DeepHitSingle,metabric IGExplainer DeepHitSingle
0,2.561369,10.585525,1.623114,0.385092,0.451977,0.607862,2.302496,1.055441,75.135091,177.302558,22.591857,0.756137,0.393561,0.590465,2.997532,0.895035
1,2.520087,10.316211,1.482779,0.481881,0.350663,0.512398,3.319178,0.902482,83.164011,176.205935,22.646538,0.531502,0.493903,0.495049,3.035184,0.838788
2,2.580824,10.295043,1.516404,0.384888,0.448958,0.671747,2.972099,1.025255,86.067206,176.32306,22.911996,0.453498,0.488762,0.416241,2.89188,0.994426
3,2.576405,10.234228,1.456215,0.462075,0.416439,0.566846,2.843245,0.910263,85.98029,186.588958,22.91097,0.416207,0.50114,0.480252,2.972674,0.856005
4,2.585601,10.714177,1.968473,0.369551,0.428007,0.767199,2.558227,0.741806,85.55326,193.35728,22.596443,0.419481,0.359435,0.371863,2.864392,0.833929
5,2.657259,10.529349,1.39198,0.385591,0.415472,0.415298,2.128949,0.84961,83.21137,189.915866,22.918879,0.41688,0.377752,0.193326,2.931528,0.773493
6,2.484493,10.245734,1.341215,0.423269,0.433153,0.381952,2.210467,0.844175,81.11306,190.153652,22.770606,0.378079,0.445929,0.533699,2.877982,0.820858
7,2.469136,10.274786,1.383913,0.426226,0.588979,0.363744,3.310922,0.92627,81.30336,188.120775,22.678848,0.379842,0.404061,0.371138,2.892711,0.805587
8,2.471727,10.539805,1.364512,0.385686,0.597103,0.349649,2.580234,1.454991,81.183632,189.361949,22.931849,0.472052,0.499057,0.407826,2.856459,0.843061
9,2.590123,10.245479,1.422351,0.49719,0.591546,0.342797,2.647926,0.865444,81.209492,188.561412,22.935561,0.508191,0.413716,0.411744,2.811036,0.875379


In [65]:
plot_exec_times("exp1", "synthetic dataset")

In [66]:
plot_exec_times("metabric", "Metabric dataset")