In [1]:
import torch
import torch.nn as nn
from torchdiffeq import odeint
# from torchdiffeq import odeint_adjoint as odeint
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
# from sksurv.datasets import load_flchain
import torch.multiprocessing as mp

from pycox.evaluation import EvalSurv
from ray import tune
from sklearn.preprocessing import StandardScaler

# random_seed = 1991

# torch.manual_seed(random_seed)
# np.random.seed(random_seed)

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# torch.backends.cudnn.benchmark = True
# device = 'cpu'

print(device)

cuda:0


In [2]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, name, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.name = name
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.name+'_checkpoint.pt')
        self.val_loss_min = val_loss

In [3]:
import torch
import torch.nn as nn
from torchdiffeq import odeint
# from torchdiffeq import odeint_adjoint as odeint
import numpy as np
import pandas as pd
     
        
class Encoder(nn.Module):
    """
        Encoding of the initial values of the memory states
        Input: 
            - number of covariates
            - number of memory states
            - hidden layer neurons, given as array (e.g. [10,10,10])
            - dropout for hidden layers, given as array (e.g. [0.2,0.3,0.4])
    """
    def __init__(self,num_in,num_latent,layers,p_dropout):
        super(Encoder, self).__init__()
        self.net = nn.Sequential(*((nn.Linear(num_in,layers[0]), nn.ReLU(), nn.Dropout(p_dropout[0])) + tuple(tup for element in tuple(((nn.Linear(layers[i],layers[i+1]), nn.ReLU(), nn.Dropout(p_dropout[i+1])) for i in range(len(layers)-1))) for tup in element) + (nn.Linear(layers[-1],num_latent),)))
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
    def forward(self, x):
        return self.net(x)


class ODEFunc(nn.Module):
    """
        KFE_KBE function to calculate the derivatives in the ODEsolver
        Input: 
            - transition matrix giving possible transitions (e.g. [[NA,1,1],[NA,NA,1],[NA,NA,NA]] for the illness-death model)
            - number of covariates
            - number of memory states
            - hidden layer neurons, given as array (e.g. [10,10,10])
            - dropout for hidden layers, given as array (e.g. [0.2,0.3,0.4])
            - softplus parameter (should be left at 1.)
    """
    def __init__(self,transition_matrix,num_in,num_latent,layers,softplus_beta=1.):
        super(ODEFunc, self).__init__()
        
        self.softplus_beta = softplus_beta
        self.transition_matrix = transition_matrix
        self.trans_dim = transition_matrix.shape[0]
        self.num_latent = num_latent
        self.number_of_hazards = int(np.nansum(transition_matrix.flatten().cpu()))
        self.num_probs = np.prod(transition_matrix.shape)
        # use this NN if covariates are to be included
#         self.net = nn.Sequential(*((nn.Linear(2*self.num_probs+self.number_of_hazards+num_latent+num_in+1,layers[0]), nn.Tanh()) + tuple(tup for element in tuple(((nn.Linear(layers[i],layers[i+1]), nn.Tanh()) for i in range(len(layers)-1))) for tup in element) + (nn.Linear(layers[-1],self.number_of_hazards+num_latent),)))
        self.net = nn.Sequential(*((nn.Linear(2*self.num_probs+self.number_of_hazards+2*num_latent+1,layers[0]), nn.Tanh()) + tuple(tup for element in tuple(((nn.Linear(layers[i],layers[i+1]), nn.Tanh()) for i in range(len(layers)-1))) for tup in element) + (nn.Linear(layers[-1],self.number_of_hazards+num_latent),)))

        count = 0
        length = len(list(self.net.modules()))
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0.)
            if count==length-1:
                nn.init.normal_(m.weight, mean=0, std=0)
            count += 1
        self.num_in = num_in
    
#     def set_x(self, x):
#         self.x = x

    def set_y0(self, y0):
        self.y0 = y0

    # KFE_KBE function
    def forward(self, t, y):
        # pass values through NN
        out = self.net(torch.cat((y,self.y0,torch.tensor([t],device=y.device).repeat((y.shape[0],1))),1))
        
        # build Q matrix from output
        qvec = torch.nn.functional.softplus(out[:,:self.number_of_hazards],beta=self.softplus_beta)
        q = torch.zeros(self.trans_dim, self.trans_dim,device=y.device).repeat((y.shape[0],1,1))
        q[self.transition_matrix.repeat((y.shape[0],1,1))==1] = qvec.flatten()
        q[torch.eye(self.trans_dim, self.trans_dim,device=y.device).repeat((y.shape[0],1,1)) == 1] = -torch.sum(q,2).flatten()
        
        # get P matrix
        P = torch.reshape(y[:,:self.num_probs],(y.shape[0],self.trans_dim,self.trans_dim))
        P_back = torch.reshape(y[:,self.num_probs:(2*self.num_probs)],(y.shape[0],self.trans_dim,self.trans_dim))
        # calculate right side of KFE and KBE
        Pprime = torch.bmm(P, q)
        Pprime_back = -torch.bmm(q, P_back)
        return torch.cat((Pprime.reshape(y.shape[0],self.num_probs),Pprime_back.reshape(y.shape[0],self.num_probs),qvec,out[:,self.number_of_hazards:]),1)
     
class ODEBlock(nn.Module):
    """
        Helper Function to define the initial value problem
    """
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.num_probs = odefunc.num_probs
        self.trans_dim = odefunc.trans_dim
        self.number_of_hazards = odefunc.number_of_hazards
        self.transition_matrix = odefunc.transition_matrix

    def forward(self, y0, x, tinterval):
#         self.odefunc.set_x(x)
        self.odefunc.set_y0(y0)
        p0 = torch.eye(self.trans_dim,device=y0.device).reshape(self.num_probs).repeat((y0.shape[0],1))
        Q0 = torch.zeros(self.number_of_hazards,device=x.device).repeat((y0.shape[0],1))
        yin = torch.cat((p0,p0,Q0,y0),1)
        out = odeint(self.odefunc, yin, tinterval, method="dopri5", atol=1e-8, rtol=1e-8)
        return out       
        
class SurvNODE(nn.Module):
    """
        SurvNODE class, 
    """
    def __init__(self,odeblock,encoder):
        super(SurvNODE, self).__init__()
        self.odeblock = odeblock
        self.encoder = encoder
        self.num_probs = odeblock.num_probs
        self.trans_dim = odeblock.trans_dim
        self.number_of_hazards = odeblock.number_of_hazards
        self.transition_matrix = odeblock.transition_matrix

    def forward(self, x, tstart, tstop, from_state, to_state):
        # get P_ij(0,t) and P_ij(t,0) at all batch times
        out = self.odeblock(self.encoder(x),x,torch.unique(torch.cat([torch.tensor([0.],device=x.device),tstart,tstop])))
        
        # get P_ij(s,t) through Kolmogorov backward equation
        tstart_indices = torch.flatten(torch.cat([(torch.unique(torch.cat([torch.tensor([0.],device=x.device),tstart,tstop])) == time).nonzero() for time in tstart]))
        Ttstartinv = torch.cat([out[tstart_indices[i],i,self.num_probs:(self.num_probs*2)] for i in range(len(tstart))]).reshape((tstart.shape[0],self.trans_dim,self.trans_dim))
        tstop_indices = torch.flatten(torch.cat([(torch.unique(torch.cat([torch.tensor([0.],device=x.device),tstart,tstop])) == time).nonzero() for time in tstop]))
        Ttstop = torch.cat([out[tstop_indices[i],i,:self.num_probs] for i in range(len(tstop))]).reshape((tstop.shape[0],self.trans_dim,self.trans_dim))
        S = torch.bmm(Ttstartinv,Ttstop)
        S = torch.cat([S[i:i+1,from_state[i]-1,from_state[i]-1] for i in range(len(from_state))])
        
        
#         # get P_ij(s,0) by inverting P_ij(0,s)
#         tstart_indices = torch.flatten(torch.cat([(torch.unique(torch.cat([torch.tensor([0.],device=x.device),tstart,tstop])) == time).nonzero() for time in tstart]))
#         Ttstart = out[tstart_indices,[i for i in range(tstart.shape[0])],:self.num_probs].reshape((tstart.shape[0],self.trans_dim,self.trans_dim))
#         Ttstartinv = torch.inverse(Ttstart)
#         # # inverse with conditioning (?)
#         # Ttstartinv = torch.inverse(Ttstart+1e-5*torch.eye(Ttstart.shape[1],device=x.device).flatten().repeat(Ttstart.shape[0]).reshape(Ttstart.shape))
#         tstop_indices = torch.flatten(torch.cat([(torch.unique(torch.cat([torch.tensor([0.],device=x.device),tstart,tstop])) == time).nonzero() for time in tstop]))
#         Ttstop = out[tstop_indices,[i for i in range(tstop.shape[0])],:self.num_probs].reshape((tstop.shape[0],self.trans_dim,self.trans_dim))
#         S = torch.bmm(Ttstartinv,Ttstop)
#         S = torch.cat([S[i:i+1,from_state[i]-1,from_state[i]-1] for i in range(len(from_state))])
        
        # get lambda at tstop
        net_in = torch.cat((torch.cat([out[tstop_indices[i],i:i+1,:] for i in range(len(tstop))]),self.encoder(x),tstop.reshape(-1,1)),1)
        qvec = torch.nn.functional.softplus(self.odeblock.odefunc.net(net_in)[:,:self.number_of_hazards],beta=self.odeblock.odefunc.softplus_beta)
        q = torch.zeros(self.trans_dim, self.trans_dim,device=x.device).repeat((x.shape[0],1,1))
        q[self.transition_matrix.repeat((x.shape[0],1,1))==1] = qvec.flatten()
        lam = torch.cat([q[t:t+1,from_state[t]-1,to_state[t]-1] for t in range(len(from_state))])
        # get all augmented hazards at the final time (t=multiplier) for loss term
        net_in = torch.cat((out[-1,:,:],self.encoder(x),torch.tensor([max(tstop)],device=x.device).repeat(tstop.reshape(-1,1).shape)),1)
        out = self.odeblock.odefunc.net(net_in)
        all_hazards_T = torch.cat((torch.nn.functional.softplus(out[:,:self.number_of_hazards],beta=self.odeblock.odefunc.softplus_beta),out[:,self.number_of_hazards:]),-1)
        
        return (S,lam,all_hazards_T)
    
    def predict(self,x,tvec):
        """
            Prediction of survival based on covariates x at times in tvec.
            This function returns the transition matrix P_ij(0,t) at every t in tvec.
        """
        with torch.no_grad():
            out = self.odeblock(self.encoder(x),x,tvec.float().to(x.device))
            T = out[:,:,:self.odeblock.odefunc.num_probs].reshape((tvec.shape[0],x.shape[0],self.trans_dim,self.trans_dim))
        return T
    
    def predict_hazard(self,x,tvec):
        """
            Predict cause specific hazard function based on covariates x at times in tvec.
            This function returns the matrix Q of instantaneous hazards over time.
        """
        with torch.no_grad():
            tvec = tvec.float().to(x.device)
            out = self.odeblock(self.encoder(x),x,tvec)
            Qvec = torch.zeros((tvec.shape[0],x.shape[0],self.trans_dim,self.trans_dim),device=x.device)
            for i in range(tvec.shape[0]):
                net_in = torch.cat((out[i,:,:],tvec[i].repeat(x.shape)),1)
                temp = self.odeblock.odefunc.net(net_in)
                qvec = torch.nn.functional.softplus(temp[:,:self.number_of_hazards],beta=self.odeblock.odefunc.softplus_beta)
                Q = torch.zeros(self.trans_dim, self.trans_dim,device=x.device).repeat((x.shape[0],1,1))
                Q[self.transition_matrix.repeat((x.shape[0],1,1))==1] = qvec.flatten()
                Q[torch.eye(self.trans_dim, self.trans_dim,device=x.device).repeat((x.shape[0],1,1)) == 1] = -torch.sum(Q,2).flatten()
                Qvec[i,:,:,:] = Q
        return Qvec
    
    def predict_cumhazard(self,x,tvec):
        """
            Predict cumulative hazard function based on covariates x at times in tvec.
            The cumulative cause specific hazards are given as the integral from 0 to t over the cause specific hazards.
            This function returns a vector of cause specific cumulative hazards over time.
        """
        with torch.no_grad():
            tvec = tvec.float().to(x.device)
            tvec = torch.unique(torch.cat([torch.tensor([0.],device=x.device),tvec]))
            out = self.odeblock(self.encoder(x),x,tvec)
            qvec = out[:,:,(2*self.num_probs):(2*self.num_probs+self.number_of_hazards)]
            Qvec = torch.zeros((tvec.shape[0],x.shape[0],self.trans_dim,self.trans_dim),device=x.device)
            Qvec[self.transition_matrix.repeat((tvec.shape[0],x.shape[0],1,1))==1] = qvec.flatten()
        return Qvec

            
def loss(odesurv,x,Tstart,Tstop,From,To,trans,status,mu=1e-4):
    """
        Loss function
        Parameter mu regulates the influence of the Lyapunov loss
    """
    trans_exist = torch.tensor([odesurv.transition_matrix[From[i]-1,To[i]-1] for i in range(len(From))])
    trans_exist = torch.where(trans_exist==1)
    x = x[trans_exist]
    Tstart = Tstart[trans_exist]
    Tstop = Tstop[trans_exist]
    From = From[trans_exist]
    To = To[trans_exist]
    status = status[trans_exist]
    
    S,lam,all_h_T = odesurv(x,Tstart,Tstop,From,To)
    loglik = -(status*torch.log(lam)+torch.log(S)).mean()
    reg = torch.norm(all_h_T,2,dim=1).mean()

    return (loglik + mu*reg), loglik, reg

In [5]:
def measures(odesurv,initial,x,Tstart,Tstop,From,To,trans,status, multiplier=1.,points=500):
    with torch.no_grad():
        time_grid = np.linspace(0, multiplier, points)
        pvec = torch.zeros((points,x.shape[0]))
        surv_ode = odesurv.predict(x,torch.from_numpy(np.linspace(0,multiplier,points)).float().to(x.device))
        pvec = torch.einsum("ilkj,k->ilj",(surv_ode[:,:,:,:],initial))[:,:,0].cpu()
        pvec = np.array(pvec.cpu().detach())
        surv_ode_df = pd.DataFrame(pvec)
        surv_ode_df.loc[:,"time"] = np.linspace(0,multiplier,points)
        surv_ode_df = surv_ode_df.set_index(["time"])
        ev_ode = EvalSurv(surv_ode_df, np.array(Tstop.cpu()), np.array(status.cpu()), censor_surv='km')
        conc = ev_ode.concordance_td('antolini')
        ibs = ev_ode.integrated_brier_score(time_grid)
        inbll = ev_ode.integrated_nbll(time_grid)
    return conc,ibs,inbll

# Metabrics

In [8]:
from sklearn_pandas import DataFrameMapper
import pandas as pd

def make_dataloader(df,Tmax,batchsize):
    cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
    cols_leave = ['x4', 'x5', 'x6', 'x7']

    standardize = [([col], StandardScaler()) for col in cols_standardize]
    leave = [(col, None) for col in cols_leave]

    x_mapper = DataFrameMapper(standardize + leave)
    X = x_mapper.fit_transform(df).astype('float32')
    
    X = torch.from_numpy(X).float().to(device)
    T = torch.from_numpy(df[["duration"]].values).float().flatten().to(device)
    T = T/Tmax
    T[T==0] = 1e-8
    E = torch.from_numpy(df[["event"]].values).float().flatten().to(device)

    Tstart = torch.from_numpy(np.array([0 for i in range(T.shape[0])])).float().to(device)
    From = torch.tensor([1],device=device).repeat((T.shape))
    To = torch.tensor([2],device=device).repeat((T.shape))
    trans = torch.tensor([1],device=device).repeat((T.shape))

    dataset = TensorDataset(X,Tstart,T,From,To,trans,E)
    loader = DataLoader(dataset, batch_size=batchsize, shuffle=True)
    return loader

In [13]:
from sklearn.model_selection import train_test_split

def odesurv_manual_benchmark(df_train, df_test,config,name):
    torch.cuda.empty_cache()
    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.loc[:,"event"])
    
    Tmax = df_train["duration"].max()
    
    train_loader = make_dataloader(df_train,Tmax/config["multiplier"],int(len(df_train)*config["batch_size"]))
    val_loader = make_dataloader(df_val,Tmax/config["multiplier"],len(df_val))
    test_loader = make_dataloader(df_test,Tmax/config["multiplier"],len(df_test))
    
    num_in = 9
    num_latent = config["num_latent"]
    layers_encoder =  [config["encoder_neurons"]]*config["num_encoder_layers"]
    dropout_encoder = [config["encoder_dropout"]]*config["num_encoder_layers"]
    layers_odefunc =  [config["odefunc_neurons"]]*config["num_odefunc_layers"]

    trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

    encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
    odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,config["softplus_beta"]).to(device)
    block = ODEBlock(odefunc).to(device)
    odesurv = SurvNODE(block,encoder).to(device)

    optimizer = torch.optim.Adam(odesurv.parameters(), weight_decay = config["weight_decay"], lr=config["lr"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["scheduler_gamma"], patience=config["scheduler_epoch"])

    early_stopping = EarlyStopping(name=name,patience=config["patience"], verbose=True)
    for i in tqdm(range(1000)):
        odesurv.train()
        for mini,ds in enumerate(train_loader):
            myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
            optimizer.zero_grad()
            myloss.backward()    
            optimizer.step()

        
        odesurv.eval()
        with torch.no_grad():
            lossval,conc,ibs,ibnll = 0., 0., 0., 0.
            for _,ds in enumerate(val_loader):
                t1,_,_ = loss(odesurv,*ds,mu=config["mu"])
                lossval += t1.item()
                t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
                conc += t1
                ibs += t2
                ibnll += t3
            early_stopping(lossval/len(val_loader), odesurv)
            scheduler.step(lossval/len(val_loader))
            
            conc_test,ibs_test,ibnll_test = 0., 0., 0.
            print("it: "+str(i)+", validation loss="+str(lossval/len(val_loader))+", c="+str(conc/len(val_loader))+", ibs="+str(ibs/len(val_loader))+", ibnll="+str(ibnll/len(val_loader)))

        if early_stopping.early_stop:
            print("Early stopping")
            break

    odesurv.load_state_dict(torch.load(name+'_checkpoint.pt'))

    odesurv.eval()
    with torch.no_grad():
        conc,ibs,ibnll = 0., 0., 0.
        for _,ds in enumerate(test_loader):
            t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
            conc += t1
            ibs += t2
            ibnll += t3
    return conc/len(test_loader), ibs/len(test_loader), ibnll/len(test_loader)

In [18]:
from sklearn.model_selection import StratifiedKFold
from pycox import datasets

kfold = StratifiedKFold(5, shuffle=True)
df_all = datasets.metabric.read_df()
gen = kfold.split(df_all.iloc[:,df_all.columns.values!="event"],df_all.loc[:,"event"])

config = {
    "lr": 1e-4,
    "weight_decay": 1e-3,
    "num_latent": 25,
    "encoder_neurons": 800,
    "num_encoder_layers": 2,
    "encoder_dropout": 0.1,
    "odefunc_neurons": 1000,
    "num_odefunc_layers": 3,
    "batch_size": 1/3,
    "multiplier": 3.,
    "mu": 1e-4,
    "softplus_beta": 1.,
    "scheduler_epoch": 50,
    "scheduler_gamma": 0.1,
    "patience": 20
}

odesurv_bench_vals = []
for g in gen:
    df_train = df_all.iloc[g[0]]
    df_test =  df_all.iloc[g[1]]
    conc, ibs, ibnll = odesurv_manual_benchmark(df_train,df_test,config,"metabrick_test")
    odesurv_bench_vals.append([conc,ibs,ibnll])

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Validation loss decreased (inf --> 0.886843).  Saving model ...
it: 0, validation loss=0.8868432641029358, c=0.6578368497691854, ibs=0.1642798065425157, ibnll=0.498292686670822
Validation loss decreased (0.886843 --> 0.879550).  Saving model ...
it: 1, validation loss=0.8795498013496399, c=0.6590141586888496, ibs=0.1642293558857981, ibnll=0.49958194464109623
Validation loss decreased (0.879550 --> 0.873906).  Saving model ...
it: 2, validation loss=0.8739055395126343, c=0.6512377234563311, ibs=0.16426601813039846, ibnll=0.5002501762430546
Validation loss decreased (0.873906 --> 0.865098).  Saving model ...
it: 3, validation loss=0.8650981187820435, c=0.6407348886203799, ibs=0.1624251734706775, ibnll=0.4957305165253631
Validation loss decreased (0.865098 --> 0.853708).  Saving model ...
it: 4, validation loss=0.8537078499794006, c=0.627691545063048, ibs=0.15906297169157135, ibnll=0.4869142390563532
Validation loss decreased (0.853708 --> 0.841377).  Saving model ...
it: 5, validation lo

EarlyStopping counter: 19 out of 20
it: 50, validation loss=0.8102954626083374, c=0.6921337175078229, ibs=0.15710639792846828, ibnll=0.47550219767116625
EarlyStopping counter: 20 out of 20
it: 51, validation loss=0.8004456162452698, c=0.6918238993710691, ibs=0.15013440134176134, ibnll=0.45619407074613494
Early stopping



HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Validation loss decreased (inf --> 0.969190).  Saving model ...
it: 0, validation loss=0.9691903591156006, c=0.6380169867060561, ibs=0.19695217075825364, ibnll=0.5743237405096879
Validation loss decreased (0.969190 --> 0.951664).  Saving model ...
it: 1, validation loss=0.9516641497612, c=0.6406326932545544, ibs=0.19180906068787804, ibnll=0.5619475653307368
Validation loss decreased (0.951664 --> 0.938198).  Saving model ...
it: 2, validation loss=0.938197910785675, c=0.6402941900541606, ibs=0.1880617959061868, ibnll=0.5531715458664861
Validation loss decreased (0.938198 --> 0.925869).  Saving model ...
it: 3, validation loss=0.9258686304092407, c=0.6354012801575578, ibs=0.18510361817857857, ibnll=0.5461399774097812
Validation loss decreased (0.925869 --> 0.914031).  Saving model ...
it: 4, validation loss=0.9140310883522034, c=0.6324778434268833, ibs=0.18269263148253603, ibnll=0.540119031461439
Validation loss decreased (0.914031 --> 0.903365).  Saving model ...
it: 5, validation loss

EarlyStopping counter: 17 out of 20
it: 50, validation loss=0.8534154295921326, c=0.6904849827671098, ibs=0.17198598286196362, ibnll=0.5093434665665905
EarlyStopping counter: 18 out of 20
it: 51, validation loss=0.8541889786720276, c=0.6878692762186115, ibs=0.17345559406606692, ibnll=0.5124162198150235
EarlyStopping counter: 19 out of 20
it: 52, validation loss=0.8682430386543274, c=0.6866999015263417, ibs=0.17774390440251786, ibnll=0.527768964546573
EarlyStopping counter: 20 out of 20
it: 53, validation loss=0.8698647618293762, c=0.6857459379615952, ibs=0.17917397360328216, ibnll=0.5279673268337748
Early stopping



HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Validation loss decreased (inf --> 0.925974).  Saving model ...
it: 0, validation loss=0.9259738922119141, c=0.6282823867361913, ibs=0.16744495257328829, ibnll=0.505397543270754
Validation loss decreased (0.925974 --> 0.913842).  Saving model ...
it: 1, validation loss=0.9138419032096863, c=0.6349642489149779, ibs=0.1656934450120007, ibnll=0.5026219763640555
Validation loss decreased (0.913842 --> 0.902492).  Saving model ...
it: 2, validation loss=0.9024922847747803, c=0.6380866144190839, ibs=0.16407102801977297, ibnll=0.49945339384006915
Validation loss decreased (0.902492 --> 0.888333).  Saving model ...
it: 3, validation loss=0.8883326649665833, c=0.6362444187716614, ibs=0.1612390228089666, ibnll=0.4925897473305442
Validation loss decreased (0.888333 --> 0.872614).  Saving model ...
it: 4, validation loss=0.8726138472557068, c=0.6355262747057171, ibs=0.1577597809511768, ibnll=0.48357102070836033
Validation loss decreased (0.872614 --> 0.857884).  Saving model ...
it: 5, validation 

EarlyStopping counter: 16 out of 20
it: 49, validation loss=0.8077951669692993, c=0.6795516283136104, ibs=0.15263318620534955, ibnll=0.46805898388930967
EarlyStopping counter: 17 out of 20
it: 50, validation loss=0.8085388541221619, c=0.6780528928716395, ibs=0.15282452749143882, ibnll=0.47017549135525577
EarlyStopping counter: 18 out of 20
it: 51, validation loss=0.8105988502502441, c=0.6769288412901614, ibs=0.15177586051466127, ibnll=0.4698440423952807
EarlyStopping counter: 19 out of 20
it: 52, validation loss=0.8149831891059875, c=0.6768351703250383, ibs=0.15557607267929938, ibnll=0.47574658550635157
EarlyStopping counter: 20 out of 20
it: 53, validation loss=0.8202458620071411, c=0.6808942454803759, ibs=0.15694789578665178, ibnll=0.47847663070005925
Early stopping



HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Validation loss decreased (inf --> 0.883227).  Saving model ...
it: 0, validation loss=0.8832269906997681, c=0.61139815249726, ibs=0.17484855283106804, ibnll=0.5226598268959193
Validation loss decreased (0.883227 --> 0.874571).  Saving model ...
it: 1, validation loss=0.8745711445808411, c=0.6167840926882731, ibs=0.1730970687156704, ibnll=0.5190862526259866
Validation loss decreased (0.874571 --> 0.866605).  Saving model ...
it: 2, validation loss=0.8666046857833862, c=0.6199154532644434, ibs=0.17134838994084223, ibnll=0.5151493836007667
Validation loss decreased (0.866605 --> 0.856726).  Saving model ...
it: 3, validation loss=0.8567256927490234, c=0.6223266009080946, ibs=0.168623848345707, ibnll=0.5083971347664455
Validation loss decreased (0.856726 --> 0.844075).  Saving model ...
it: 4, validation loss=0.8440749645233154, c=0.6246438077344606, ibs=0.16475663342694616, ibnll=0.49825757935017645
Validation loss decreased (0.844075 --> 0.833154).  Saving model ...
it: 5, validation lo

EarlyStopping counter: 17 out of 20
it: 50, validation loss=0.8597412109375, c=0.6462188821042744, ibs=0.18639451274299226, ibnll=0.544071290281756
EarlyStopping counter: 18 out of 20
it: 51, validation loss=0.8254054188728333, c=0.648254266478785, ibs=0.17547225914576578, ibnll=0.5196281598252479
EarlyStopping counter: 19 out of 20
it: 52, validation loss=0.8574323654174805, c=0.6446532018161891, ibs=0.18644548010390183, ibnll=0.5431163717446457
EarlyStopping counter: 20 out of 20
it: 53, validation loss=0.8328456282615662, c=0.6471269766713638, ibs=0.17899096251831356, ibnll=0.5258723474377706
Early stopping



HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Validation loss decreased (inf --> 0.920997).  Saving model ...
it: 0, validation loss=0.9209967255592346, c=0.5847460225008431, ibs=0.1707706373454809, ibnll=0.5122529842311946
Validation loss decreased (0.920997 --> 0.920639).  Saving model ...
it: 1, validation loss=0.9206393361091614, c=0.5893442874222127, ibs=0.17054967250163733, ibnll=0.5115441566002966
Validation loss decreased (0.920639 --> 0.919600).  Saving model ...
it: 2, validation loss=0.9196001291275024, c=0.5894975629195917, ibs=0.17014371968946895, ibnll=0.5101185053152638
Validation loss decreased (0.919600 --> 0.918890).  Saving model ...
it: 3, validation loss=0.9188899993896484, c=0.5935746911498728, ibs=0.16983432993018863, ibnll=0.5088977251952463
Validation loss decreased (0.918890 --> 0.917291).  Saving model ...
it: 4, validation loss=0.917290985584259, c=0.5968547867937831, ibs=0.16939227077640998, ibnll=0.5074904377499211
Validation loss decreased (0.917291 --> 0.913012).  Saving model ...
it: 5, validation 

Validation loss decreased (0.821971 --> 0.819276).  Saving model ...
it: 49, validation loss=0.8192763924598694, c=0.662518009870942, ibs=0.1528380616550121, ibnll=0.463059390873329
EarlyStopping counter: 1 out of 20
it: 50, validation loss=0.8198210597038269, c=0.6629778363630789, ibs=0.1546636756482239, ibnll=0.4680279871580439
EarlyStopping counter: 2 out of 20
it: 51, validation loss=0.8274588584899902, c=0.6650623831274333, ibs=0.15843708098244075, ibnll=0.47819486065775324
EarlyStopping counter: 3 out of 20
it: 52, validation loss=0.8418700695037842, c=0.6621194935777567, ibs=0.16489022098532138, ibnll=0.4937049073224103
EarlyStopping counter: 4 out of 20
it: 53, validation loss=0.8457547426223755, c=0.6598816713160234, ibs=0.1670082862698515, ibnll=0.49814121308549714
EarlyStopping counter: 5 out of 20
it: 54, validation loss=0.8321608304977417, c=0.6551914410962264, ibs=0.15903152442515037, ibnll=0.47798056855422705
EarlyStopping counter: 6 out of 20
it: 55, validation loss=0.8

In [19]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6558965225605643+-0.012243617513931153
ibs=0.16124182424646324+-0.008117888646458552
ibnll=0.48242733142488453+-0.022724151456051847


#### Tune

In [10]:
from datetime import datetime
from sklearn.model_selection import train_test_split
import os

def odesurv_benchmark(df_train, df_test, name, fold_num):
    torch.cuda.empty_cache()
    df_train, df_val = train_test_split(df_train, test_size=0.2)
    
    Tmax = df_train["duration"].max()*1.2
    
    def train_model(config):
        train_loader = make_dataloader(df_train,int(len(df_train)*config["batch_size"]))
        val_loader = make_dataloader(df_val,len(df_val))
        
        num_in = 9
        num_latent = config["num_latent"]
        layers_encoder =  [config["encoder_neurons"]]*config["num_encoder_layers"]
        dropout_encoder = [config["encoder_dropout"]]*config["num_encoder_layers"]
        layers_odefunc =  [config["odefunc_neurons"]]*config["num_odefunc_layers"]
        dropout_odefunc = []
        
        trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

        encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
        odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,dropout_odefunc,config["softplus_beta"]).to(device)
        block = ODEBlock(odefunc).to(device)
        odesurv = ODEsurv(block,encoder,Tmax,config["multiplier"]).to(device)

        optimizer = torch.optim.Adam(odesurv.parameters(), weight_decay = config["weight_decay"], lr=config["lr"])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config["scheduler_epoch"], gamma=config["scheduler_gamma"])
#         scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["scheduler_gamma"], patience=config["scheduler_epoch"])

        loss_min = np.inf
        for i in range(1000):
            odesurv.train()
            for mini,ds in enumerate(train_loader):
                myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
                optimizer.zero_grad()
                myloss.backward()    
                optimizer.step()
       
            odesurv.eval()
            with torch.no_grad():
                train_loss = 0.
                for mini,ds in enumerate(train_loader):
                    myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
                    train_loss += myloss.item()
                lossval,conc,ibs,ibnll = 0., 0., 0., 0.
                for _,ds in enumerate(val_loader):
                    t1,_,_ = loss(odesurv,*ds,mu=config["mu"])
                    lossval += t1.item()
                    t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds)
                    conc += t1
                    ibs += t2
                    ibnll += t3
                track.log(concordance=conc/len(val_loader),int_brier_score=ibs/len(val_loader),int_bin_nll = ibnll/len(val_loader), loss=lossval/len(val_loader), train_loss=train_loss/len(val_loader))
                scheduler.step(lossval/len(val_loader))
                if lossval/len(val_loader) < loss_min:
                    torch.save(odesurv.state_dict(), 'checkpoint.pt')
                    
    
    dt_string = name+"_fold_"+str(fold_num)+datetime.now().strftime("_date_%d_%m_%Y_time_%H_%M_%S")
    analysis = tune.run(train_model, name=dt_string, num_samples=4, stop={"training_iteration": 100}, config=search_space, resources_per_trial={"cpu": 1, "gpu": 1}, verbose=1, raise_on_failed_trial=False)
#     analysis = tune.run(train_model, name=dt_string, num_samples=300, scheduler=ASHAScheduler(metric="loss", mode="min"), config=search_space, resources_per_trial={"cpu": 1, "gpu": 0.25}, verbose=1, raise_on_failed_trial=False)
#     analysis = tune.run(train_model, name=dt_string, num_samples=50, scheduler=ASHAScheduler(metric="concordance", mode="max", max_t=80, grace_period=20, reduction_factor=2, brackets=3), config=search_space, resources_per_trial={"cpu": 1, "gpu": 0.25}, verbose=1, raise_on_failed_trial=False)
    best_config = analysis.get_best_config(metric="concordance", mode="max")
    logdir = analysis.get_best_logdir("concordance", mode="max")
    
    num_in = 9
    num_latent = best_config["num_latent"]
    layers_encoder =  [best_config["encoder_neurons"]]*best_config["num_encoder_layers"]
    dropout_encoder = [best_config["encoder_dropout"]]*best_config["num_encoder_layers"]
    layers_odefunc =  [best_config["odefunc_neurons"]]*best_config["num_odefunc_layers"]
    dropout_odefunc = []

    trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

    encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
    odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,dropout_odefunc,best_config["softplus_beta"]).to(device)
    block = ODEBlock(odefunc).to(device)
    odesurv = ODEsurv(block,encoder,Tmax,best_config["multiplier"]).to(device)
    
    odesurv.load_state_dict(torch.load(os.path.join(logdir, "checkpoint.pt")))

    loader = make_dataloader(df_test,len(df_test))
    odesurv.eval()
    with torch.no_grad():
        conc,ibs,ibnll = 0., 0., 0.
        for _,ds in enumerate(loader):
            t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds)
            conc += t1
            ibs += t2
            ibnll += t3
    return conc/len(loader), ibs/len(loader), ibnll/len(loader)

In [11]:
from ray.tune import track
from ray.tune.schedulers import ASHAScheduler

# search_space = {
#     "lr": tune.loguniform(5e-5,1e-3),
#     "weight_decay": tune.loguniform(1e-7,1e-2),
#     "num_latent": tune.randint(10,150),
#     "encoder_neurons": 20,
#     "num_encoder_layers": 1,
#     "encoder_dropout": 0.2,
#     "odefunc_neurons": tune.randint(100,1000),
#     "num_odefunc_layers": tune.randint(1,3),
#     "batch_size": 512,
#     "multiplier": 1.,
#     "mu": 1e-4,
#     "softplus_beta": 1.,
#     "scheduler_epoch": 10,
#     "scheduler_gamma": tune.uniform(0.1,1.)
# }

# search_space = {
#     "lr": 4e-4,
#     "weight_decay": 1e-7,
#     "num_latent": 200,
#     "encoder_neurons": 50,
#     "num_encoder_layers": 2,
#     "encoder_dropout": 0.2,
#     "odefunc_neurons": tune.grid_search([400,800]),
#     "num_odefunc_layers": tune.grid_search([1,2]),
#     "batch_size": 1/3,
#     "multiplier": 3.,
#     "mu": 0.1,
#     "softplus_beta": 1.,
#     "scheduler_epoch": 20,
#     "scheduler_gamma": 0.5
# }

search_space = {
    "lr": tune.grid_search([1e-4,2e-4]),
    "weight_decay": 1e-3,
    "num_latent": 200,
    "encoder_neurons": 50,
    "num_encoder_layers": 2,
    "encoder_dropout": 0.,
    "odefunc_neurons": 800,
    "num_odefunc_layers": 2,
    "batch_size": 1/4,
    "multiplier": 1.,
    "mu": 0.1,
    "softplus_beta": 1.,
    "scheduler_epoch": 20,
    "scheduler_gamma": 0.1
}

In [12]:
import ray
ray.shutdown()

In [13]:
from sklearn.model_selection import KFold
from pycox import datasets

kfold = KFold(5, shuffle=True)
df_all = datasets.metabric.read_df()
gen = kfold.split(df_all)


ray.init(webui_host='0.0.0.0')

odesurv_bench_vals = []
fold_num = 1
for g in gen:
    df_train = df_all.iloc[g[0]]
    df_test =  df_all.iloc[g[1]]
    conc, ibs, ibnll = odesurv_benchmark(df_train,df_test,"metabric_stratkfold_test",fold_num)
    fold_num+=1
    odesurv_bench_vals.append([conc,ibs,ibnll])

Trial name,status,loc,lr,iter,total time (s)
train_model_00000,RUNNING,10.142.0.7:5997,0.0001,0.0,22.5935
train_model_00001,RUNNING,,0.0002,,
train_model_00002,RUNNING,,0.0001,,
train_model_00003,RUNNING,,0.0002,,
train_model_00004,PENDING,,0.0001,,
train_model_00005,PENDING,,0.0002,,
train_model_00006,PENDING,,0.0001,,
train_model_00007,PENDING,,0.0002,,


KeyboardInterrupt: 

In [57]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6620480281168408+-0.014956263887575002
ibs=0.16545908680706367+-0.006979003894020888
ibnll=0.4956326341833181+-0.017867351361639697


In [188]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6672544577555329+-0.018223778302013317
ibs=0.15723787183862553+-0.013663084535737445
ibnll=0.47655865936706887+-0.04510651962063296


In [115]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6683081896353393+-0.01586286779252497
ibs=0.157976523120051+-0.007545165226842307
ibnll=0.4741855078212528+-0.02476666721566905


# ROT

In [189]:
from sklearn_pandas import DataFrameMapper
import pandas as pd

def make_dataloader(df,Tmax,batchsize):
    cols_standardize = ['x3', 'x4', 'x5', 'x6']
    cols_leave = ["x0","x1","x2"]

    standardize = [([col], StandardScaler()) for col in cols_standardize]
    leave = [(col, None) for col in cols_leave]

    x_mapper = DataFrameMapper(standardize + leave)
    X = x_mapper.fit_transform(df).astype('float32')
    
    X = torch.from_numpy(X).float().to(device)
    T = torch.from_numpy(df[["duration"]].values).float().flatten().to(device)
    Tmax = torch.tensor(Tmax).to(device)
    T = T/Tmax
    T[T==0] = 1e-8
    E = torch.from_numpy(df[["event"]].values).float().flatten().to(device)

    Tstart = torch.from_numpy(np.array([0 for i in range(T.shape[0])])).float().to(device)
    From = torch.tensor([1],device=device).repeat((T.shape))
    To = torch.tensor([2],device=device).repeat((T.shape))
    trans = torch.tensor([1],device=device).repeat((T.shape))

    dataset = TensorDataset(X,Tstart,T,From,To,trans,E)
    loader = DataLoader(dataset, batch_size=batchsize, shuffle=True)
    return loader

In [190]:
from ray.tune import track
from ray.tune.schedulers import ASHAScheduler

# search_space = {
#     "lr": tune.loguniform(5e-5,1e-3),
#     "weight_decay": tune.loguniform(1e-7,1e-2),
#     "num_latent": tune.randint(10,150),
#     "encoder_neurons": 20,
#     "num_encoder_layers": 1,
#     "encoder_dropout": 0.2,
#     "odefunc_neurons": tune.randint(100,1000),
#     "num_odefunc_layers": tune.randint(1,3),
#     "batch_size": 512,
#     "multiplier": 1.,
#     "mu": 1e-4,
#     "softplus_beta": 1.,
#     "scheduler_epoch": 10,
#     "scheduler_gamma": tune.uniform(0.1,1.)
# }

search_space = {
    "lr": tune.grid_search([2e-4,1e-4,8e-5]),
    "weight_decay": 1e-2,
    "num_latent": 50,
    "encoder_neurons": 50,
    "num_encoder_layers": 2,
    "encoder_dropout": 0.2,
    "odefunc_neurons": tune.grid_search([400,800]),
    "num_odefunc_layers": tune.grid_search([1,2]),
    "batch_size": 512,
    "multiplier": 2.,
    "mu": 1e-4,
    "softplus_beta": 1.,
    "scheduler_epoch": 10,
    "scheduler_gamma": 0.3
}

In [191]:
from datetime import datetime
from sklearn.model_selection import train_test_split
import os

def odesurv_benchmark(df_train, df_test, name):
    torch.cuda.empty_cache()
    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.loc[:,"event"])
    
    Tmax = df_train["duration"].max()*1.2
    
    def train_model(config):
        train_loader = make_dataloader(df_train,Tmax/config["multiplier"],int(len(df_train)/3))
        val_loader = make_dataloader(df_val,Tmax/config["multiplier"],len(df_val))
        
        num_in = 7
        num_latent = config["num_latent"]
        layers_encoder =  [config["encoder_neurons"]]*config["num_encoder_layers"]
        dropout_encoder = [config["encoder_dropout"]]*config["num_encoder_layers"]
        layers_odefunc =  [config["odefunc_neurons"]]*config["num_odefunc_layers"]
        dropout_odefunc = []
        
        trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

        encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
        odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,dropout_odefunc,config["softplus_beta"]).to(device)
        block = ODEBlock(odefunc).to(device)
        odesurv = ODEsurv(block,encoder).to(device)

        optimizer = torch.optim.Adam(odesurv.parameters(), weight_decay = config["weight_decay"], lr=config["lr"])
#         scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config["scheduler_epoch"], gamma=config["scheduler_gamma"])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["scheduler_gamma"], patience=config["scheduler_epoch"])

        conc_max = - np.inf
        for i in range(1000):
            odesurv.train()
            for mini,ds in enumerate(train_loader):
                myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
                optimizer.zero_grad()
                myloss.backward()    
                optimizer.step()
       
            odesurv.eval()
            with torch.no_grad():
                train_loss = 0.
                for mini,ds in enumerate(train_loader):
                    myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
                    train_loss += myloss.item()
                lossval,conc,ibs,ibnll = 0., 0., 0., 0.
                for _,ds in enumerate(val_loader):
                    t1,_,_ = loss(odesurv,*ds,mu=config["mu"])
                    lossval += t1.item()
                    t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),Tmax,*ds,multiplier=config["multiplier"])
                    conc += t1
                    ibs += t2
                    ibnll += t3
                track.log(concordance=conc/len(val_loader),int_brier_score=ibs/len(val_loader),int_bin_nll = ibnll/len(val_loader), loss=lossval/len(val_loader), train_loss=train_loss/len(val_loader))
                scheduler.step(lossval/len(val_loader))
                if conc/len(val_loader) > conc_max:
                    torch.save(odesurv.state_dict(), 'checkpoint.pt')
                    
    
    dt_string = name+datetime.now().strftime("_date_%d_%m_%Y_time_%H_%M_%S")
    analysis = tune.run(train_model, name=dt_string, num_samples=1, stop={"training_iteration": 200}, config=search_space, resources_per_trial={"cpu": 1, "gpu": 0.33}, verbose=1, raise_on_failed_trial=False)
#     analysis = tune.run(train_model, name=dt_string, num_samples=300, scheduler=ASHAScheduler(metric="loss", mode="min"), config=search_space, resources_per_trial={"cpu": 1, "gpu": 0.25}, verbose=1, raise_on_failed_trial=False)
#     analysis = tune.run(train_model, name=dt_string, num_samples=50, scheduler=ASHAScheduler(metric="concordance", mode="max", max_t=80, grace_period=20, reduction_factor=2, brackets=3), config=search_space, resources_per_trial={"cpu": 1, "gpu": 0.25}, verbose=1, raise_on_failed_trial=False)
    best_config = analysis.get_best_config(metric="concordance", mode="max")
    logdir = analysis.get_best_logdir("concordance", mode="max")
    
    num_in = 7
    num_latent = best_config["num_latent"]
    layers_encoder =  [best_config["encoder_neurons"]]*best_config["num_encoder_layers"]
    dropout_encoder = [best_config["encoder_dropout"]]*best_config["num_encoder_layers"]
    layers_odefunc =  [best_config["odefunc_neurons"]]*best_config["num_odefunc_layers"]
    dropout_odefunc = []

    trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

    encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
    odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,dropout_odefunc,best_config["softplus_beta"]).to(device)
    block = ODEBlock(odefunc).to(device)
    odesurv = ODEsurv(block,encoder).to(device)
    
    odesurv.load_state_dict(torch.load(os.path.join(logdir, "checkpoint.pt")))

    loader = make_dataloader(df_test,Tmax/best_config["multiplier"],len(df_test))
    odesurv.eval()
    with torch.no_grad():
        conc,ibs,ibnll = 0., 0., 0.
        for _,ds in enumerate(loader):
            t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),Tmax,*ds,multiplier=best_config["multiplier"])
            conc += t1
            ibs += t2
            ibnll += t3
    return conc/len(loader), ibs/len(loader), ibnll/len(loader)

In [149]:
import ray
ray.shutdown()

In [192]:
from sklearn.model_selection import KFold
from pycox import datasets

kfold = KFold(5, shuffle=True)
df_all = datasets.gbsg.read_df()
gen = kfold.split(df_all)

# ray.init(webui_host='0.0.0.0')

odesurv_bench_vals_rot = []
for g in gen:
    df_train = df_all.iloc[g[0]]
    df_test =  df_all.iloc[g[1]]
    conc, ibs, ibnll = odesurv_benchmark(df_train,df_test,"rot")
    odesurv_bench_vals_rot.append([conc,ibs,ibnll])

Trial name,status,loc,lr,num_latent,num_odefunc_layers,odefunc_neurons,scheduler_gamma,weight_decay,iter,total time (s)
train_model_00000,TERMINATED,,0.000743329,19,2,818,0.862176,3.69822e-07,100,3934.91
train_model_00001,TERMINATED,,0.000511077,111,1,989,0.595397,0.00495276,16,771.29
train_model_00002,TERMINATED,,0.000203526,136,2,233,0.713797,1.86807e-07,1,125.353
train_model_00003,TERMINATED,,0.000209287,134,1,508,0.89779,9.93566e-07,1,123.428
train_model_00004,TERMINATED,,0.000722869,25,1,547,0.603554,1.28015e-06,100,4216.59
train_model_00005,TERMINATED,,0.000790787,145,1,917,0.936546,1.20681e-07,4,255.307
train_model_00006,TERMINATED,,9.29906e-05,107,2,198,0.260943,3.31935e-07,1,105.203
train_model_00007,TERMINATED,,0.000123262,26,1,436,0.828129,1.01854e-06,1,99.3676
train_model_00008,TERMINATED,,8.4564e-05,114,2,388,0.587675,2.52835e-06,1,114.197
train_model_00009,TERMINATED,,9.5973e-05,61,1,800,0.355959,1.65136e-06,1,111.309


In [193]:
print("c="+str(np.mean(np.array(odesurv_bench_vals_rot)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals_rot)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals_rot)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals_rot)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals_rot)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals_rot)[:,2])))

c=0.6793196362142764+-0.009407195337106437
ibs=0.17484367616111518+-0.0026219616036148308
ibnll=0.5192838730250453+-0.007654624145028555


# Support

In [4]:
from datetime import datetime
from sklearn.model_selection import train_test_split
import os

def odesurv_benchmark(df_train, df_test, name, fold_num):
    torch.cuda.empty_cache()
    df_train, df_val = train_test_split(df_train, test_size=0.2)
    
    Tmax = df_train["duration"].max()*1.2
    
    def train_model(config):
        train_loader = make_dataloader(df_train,int(len(df_train)*config["batch_size"]))
        val_loader = make_dataloader(df_val,len(df_val))
        
        num_in = 14
        num_latent = config["num_latent"]
        layers_encoder =  [config["encoder_neurons"]]*config["num_encoder_layers"]
        dropout_encoder = [config["encoder_dropout"]]*config["num_encoder_layers"]
        layers_odefunc =  [config["odefunc_neurons"]]*config["num_odefunc_layers"]
        dropout_odefunc = []
        
        trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

        encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
        odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,dropout_odefunc,config["softplus_beta"]).to(device)
        block = ODEBlock(odefunc).to(device)
        odesurv = ODEsurv(block,encoder,Tmax,config["multiplier"]).to(device)

        optimizer = torch.optim.Adam(odesurv.parameters(), weight_decay = config["weight_decay"], lr=config["lr"])
#         scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config["scheduler_epoch"], gamma=config["scheduler_gamma"])
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["scheduler_gamma"], patience=config["scheduler_epoch"])

        conc_max = -np.inf
        for i in range(1000):
            odesurv.train()
            for mini,ds in enumerate(train_loader):
                myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
                optimizer.zero_grad()
                myloss.backward()    
                optimizer.step()
       
            odesurv.eval()
            with torch.no_grad():
                train_loss = 0.
                for mini,ds in enumerate(train_loader):
                    myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
                    train_loss += myloss.item()
                lossval,conc,ibs,ibnll,reg,loglik = 0., 0., 0., 0., 0., 0.
                for _,ds in enumerate(val_loader):
                    t1,t2,t3 = loss(odesurv,*ds,mu=config["mu"])
                    lossval += t1.item()
                    loglik += t2.item()
                    reg += t3.item()
                    t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds)
                    conc += t1
                    ibs += t2
                    ibnll += t3
                track.log(concordance=conc/len(val_loader),int_brier_score=ibs/len(val_loader),int_bin_nll = ibnll/len(val_loader), loss=lossval/len(val_loader), loglik = loglik/len(val_loader), reg=reg/len(val_loader), train_loss=train_loss/len(val_loader))
                scheduler.step(lossval/len(val_loader))
                if conc/len(val_loader) > conc_max:
                    torch.save(odesurv.state_dict(), 'checkpoint.pt')
                    
    
    dt_string = name+"_fold_"+str(fold_num)+datetime.now().strftime("_date_%d_%m_%Y_time_%H_%M_%S")
    analysis = tune.run(train_model, name=dt_string, num_samples=1, stop={"training_iteration": 110}, config=search_space, resources_per_trial={"cpu": 1, "gpu": 1}, verbose=1, raise_on_failed_trial=False)
#     analysis = tune.run(train_model, name=dt_string, num_samples=300, scheduler=ASHAScheduler(metric="loss", mode="min"), config=search_space, resources_per_trial={"cpu": 1, "gpu": 0.25}, verbose=1, raise_on_failed_trial=False)
#     analysis = tune.run(train_model, name=dt_string, num_samples=50, scheduler=ASHAScheduler(metric="concordance", mode="max", max_t=80, grace_period=20, reduction_factor=2, brackets=3), config=search_space, resources_per_trial={"cpu": 1, "gpu": 0.25}, verbose=1, raise_on_failed_trial=False)
    best_config = analysis.get_best_config(metric="concordance", mode="max")
    logdir = analysis.get_best_logdir("concordance", mode="max")
    
    num_in = 14
    num_latent = best_config["num_latent"]
    layers_encoder =  [best_config["encoder_neurons"]]*best_config["num_encoder_layers"]
    dropout_encoder = [best_config["encoder_dropout"]]*best_config["num_encoder_layers"]
    layers_odefunc =  [best_config["odefunc_neurons"]]*best_config["num_odefunc_layers"]
    dropout_odefunc = []

    trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

    encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
    odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc,dropout_odefunc,best_config["softplus_beta"]).to(device)
    block = ODEBlock(odefunc).to(device)
    odesurv = ODEsurv(block,encoder,Tmax,best_config["multiplier"]).to(device)
    
    odesurv.load_state_dict(torch.load(os.path.join(logdir, "checkpoint.pt")))

    loader = make_dataloader(df_test,len(df_test))
    odesurv.eval()
    with torch.no_grad():
        conc,ibs,ibnll = 0., 0., 0.
        for _,ds in enumerate(loader):
            t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds)
            conc += t1
            ibs += t2
            ibnll += t3
    return conc/len(loader), ibs/len(loader), ibnll/len(loader)

In [5]:
from sklearn_pandas import DataFrameMapper
import pandas as pd

def make_dataloader(df,batchsize):
    cols_standardize = ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
    cols_leave = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']

    standardize = [([col], StandardScaler()) for col in cols_standardize]
    leave = [(col, None) for col in cols_leave]

    x_mapper = DataFrameMapper(standardize + leave)
    X = x_mapper.fit_transform(df).astype('float32')
    
    X = torch.from_numpy(X).float().to(device)
    T = torch.from_numpy(df[["duration"]].values).float().flatten().to(device)
    T[T==0] = 1e-8
    E = torch.from_numpy(df[["event"]].values).float().flatten().to(device)

    Tstart = torch.from_numpy(np.array([0 for i in range(T.shape[0])])).float().to(device)
    From = torch.tensor([1],device=device).repeat((T.shape))
    To = torch.tensor([2],device=device).repeat((T.shape))
    trans = torch.tensor([1],device=device).repeat((T.shape))

    dataset = TensorDataset(X,Tstart,T,From,To,trans,E)
    loader = DataLoader(dataset, batch_size=batchsize, shuffle=True)
    return loader

In [6]:
from ray.tune import track
from ray.tune.schedulers import ASHAScheduler

# search_space = {
#     "lr": tune.grid_search([6e-4,8e-4,1e-3,2e-3]),
#     "weight_decay": 1e-4,
#     "num_latent": 5,
#     "encoder_neurons": 50,
#     "num_encoder_layers": 1,
#     "encoder_dropout": 0.,
#     "odefunc_neurons": 800,
#     "num_odefunc_layers": 2,
#     "batch_size": 1/16,
#     "multiplier": 2.,
#     "mu": 0.1,
#     "softplus_beta": 1.,
#     "scheduler_epoch": 100,
#     "scheduler_gamma": 0.1
# }

search_space = {
    "lr": tune.grid_search([1e-4,5e-4,1e-3]),
    "weight_decay": tune.grid_search([1e-7,1e-5,1e-3]),
    "num_latent": 200,
    "encoder_neurons": 400,
    "num_encoder_layers": 2,
    "encoder_dropout": 0.,
    "odefunc_neurons": tune.grid_search([400,1000]),
    "num_odefunc_layers": tune.grid_search([2,4]),
    "batch_size": 1/16,
    "multiplier": 2.,
    "mu": 1e-4,
    "softplus_beta": 1.,
    "scheduler_epoch": 20,
    "scheduler_gamma": 0.1
}

In [7]:
import ray
ray.shutdown()

In [None]:
from sklearn.model_selection import KFold
from pycox import datasets

kfold = KFold(5, shuffle=True)
df_all = datasets.support.read_df()
gen = kfold.split(df_all)


ray.init(webui_host='0.0.0.0')

odesurv_bench_vals = []
fold_num = 1
for g in gen:
    df_train = df_all.iloc[g[0]]
    df_test =  df_all.iloc[g[1]]
    conc, ibs, ibnll = odesurv_benchmark(df_train,df_test,"support",fold_num)
    fold_num+=1
    odesurv_bench_vals.append([conc,ibs,ibnll])

Trial name,status,loc,lr,num_odefunc_layers,odefunc_neurons,weight_decay,iter,total time (s)
train_model_00010,ERROR,,0.0005,4,1000,1e-07,18.0,2266.83
train_model_00017,PENDING,,0.001,4,400,1e-05,,
train_model_00018,PENDING,,0.0001,2,1000,1e-05,,
train_model_00019,PENDING,,0.0005,2,1000,1e-05,,
train_model_00020,PENDING,,0.001,2,1000,1e-05,,
train_model_00021,PENDING,,0.0001,4,1000,1e-05,,
train_model_00022,PENDING,,0.0005,4,1000,1e-05,,
train_model_00023,PENDING,,0.001,4,1000,1e-05,,
train_model_00024,PENDING,,0.0001,2,400,0.001,,
train_model_00013,RUNNING,10.142.0.7:19457,0.0005,2,400,1e-05,70.0,7477.84

Trial name,# failures,error file
train_model_00010,1,"/home/jupyter/ray_results/support_fold_1_date_30_05_2020_time_05_40_31/train_model_10_lr=0.0005,num_odefunc_layers=4,odefunc_neurons=1000,weight_decay=1e-07_2020-05-30_12-58-21tln5yyf8/error.txt"


In [None]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

In [91]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6153642201152687+-0.00585844798470677
ibs=0.19639579818750258+-0.003522886894064154
ibnll=0.5746934067995662+-0.008562732425823203


In [103]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6132351299915724+-0.010722358185681976
ibs=0.19424970554238233+-0.0031933519251837066
ibnll=0.5706360183291317+-0.007300024663354327


In [116]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.618216622635997+-0.006597502544798807
ibs=0.1935314575770763+-0.004279757496571185
ibnll=0.5689731250634114+-0.009453219670385243


In [24]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.6102313589577636+-0.0038118403064905135
ibs=0.19567585951137095+-0.002301237997946347
ibnll=0.5741010303897506+-0.00575819840005718


In [57]:
print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

c=0.622009475948327+-0.009656146122673258
ibs=0.19717123718515026+-0.00618661731145133
ibnll=0.5829257759083994+-0.019397555262266865
