In [1]:
import numpy as np
import scipy.integrate
import scipy.optimize
import pandas as pd
import sksurv
from ruamel.yaml import safe_load

## Data
The `exp1`, `lime` and `heart` are from the SurvSHAP(t) paper (iirc). METABRIC is from the DeepHit paper, and it's also used in the pycox docs.

In [2]:
def exp1(N=1000, seed=42, rho=1.0):
    gen = np.random.default_rng(seed=seed)

    def h0(t):
        return np.exp(-17.8+6.5*t-11*np.sqrt(t)*np.log(t)\
            +9.5*np.sqrt(t))
    
    def h(t, x1, x2, x3, x4, x5):
        return h0(t) * np.exp((-0.9+0.1*t+0.9*np.log(t))*x1*rho\
            +0.5*x2-0.2*x3+0.1*x4+1e-6*x5)
    
    x1 = gen.binomial(n=1, p=0.5, size=N)
    x2 = gen.binomial(n=1, p=0.5, size=N)
    x3 = gen.normal(loc=10, scale=np.sqrt(2), size=N)
    x4 = gen.normal(loc=20, scale=2, size=N)
    x5 = gen.normal(loc=0, scale=1, size=N)

    def S(t, *args):
        H = scipy.integrate.quad(h, 0.0, t, args=args)[0]
        return np.exp(-H)
    
    U = gen.uniform(low=0.0, high=1.0, size=N)

    def _T(u, *args):
        t0, t1 = 1e-16, 1.0
        while S(t1, *args) - u > 0.0:
            t1 = 2.0 * t1
        
        sol = scipy.optimize.root_scalar(
            f=lambda t: S(t, *args) - u,
            bracket=[t0, t1],
        )
        return sol.root
    
    T = np.array([
        _T(U[i], x1[i], x2[i], x3[i], x4[i], x5[i])
        for i in range(N)
    ])

    C_l = gen.uniform(low=11.0, high=16.0, size=N)
    C_r = gen.uniform(low=0.0, high=24.0, size=N)
    y = np.min(np.column_stack((T, C_l, C_r)), axis=1)
    delta = np.where((C_l > T) & (C_r > T), 1, 0)

    return pd.DataFrame.from_dict(dict(
        x1=x1, x2=x2, x3=x3, x4=x4, x5=x5,
        duration=y, event=delta,
    ))

def exp1_csv():
    df = pd.read_csv("data/exp1_data.csv")
    df = df.rename(columns={"time": "duration"})
    return df

# exp1_df = exp1()
exp1_df = exp1_csv()

In [3]:
def sphere(gen: np.random.Generator, center, R, size=None):
    dim = center.shape[0]
    pts = gen.multivariate_normal(
        mean=np.zeros(dim),
        cov=np.eye(dim),
        size=size,
    )
    pts = pts / np.linalg.norm(pts, axis=-1).reshape(-1, 1)
    return center + pts * R

def lime_dataset(center, b, lambda_=1e-5, v=2, N=1000, seed=42):
    gen = np.random.default_rng(seed=seed)

    x = sphere(gen, np.asarray(center), 8.0, size=N)
    U = gen.uniform(low=0.0, high=1.0, size=N)
    y = (-np.log(U)/(lambda_*np.exp(np.dot(x, b))))**(1/v)
    delta = gen.binomial(n=1, p=0.9, size=N)

    return pd.DataFrame.from_dict(dict(
        x0=x[:,0], x1=x[:,1], x2=x[:,2], x3=x[:,3], x4=x[:,4],
        duration=y, event=delta,
    ))

lime_df0 = lime_dataset(
    center=np.zeros(5),
    b=np.array([1e-6, 0.1, -0.15, 1e-6, 1e-6]),
)

lime_df1 = lime_dataset(
    center=np.array([4.0, -8.0, 2.0, 4.0, 2.0]),
    b=np.array([1e-6, -0.15, 1e-6, 1e-6, -0.1]),
)

In [4]:
def heart_failure():
    df = pd.read_csv("data/exp3_heart_failure_dataset.csv")
    df = df.rename(columns={"time": "duration", "DEATH_EVENT": "event"})
    return df

heart_df = heart_failure()

In [5]:
# def metabric():
#     with open("data/METABRIC_features.yml", mode="r") as features_f:
#         features = safe_load(features_f)
    
#     clinical = features["clinical"]
#     duration = features["response"]["duration"]
#     event = features["response"]["event"]

#     df = pd.read_csv(
#         "data/METABRIC_RNA_Mutation.csv",
#         usecols=[*clinical, duration, event],
#     )

#     df = df.rename(columns={duration: "duration", event: "event"})

#     return df

# metabric_df = metabric()

from pycox.datasets import metabric
metabric_df = metabric.read_df()

## Models

In [6]:
from sksurv.datasets import get_x_y
from sklearn.model_selection import train_test_split

df = metabric_df

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_X, train_y = get_x_y(train_df, 
    attr_labels=["event", "duration"], pos_label=1)
test_X, test_y = get_x_y(test_df, 
    attr_labels=["event", "duration"], pos_label=1)

# X, y = get_x_y(df, attr_labels=["event", "duration"], pos_label=1)
# train_X = test_X = X
# train_y = test_y = y

In [7]:
def eval_model(model, name):
    print(f"[{name}] Train score: {model.score(train_X, train_y)}")
    print(f"[{name}] Test score: {model.score(test_X, test_y)}")

### Linear model (`CoxPHSurvivalAnalysis`)

In [8]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

cph = CoxPHSurvivalAnalysis()
cph = cph.fit(train_X, train_y)
eval_model(cph, "cph")

[cph] Train score: 0.6415487599659973
[cph] Test score: 0.6336213958060288


### Random survival forest

In [9]:
from sksurv.ensemble import RandomSurvivalForest

rsf = RandomSurvivalForest(
    random_state=42,
    n_estimators=100,
    min_samples_split=8,
    min_samples_leaf=4,
    max_features=3,
    max_samples=0.8,
)

rsf = rsf.fit(train_X, train_y)
eval_model(rsf, "rsf")

[rsf] Train score: 0.844405286780323
[rsf] Test score: 0.6429595347313237


### DNN (DeepHit)

In [10]:
from pycox.models import DeepHitSingle
from pycox.evaluation import EvalSurv
from sksurv.functions import StepFunction
import torch
import torch.nn as nn
import torch.optim as optim
from dataclasses import dataclass
from typing import Callable, Optional, Any
from scipy.interpolate import interp1d
from warnings import catch_warnings, simplefilter


@dataclass
class TrainConf:
    optimizer_fn: Callable[[nn.Module], optim.Optimizer] \
        = lambda net: optim.Adam(net.parameters(), lr=1e-3)
    device: Optional[torch.device] = None
    batch_size: int = 256
    epochs: int = 1
    callbacks: Any = None
    verbose: bool = False
    num_workers: int = 0
    shuffle: bool = False
    metrics: Any = None
    val_data: Optional[torch.Tensor] = None
    val_batch_size: int = 8224


class DeepHitSingle_:
    """A sksurv-like wrapper around pycox's DeepHitSingle."""

    def __init__(self, net: nn.Module, timestamps=None, alpha=0.2, sigma=0.1):
        """Create an instance.
        :param net: NN at the core of DeepHit. The number of output features
         is the number of cuts/timestamps.
        :param timestamps: (Optional) Predefined timestamps/steps to use. 
        Number of them must be equal to the dimensionality of net's output 
        space. If not provided, default cuts are used (see pycox docs for more 
        details.)
        :param alpha, sigma: Parameters for the pycox.models.DeepHitSingle 
        class."""

        self.net = net
        self.timestamps = None
        self.alpha = alpha
        self.sigma = sigma
    
    def fit(self, X, y, conf: Optional[TrainConf] = None):
        # Get the timestamps/cuts
        if self.timestamps is not None:
            cuts = self.timestamps
        else:
            # If the cuts are unspecified, we guess it by dry-running the net 
            # and checking the dimensionality of the output space.
            with torch.no_grad():
                self.net.eval()
                num_features = X.shape[1]
                dummy = torch.empty((1, num_features), dtype=torch.float32)
                res = self.net(dummy)[0]
                self.net.train()
            cuts = res.shape[0]
        
        self.tf = DeepHitSingle.label_transform(cuts)

        # Encode X and y to a form acceptable for DeepHitSingle.
        def encode(X, y, fit=False):
            input = torch.tensor(np.asarray(X), dtype=torch.float32)
            if fit:
                self.tf.fit(y["duration"], y["event"])
            durations, events = self.tf.transform(y["duration"], y["event"])
            durations = torch.tensor(durations, dtype=torch.int64)
            events = torch.tensor(events, dtype=torch.float32)
            target = (durations, events)
            return input, target
        
        input, target = encode(X, y, fit=True)
        # Save event_times_, just like the sksurv models.
        self.event_times_ = self.tf.cuts

        if conf is None:
            conf = TrainConf()
            
        self._model = DeepHitSingle(
            net=self.net,
            optimizer=conf.optimizer_fn(self.net),
            device=conf.device,
            duration_index=self.tf.cuts,
            alpha=self.alpha,
            sigma=self.sigma,
        )

        if conf.val_data is not None:
            val_X, val_y = conf.val_data
            val_data = encode(val_X, val_y)
        else:
            val_data = None

        self._log = self._model.fit(input, target, conf.batch_size, conf.epochs,
            conf.callbacks, conf.verbose, conf.num_workers, conf.shuffle,
            conf.metrics, val_data, conf.val_batch_size)
        
        return self
    
    def predict_survival_function(self, X, return_array=False):
        """Predict survival function. See sksurv models for more details."""
        
        X = torch.tensor(np.asarray(X), dtype=torch.float32)
        sf_values = self._model.predict_surv(X).numpy()

        if return_array:
            return sf_values
        else:
            return np.array([
                StepFunction(self.event_times_, values)
                for values in sf_values
            ])

    def predict_cumulative_hazard_function(self, X, return_array=False):
        raise NotImplementedError
    
    def score(self, X, y):
        X = torch.tensor(np.asarray(X), dtype=torch.float32)
        surv = self._model.predict_surv_df(X)
        
        durations, events = self.tf.transform(y["duration"], y["event"])

        with catch_warnings():
            simplefilter("ignore")
            eval = EvalSurv(
                surv=surv,
                durations=durations,
                events=events,
                censor_surv="km",
            )
            return eval.concordance_td(method="antolini")

In [11]:
from torchtuples.callbacks import EarlyStopping


deephit_ = DeepHitSingle_(
    net = nn.Sequential(
        torch.nn.Linear(test_X.shape[-1], 32),
        torch.nn.ReLU(),
        torch.nn.BatchNorm1d(32),
        torch.nn.Dropout(0.1),
        torch.nn.Linear(32, 32),
        torch.nn.ReLU(),
        torch.nn.BatchNorm1d(32),
        torch.nn.Dropout(0.1),
        torch.nn.Linear(32, 1000),
    ),
)

_train_X, _val_X, _train_y, _val_y = \
    train_test_split(train_X, train_y, test_size=0.2, random_state=1234)

train_conf = TrainConf(
    epochs=int(100),
    callbacks=[EarlyStopping()],
    val_data=(_val_X, _val_y),
    batch_size=256,
    optimizer_fn=lambda net: optim.Adam(net.parameters(), lr=1e-2)
    # verbose=True,
)

deephit_ = deephit_.fit(_train_X, _train_y, train_conf)
eval_model(deephit_, "deephit")

[deephit] Train score: 0.6031858275335888
[deephit] Test score: 0.5932092184040023


## Attribution methods

In [12]:
from survshap import SurvivalModelExplainer, ModelSurvSHAP, PredictSurvSHAP
from tqdm import tqdm

class SurvShapExplainer:
    """A bit more shap-esque Explainer wrapper for SurvSHAP."""

    def __init__(self, model, data=None, y=None, calculation_method="kernel", aggregation_method="integral", path="average", B=25, random_state=42, pbar=False):
        self.model = model
        self.calculation_method = calculation_method
        self.aggregation_method = aggregation_method
        self.path = path
        self.B = B
        self.random_state = random_state
        self.exp = SurvivalModelExplainer(self.model, data=data, y=y)
        self.pbar = pbar

    def __call__(self, observations: pd.DataFrame, timestamps=None) -> pd.DataFrame:
        """Predict SHAP values for a number of observations. In this case, we 
        deal with survival functions in the form of StepFunction, so likewise 
        the output SHAP values will be step functions.
        :param observations: Dataframe with shape (num_obs, num_features).
        :return: A dataframe with shape (num_obs, num_features), where each 
        "cell" contains a StepFunction being the SHAP attribution for a given 
        observation and a given value."""

        if timestamps is None:
            timestamps = self.model.event_times_

        skip = ["variable_str", "variable_name", "variable_value", "B", "aggregated_change", "index"]

        all_results = []
        
        idx_seq = range(len(observations))
        if self.pbar:
            idx_seq = tqdm(idx_seq)
        
        for idx in idx_seq:
            obs = observations.iloc[[idx]]
            shap = PredictSurvSHAP(
                calculation_method=self.calculation_method,
                aggregation_method=self.aggregation_method,
                path=self.path,
                B=self.B,
                random_state=self.random_state,
            )

            shap.fit(self.exp, obs, timestamps)
            obs_df = shap.result
            obs_df.insert(len(skip)-1, "index", idx)
            all_results.append(obs_df)
        
        res_df = pd.concat(all_results)
            
        g = res_df.groupby(by="variable_name")

        var_attr_values = {}
        for var in g.groups:
            grp: pd.DataFrame = g.get_group(var)
            grp = grp.sort_values(by=["index"])
            attr_values = grp.iloc[:,len(skip):].values
            var_attr_values[var] = [
                StepFunction(timestamps, attr_values_)
                for attr_values_ in attr_values
            ]
            
        res_df = pd.DataFrame(var_attr_values)
        res_df = res_df.set_index(observations.index, drop=True)
        return res_df

In [13]:
from captum.attr import DeepLift
from warnings import catch_warnings, simplefilter

class DeepLiftExplainer:
    """A shap-esque wrapper for DeepLift. See SurvShapExplainer for more
     details."""
     
    def __init__(self, model, data=None, y=None):
        self.model = model
        if data is not None:
            data = torch.tensor(data.values, dtype=torch.float32)
            self.baselines = data.mean(dim=0)
        else:
            self.baselines = None
    
    def __call__(self, X: pd.DataFrame):
        inputs = torch.tensor(X.values, dtype=torch.float32)
        if self.baselines is not None:
            baselines = self.baselines.broadcast_to(inputs.shape)
        else:
            baselines = inputs.mean(dim=0).broadcast_to(inputs.shape)

        deep_lift = DeepLift(self.model.net)

        with catch_warnings():
            simplefilter("ignore")
            attr_values = []
            for idx in range(len(self.model.event_times_)):
                attrs = deep_lift.attribute(inputs, baselines, target=idx)
                attr_values.append(attrs)
            attr_values = torch.stack(attr_values, dim=2).detach().numpy()

        var_attr_values = {}
        for var_idx, var in enumerate(X.columns):
            var_attr_values[var] = [
                StepFunction(
                    x=self.model.event_times_,
                    y=np.cumsum(attr_values[obs_idx,var_idx]),
                ) for obs_idx in range(len(X))
            ]
        
        return pd.DataFrame(var_attr_values)

## Experiments

This is "sort of" incomplete but I don't care. Anyways, if you run these cells multiple times, you can get some very fishy results, which I can't make heads or tails of. I will also say that DeepLift's explanations are sort of meaningless?

In [14]:
import plotly.graph_objects as go

def plot_expl(expl):
    fig = go.Figure()
    for var in expl.index:
        shap_f = expl[var]
        fig.add_trace(go.Scatter(
            x=shap_f.x, y=shap_f.y,
            mode="lines",
            line=dict(shape="hv"),
            name=var,
        ))
    return fig

In [15]:
cph_expl = SurvShapExplainer(cph, test_X, test_y)
cph_expl0 = cph_expl(test_X.iloc[[14]]).iloc[0]
plot_expl(cph_expl0)

In [16]:
rsf_expl = SurvShapExplainer(rsf, test_X, test_y)
rsf_expl0 = rsf_expl(test_X.iloc[[14]]).iloc[0]
plot_expl(rsf_expl0)

In [17]:
dh_expl = SurvShapExplainer(deephit_, test_X, test_y)
dh_expl0 = dh_expl(test_X.iloc[[14]]).iloc[0]
plot_expl(dh_expl0)

In [18]:
dl_expl = DeepLiftExplainer(deephit_, test_X, test_y)
dl_expl0 = dl_expl(test_X.iloc[[14]]).iloc[0]
plot_expl(dl_expl0)