## DRSA : Deep Recurrent Survival Analysis
- paper : https://arxiv.org/pdf/1809.02403.pdf

In [1]:
import sys, os
sys.path.append(os.path.abspath('..'))

In [2]:
from drsa.functions import event_time_loss, event_rate_loss
from drsa.model import DRSA
import torch
import torch.nn as nn
import torch.optim as optim
import pycox.datasets as dt
import numpy as np

In [3]:
# For preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper 
from torch.utils.data import Dataset
from pycox.preprocessing import label_transforms

In [4]:
class DRSA_Dataset(Dataset):
    def __init__(self, name, phase='train'):
        super().__init__()
        if name == 'support':
            cols_standardize =  ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
            cols_leave = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']
            max_seq_len = 25 # will be modified
        else:
            raise NotImplementedError
                        
        self.max_seq_len = max_seq_len
        self.cols_standardize = cols_standardize
        self.cols_leave = cols_leave
        
        cols_tgt = ['duration', 'event']
            
        df_full = eval(f'dt.{name}.read_df()')
        df_test = df_full.sample(frac=0.3)
        df_train = df_full.drop(df_test.index)
        df_val = df_train.sample(frac=0.1)
        df_train = df_train.drop(df_val.index)
        
        # Target info
        y_train = df_train[cols_tgt].values
        y_val = df_val[cols_tgt].values
        y_test = df_test[cols_tgt].values
        
        # Target label preprocessing
        target_parse_fn = DRSA_Dataset.get_target
        labtrans = label_transforms.LabTransDiscreteTime(max_seq_len)
        y_train  = np.c_[labtrans.fit_transform(*target_parse_fn(y_train))]
        y_val    = np.c_[labtrans.transform(*target_parse_fn(y_val))]
        y_test    = np.c_[labtrans.transform(*target_parse_fn(y_test))]
        
        # Input Covariates
        df_train = df_train.drop(cols_tgt, axis=1)
        df_val = df_val.drop(cols_tgt, axis=1)
        df_test = df_test.drop(cols_tgt, axis=1)
        
        # Input data preprocessing
        standardize = [([col], StandardScaler()) for col in cols_standardize]
        leave = [(col, None) for col in cols_leave]
        
        # assume categorical columns located first (for embedding later)
        x_mapper = DataFrameMapper(leave + standardize)
        x_train = x_mapper.fit_transform(df_train).astype('float32')
        x_val = x_mapper.transform(df_val).astype('float32')
        x_test = x_mapper.transform(df_test).astype('float32')
        
        if phase == 'train':
            self.X, self.Y = x_train, y_train
        if phase == 'val':
            self.X, self.Y = x_val, y_val
        if phase == 'test':
            self.X, self.Y = x_test, y_test
        
        setattr(self, 'labtrans', labtrans)
                    
    @property
    def numeric_columns(self):
        return self.cols_standardize
    
    @property
    def categorical_columns(self):
        return self.cols_leave
    
    @property
    def n_features(self):
        return self.X.shape[1]
    
    @property
    def n_embeddings(self):
        return [ len(np.unique(self.X[:, ix])) for ix in range(len(self.categorical_columns)) ]
    
    @staticmethod
    def get_target(y):
        durations, events = y[:, 0], y[:, 1]
        return durations, events
        
    def __getitem__(self, ix):
        x = torch.from_numpy(self.X[ix])
        y = torch.from_numpy(self.Y[ix])
        
        # Add time features on tiled covariates
        xs = x.tile(self.max_seq_len, 1)
        t = torch.arange(self.max_seq_len).unsqueeze(1)
        xs = torch.cat((xs, t), dim=1)
        
        return xs, y
    
    def __len__(self):
        return len(self.X)

In [5]:
from torch.utils.data import DataLoader

name = 'support'

train_ds = DRSA_Dataset(name, 'train')
val_ds = DRSA_Dataset(name, 'val')
test_ds = DRSA_Dataset(name, 'test')

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=128, shuffle=False)

## instantiating embedding parameters

In [6]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [7]:
embedding_size = 5

get_embeddings = lambda n_embeddings: [torch.nn.Embedding(nb_emb, embedding_size, device=device) for nb_emb in n_embeddings ]
embeddings = get_embeddings(train_ds.n_embeddings)

## instantiating model

In [8]:
import pandas as pd
import torchtuples as tt
from pycox import models
from pycox.models.utils import pad_col
from pycox.models.interpolation import InterpolatePMF
from pycox.evaluation.concordance import concordance_td
from pycox import utils

In [25]:
class PyCoxWrapper(tt.Model):
    """Wrapper class for pycox API compatibility
    """
    _steps ='post'
    def __init__(self, net, loss=None, optimizer=None, device=None, duration_index=None):
        self.duration_index = duration_index
        super().__init__(net, loss, optimizer, device)

    @property
    def duration_index(self):
        """
        Array of durations that defines the discrete times. This is used to set the index
        of the DataFrame in `predict_surv_df`.
        
        Returns:
            np.array -- Duration index.
        """
        return self._duration_index

    @duration_index.setter
    def duration_index(self, val):
        self._duration_index = val
        
    @property
    def index_surv(self):
        return self.surv.index.values

    @property
    def steps(self):
        """How to handle predictions that are between two indexes in `index_surv`.

        For a visualization, run the following:
            ev = EvalSurv(pd.DataFrame(np.linspace(1, 0, 7)), np.empty(7), np.ones(7), steps='pre')
            ax = ev[0].plot_surv()
            ev.steps = 'post'
            ev[0].plot_surv(ax=ax, style='--')
            ax.legend(['pre', 'post'])
        """
        return self._steps

    @steps.setter
    def steps(self, steps):
        vals = ['post', 'pre']
        if steps not in vals:
            raise ValueError(f"`steps` needs to be {vals}, got {steps}")
        self._steps = steps


    def idx_at_times(self, times):
        """Get the index (iloc) of the `surv.index` closest to `times`.
        I.e. surv.loc[tims] (almost)= surv.iloc[idx_at_times(times)].

        Useful for finding predictions at given durations.
        """
        return utils.idx_at_times(self.index_surv, times, self.steps)
    
    def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
                     num_workers=0):
        pmf = self.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers)
        surv = 1 - pmf.cumsum(1)
        return tt.utils.array_or_tensor(surv, numpy, input)

    def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
                    num_workers=0):
        preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
        pmf = pad_col(preds).softmax(1)[:, :-1]
        return tt.utils.array_or_tensor(pmf, numpy, input)

    def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):
        surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers)
        return pd.DataFrame(surv.transpose(), self.duration_index)

    def interpolate(self, sub=10, scheme='const_pdf', duration_index=None):
        """Use interpolation for predictions.
        There are only one scheme:
            `const_pdf` and `lin_surv` which assumes pice-wise constant pmf in each interval (linear survival).
        
        Keyword Arguments:
            sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with
                10 times the number of grid points than the original `duration_index` (default: {10}).
            scheme {str} -- Type of interpolation {'const_hazard', 'const_pdf'}.
                See `InterpolateDiscrete` (default: {'const_pdf'})
            duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation,
                only for setting index in `predict_surv_df` (default: {None})
        
        Returns:
            [InterpolationPMF] -- Object for prediction with interpolation.
        """
        if duration_index is None:
            duration_index = self.duration_index
        return InterpolatePMF(self, scheme, duration_index, sub)
    
    def concordance_td(self, preds, target, method='adj_antolini'):
        """Time dependent concorance index from
        Antolini, L.; Boracchi, P.; and Biganzoli, E. 2005. A time-dependent discrimination
        index for survival data. Statistics in Medicine 24:3927–3944.

        If 'method' is 'antolini', the concordance from Antolini et al. is computed.
    
        If 'method' is 'adj_antolini' (default) we have made a small modifications
        for ties in predictions and event times.
        We have followed step 3. in Sec 5.1. in Random Survival Forests paper, except for the last
        point with "T_i = T_j, but not both are deaths", as that doesn't make much sense.
        See 'metrics._is_concordant'.

        Keyword Arguments:
            method {str} -- Type of c-index 'antolini' or 'adj_antolini' (default {'adj_antolini'}).

        Returns:
            float -- Time dependent concordance index.
        """
        pmf = pad_col(preds.squeeze(2)).softmax(1)[:, :-1]
        pmf = tt.utils.array_or_tensor(pmf, True, pmf)
        
        surv = 1 - pmf.cumsum(1)
        self.surv = pd.DataFrame(surv.transpose(), self.duration_index)
        
        durations, events = np.split(target.cpu().numpy(), 2, axis=1)
        durations = durations.ravel()
        events = events.ravel()
        
        duration_idx = self.idx_at_times(durations) # need to debug
        
        from IPython.core.debugger import set_trace
        set_trace()

        return concordance_td(durations, events, self.surv.values,
                              duration_idx, method)

In [26]:
net = DRSA(
    n_features=train_ds.n_features + 1,  # +1 for time features
    hidden_dim=2,
    n_layers=1,
    embeddings=embeddings,
)
net.to(device)

DRSA(
  (lstm): LSTM(39, 2, batch_first=True)
  (fc): Linear(in_features=2, out_features=1, bias=True)
  (linear_dropout): Dropout(p=0.0, inplace=False)
  (sigmoid): Sigmoid()
  (params_to_train): ModuleList(
    (0): Embedding(2, 5)
    (1): Embedding(10, 5)
    (2): Embedding(6, 5)
    (3): Embedding(2, 5)
    (4): Embedding(2, 5)
    (5): Embedding(3, 5)
  )
)

In [27]:
class LossDRSA(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        assert (alpha >= 0) and (alpha <= 1), 'Need `alpha` in [0, 1].'
        self.alpha = alpha
        
    def forward(self, preds, target=None):
        # weighted average of event_time_loss and event_rate_loss
        evt_loss = event_time_loss(preds)
        evr_loss = event_rate_loss(preds)
        loss = (self.alpha * evt_loss) + ((1 - self.alpha) * evr_loss)
        return loss    

In [28]:
labtrans = train_ds.labtrans

model = PyCoxWrapper(net, LossDRSA(), tt.optim.Adam(lr=1e-3), 
                     duration_index=labtrans.cuts)

In [29]:
log = model.fit_dataloader(train_loader, 
                           epochs=50, 
                           val_dataloader=test_loader, 
                           metrics={'Ctd': model.concordance_td}
                          )

> [1;32mc:\users\youhs\appdata\local\temp\ipykernel_17804\3556093577.py[0m(127)[0;36mconcordance_td[1;34m()[0m

ipdb> durations
array([ 3.,  7., 13.,  1.,  1.,  6., 20.,  2.,  9.,  1.,  6., 10., 15.,
        1.,  9.,  3.,  5.,  3.,  2.,  5.,  1.,  1., 17., 12.,  3.,  1.,
        7.,  3.,  5., 19.,  4.,  1.,  1.,  1.,  1.,  3.,  9.,  9., 12.,
       17., 20.,  7., 10.,  6., 17.,  5.,  7., 22.,  1., 23.,  1.,  6.,
        2., 11.,  7.,  2.,  3., 15.,  2.,  4.,  4., 16.,  6.,  1.,  4.,
        1.,  4., 18.,  1., 18.,  5.,  5.,  9., 21.,  6.,  6.,  2.,  1.,
        1., 20., 15.,  1., 18.,  1., 11.,  1.,  5.,  1.,  1.,  5.,  1.,
        1., 12.,  2.,  1.,  1.,  6.,  2.,  4.,  2.,  1.,  1.,  1., 20.,
        1.,  6., 18.,  1.,  5.,  6.,  8.,  1.,  1.,  9.,  2.,  9.,  7.,
       23.,  1.,  1.,  5.,  5.,  1.,  2., 22.,  1.,  1.,  8.])
ipdb> events
array([1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0., 1., 0.,
       1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1

In [None]:
log.plot()