In [16]:
import numpy as np
import scipy.integrate
import scipy.optimize
import pandas as pd
import sksurv

## Data

In [17]:
def exp1(N=1000, seed=42, rho=1.0):
    gen = np.random.default_rng(seed=seed)

    def h0(t):
        return np.exp(-17.8+6.5*t-11*np.sqrt(t)*np.log(t)\
            +9.5*np.sqrt(t))
    
    def h(t, x1, x2, x3, x4, x5):
        return h0(t) * np.exp((-0.9+0.1*t+0.9*np.log(t))*x1*rho\
            +0.5*x2-0.2*x3+0.1*x4+1e-6*x5)
    
    x1 = gen.binomial(n=1, p=0.5, size=N)
    x2 = gen.binomial(n=1, p=0.5, size=N)
    x3 = gen.normal(loc=10, scale=np.sqrt(2), size=N)
    x4 = gen.normal(loc=20, scale=2, size=N)
    x5 = gen.normal(loc=0, scale=1, size=N)

    def S(t, *args):
        H = scipy.integrate.quad(h, 0.0, t, args=args)[0]
        return np.exp(-H)
    
    U = gen.uniform(low=0.0, high=1.0, size=N)

    def _T(u, *args):
        t0, t1 = 1e-16, 1.0
        while S(t1, *args) - u > 0.0:
            t1 = 2.0 * t1
        
        sol = scipy.optimize.root_scalar(
            f=lambda t: S(t, *args) - u,
            bracket=[t0, t1],
        )
        return sol.root
    
    T = np.array([
        _T(U[i], x1[i], x2[i], x3[i], x4[i], x5[i])
        for i in range(N)
    ])

    C_l = gen.uniform(low=11.0, high=16.0, size=N)
    C_r = gen.uniform(low=0.0, high=24.0, size=N)
    y = np.min(np.column_stack((T, C_l, C_r)), axis=1)
    delta = np.where((C_l > T) & (C_r > T), 1, 0)

    return pd.DataFrame.from_dict(dict(
        x1=x1, x2=x2, x3=x3, x4=x4, x5=x5,
        duration=y, event=delta,
    ))

exp1_df = exp1()

In [18]:
def sphere(gen: np.random.Generator, center, R, size=None):
    dim = center.shape[0]
    pts = gen.multivariate_normal(
        mean=np.zeros(dim),
        cov=np.eye(dim),
        size=size,
    )
    pts = pts / np.linalg.norm(pts, axis=-1).reshape(-1, 1)
    return center + pts * R

def lime_dataset(center, b, lambda_=1e-5, v=2, N=1000, seed=42):
    gen = np.random.default_rng(seed=seed)

    x = sphere(gen, np.asarray(center), 8.0, size=N)
    U = gen.uniform(low=0.0, high=1.0, size=N)
    y = (-np.log(U)/(lambda_*np.exp(np.dot(x, b))))**(1/v)
    delta = gen.binomial(n=1, p=0.9, size=N)

    return pd.DataFrame.from_dict(dict(
        x0=x[:,0], x1=x[:,1], x2=x[:,2], x3=x[:,3], x4=x[:,4],
        duration=y, event=delta,
    ))

lime_df0 = lime_dataset(
    center=np.zeros(5),
    b=np.array([1e-6, 0.1, -0.15, 1e-6, 1e-6]),
)

lime_df1 = lime_dataset(
    center=np.array([4.0, -8.0, 2.0, 4.0, 2.0]),
    b=np.array([1e-6, -0.15, 1e-6, 1e-6, -0.1]),
)

In [19]:
def heart_failure(path="./data/heart_failure.csv"):
    df = pd.read_csv(path)
    df = df.rename(columns={
        "time": "duration",
        "DEATH_EVENT": "event",
    })
    return df

heart_df = heart_failure()

## Models

In [20]:
from sksurv.datasets import get_x_y
from sklearn.model_selection import train_test_split

df = exp1_df
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_X, train_y = get_x_y(train_df, 
    attr_labels=["event", "duration"], pos_label=1)
test_X, test_y = get_x_y(test_df, 
    attr_labels=["event", "duration"], pos_label=1)

In [21]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored

cph = CoxPHSurvivalAnalysis()
cph = cph.fit(train_X, train_y)
cph.score(test_X, test_y)

0.5632143852801784

In [22]:
from sksurv.ensemble import RandomSurvivalForest

rsf = RandomSurvivalForest(
    random_state=42,
    n_estimators=100,
    min_samples_split=8,
    min_samples_leaf=4,
    max_features=3,
    max_samples=0.8,
)

rsf = rsf.fit(train_X, train_y)
rsf.score(test_X, test_y)

0.5849595762475607

In [23]:
from pycox.models import DeepHitSingle
import torch
import torch.nn as nn
import torch.optim as optim

tf = DeepHitSingle.label_transform(cuts=16)

train_y_pcx = tf.fit_transform(train_y["duration"], train_y["event"])
test_y_pcx = tf.fit_transform(test_y["duration"], test_y["event"])

net = nn.Sequential(
    nn.Linear(test_X.shape[-1], 32),
    nn.ReLU(),
    nn.BatchNorm1d(32),
    nn.Dropout1d(p=0.5),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.BatchNorm1d(32),
    nn.Dropout1d(p=0.5),
    nn.Linear(32, tf.out_features),
)

deephit = DeepHitSingle(
    net=net,
    optimizer=optim.Adam(net.parameters(), lr=1e-3),
    duration_index=tf.cuts,
)

log = deephit.fit(
    input=train_X.values.astype(np.float32),
    target=train_y_pcx,
    batch_size=64,
    epochs=32,
    verbose=False,
)

In [24]:
from pycox.evaluation import EvalSurv

surv = deephit.interpolate(16).predict_surv_df(test_X.values.astype(np.float32))

eval_ = EvalSurv(
    surv=surv,
    durations=test_y_pcx[0],
    events=test_y_pcx[1],
    censor_surv="km",
)

eval_.concordance_td(method="antolini")

  assert pd.Series(self.index_surv).is_monotonic


0.5900453881067774

In [25]:
from captum.attr import DeepLiftShap

deep_lift = DeepLiftShap(deephit.net)

def deep_lift_results(X: pd.DataFrame):
    inputs = torch.tensor(X.values)
    baselines = inputs.mean(dim=0).broadcast_to(inputs.shape)
    
    attr_values = []
    for idx in range(tf.out_features):
        var_attr_values = deep_lift.attribute(inputs, baselines, target=idx)
        attr_values.append(var_attr_values)
    attr_values = torch.stack(attr_values, dim=2)

    return attr_values.detach().numpy()

deep_attr = deep_lift_results(test_X.iloc[:10].astype(np.float32))

               activations. The hooks and attributes will be removed
            after the attribution is finished


In [26]:
from survshap import SurvivalModelExplainer, ModelSurvSHAP
from scipy.interpolate import interp1d
from scipy.integrate import quad

class PycoxAdapter:
    def __init__(self, model, cuts):
        self.model = model
        self.cuts = cuts

    def predict_survival_function(self, X):
        X = np.asarray(X)
        values = self.model.predict_surv(X)
        return [interp1d(self.cuts, v) for v in values]
    
    def predict_cumulative_hazard_function(self, X):
        X = np.asarray(X)
        values = self.model.predict_hazard(X)
        return [lambda t: quad(interp1d(self.cuts, v), 0, t) for v in values]

def survshap_results(X: pd.DataFrame, y):
    exp = SurvivalModelExplainer(
        model=PycoxAdapter(deephit, tf.cuts),
        data=X,
        y=y,
    )

    shap = ModelSurvSHAP()
    shap.fit(exp, timestamps=tf.cuts)
    res_df: pd.DataFrame = shap.full_result

    g = res_df.groupby(by="variable_name")
    skip = ["variable_str", "variable_name", "variable_value", "B", "aggregated_change", "index"]

    attr_values = {}
    for var in g.groups:
        grp: pd.DataFrame = g.get_group(var)
        grp = grp.sort_values(by=["index"])
        var_attr_values = grp.iloc[:,len(skip):].values
        attr_values[var] = var_attr_values
    
    attr_values = [attr_values[var] for var in X.columns]
    attr_values = np.stack(attr_values, axis=1)
    return attr_values
    

sshap_attrs = survshap_results(test_X.iloc[:10].astype(np.float32), test_y[:10])

100%|██████████| 10/10 [00:02<00:00,  4.34it/s]
