In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from lifelines import CoxPHFitter
from lifelines import KaplanMeierFitter

from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler

In [2]:
random_seed = 137
torch.manual_seed(random_seed)
np.random.seed(random_seed)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

print(device)

cuda:0


In [3]:
# Early stopping class from https://github.com/Bjarten/early-stopping-pytorch
from SurvNODE.EarlyStopping import EarlyStopping
from SurvNODE.SurvNODE_x_ranking import *

In [4]:
def measures(odesurv,initial,x,Tstart,Tstop,From,To,trans,status, multiplier=1.,points=500):
    with torch.no_grad():
        time_grid = np.linspace(0, multiplier, points)
        pvec = torch.zeros((points,x.shape[0]))
        surv_ode = odesurv.predict(x,torch.from_numpy(np.linspace(0,multiplier,points)).float().to(x.device))
        pvec = torch.einsum("ilkj,k->ilj",(surv_ode[:,:,:,:],initial))[:,:,0].cpu()
        pvec = np.array(pvec.cpu().detach())
        surv_ode_df = pd.DataFrame(pvec)
        surv_ode_df.loc[:,"time"] = np.linspace(0,multiplier,points)
        surv_ode_df = surv_ode_df.set_index(["time"])
        ev_ode = EvalSurv(surv_ode_df, np.array(Tstop.cpu()), np.array(status.cpu()), censor_surv='km')
        conc = ev_ode.concordance_td('antolini')
        ibs = ev_ode.integrated_brier_score(time_grid)
        inbll = ev_ode.integrated_nbll(time_grid)
    return conc,ibs,inbll

In [5]:
from sklearn_pandas import DataFrameMapper
import pandas as pd

def make_dataloader(df,Tmax,batchsize):
#     cols_standardize = ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
#     cols_leave = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']

#     standardize = [([col], StandardScaler()) for col in cols_standardize]
#     leave = [(col, None) for col in cols_leave]
#     x_mapper = DataFrameMapper(standardize + leave)
#     X = x_mapper.fit_transform(df).astype('float32')

    X = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']].values
    
    X = torch.from_numpy(X).float().to(device)
    T = torch.from_numpy(df[["duration"]].values).float().flatten().to(device)
    T = T/Tmax
    T[T==0] = 1e-8
    E = torch.from_numpy(df[["event"]].values).float().flatten().to(device)

    Tstart = torch.from_numpy(np.array([0 for i in range(T.shape[0])])).float().to(device)
    From = torch.tensor([1],device=device).repeat((T.shape))
    To = torch.tensor([2],device=device).repeat((T.shape))
    trans = torch.tensor([1],device=device).repeat((T.shape))

    dataset = TensorDataset(X,Tstart,T,From,To,trans,E)
    loader = DataLoader(dataset, batch_size=batchsize, shuffle=True)
    return loader

In [6]:
from sklearn.model_selection import train_test_split

def odesurv_manual_benchmark(df_train, df_test,config,name):
    torch.cuda.empty_cache()
    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.loc[:,"event"])
    
    Tmax = df_train["duration"].max()
    
    train_loader = make_dataloader(df_train,Tmax/config["multiplier"],int(len(df_train)*config["batch_size"]))
    val_loader = make_dataloader(df_val,Tmax/config["multiplier"],len(df_val))
    test_loader = make_dataloader(df_test,Tmax/config["multiplier"],len(df_test))
    
    num_in = 9
    num_latent = config["num_latent"]
    layers_encoder =  [config["encoder_neurons"]]*config["num_encoder_layers"]
    dropout_encoder = [config["encoder_dropout"]]*config["num_encoder_layers"]
    layers_odefunc =  [config["odefunc_neurons"]]*config["num_odefunc_layers"]

    trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

    encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
    odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,config["softplus_beta"]).to(device)
    block = ODEBlock(odefunc).to(device)
    odesurv = SurvNODE(block,encoder).to(device)

    optimizer = torch.optim.Adam(odesurv.parameters(), weight_decay = config["weight_decay"], lr=config["lr"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["scheduler_gamma"], patience=config["scheduler_epoch"])

    early_stopping = EarlyStopping(name=name,patience=config["patience"], verbose=True)
    for i in tqdm(range(1000)):
        odesurv.train()
        for mini,ds in enumerate(train_loader):
            myloss,_,_ = loss(odesurv,*ds,mu=config["mu"],alpha=config["alpha"],sigma=config["sigma"])
            optimizer.zero_grad()
            myloss.backward()    
            optimizer.step()

        
        odesurv.eval()
        with torch.no_grad():
            lossval,conc,ibs,ibnll = 0., 0., 0., 0.
            for _,ds in enumerate(val_loader):
                t1,_,_ = loss(odesurv,*ds,mu=config["mu"],alpha=config["alpha"],sigma=config["sigma"])
                lossval += t1.item()
                t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
                conc += t1
                ibs += t2
                ibnll += t3
            early_stopping(lossval/len(val_loader), odesurv)
            scheduler.step(lossval/len(val_loader))
            
            conc_test,ibs_test,ibnll_test = 0., 0., 0.
            print("it: "+str(i)+", train loss="+str(myloss.item())+", validation loss="+str(lossval/len(val_loader))+", c="+str(conc/len(val_loader))+", ibs="+str(ibs/len(val_loader))+", ibnll="+str(ibnll/len(val_loader)))

        if early_stopping.early_stop:
            print("Early stopping")
            break

    odesurv.load_state_dict(torch.load(name+'_checkpoint.pt'))

    odesurv.eval()
    with torch.no_grad():
        conc,ibs,ibnll = 0., 0., 0.
        for _,ds in enumerate(test_loader):
            t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
            conc += t1
            ibs += t2
            ibnll += t3
    return conc/len(test_loader), ibs/len(test_loader), ibnll/len(test_loader)

In [None]:
from sklearn.model_selection import StratifiedKFold
from pycox import datasets

kfold = StratifiedKFold(5, shuffle=True)
df_all = datasets.metabric.read_df()
gen = kfold.split(df_all.iloc[:,df_all.columns.values!="event"],df_all.loc[:,"event"])

config = {
    "lr": 1e-3,
    "weight_decay": 1e-3,
    "num_latent": 200,
    "encoder_neurons": 50,
    "num_encoder_layers": 4,
    "encoder_dropout": 0.,
    "odefunc_neurons": 800,
    "num_odefunc_layers": 2,
    "batch_size": 1/3,
    "multiplier": 1.,
    "mu": 1e-1,
    "alpha": 0.7,
    "sigma": 0.1,
    "softplus_beta": 1.,
    "scheduler_epoch": 10,
    "scheduler_gamma": 0.1,
    "patience": 20
}

odesurv_bench_vals = []
for g in gen:
    df_train = df_all.iloc[g[0]]
    df_test =  df_all.iloc[g[1]]
    conc, ibs, ibnll = odesurv_manual_benchmark(df_train,df_test,config,"metabric_test")
    odesurv_bench_vals.append([conc,ibs,ibnll])

  0%|          | 0/1000 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.161452).  Saving model ...
it: 0, train loss=0.7617104649543762, validation loss=0.16145195066928864, c=0.5354734029015017, ibs=0.181363949950862, ibnll=0.5383163511823669
EarlyStopping counter: 1 out of 20
it: 1, train loss=0.5374631881713867, validation loss=0.24406588077545166, c=0.5509671672181217, ibs=0.18134758631652412, ibnll=0.5395845425523068
EarlyStopping counter: 2 out of 20
it: 2, train loss=0.39296647906303406, validation loss=0.3268198072910309, c=0.5474357342835328, ibs=0.18127056306803127, ibnll=0.5390520197884567
EarlyStopping counter: 3 out of 20
it: 3, train loss=0.16378892958164215, validation loss=0.207983136177063, c=0.5682425553575974, ibs=0.18010064108857227, ibnll=0.5366584958962884
EarlyStopping counter: 4 out of 20
it: 4, train loss=0.19430136680603027, validation loss=0.1715567409992218, c=0.5686243318910664, ibs=0.18165707215164492, ibnll=0.5396820362383445
EarlyStopping counter: 5 out of 20
it: 5, train loss=0.171629250

In [None]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))