# DSM on SUPPORT Dataset

The SUPPORT dataset comes from the Vanderbilt University study
to estimate survival for seriously ill hospitalized adults.
(Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
for the original datasource.)

In this notebook, we will apply Deep Survival Machines for survival prediction on the SUPPORT data.

### Load the SUPPORT Dataset

The package includes helper functions to load the dataset.

X represents an np.array of features (covariates),
T is the event/censoring times and,
E is the censoring indicator.

In [7]:
import sys
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

sys.path.append('../')
from auton_survival import datasets
outcomes, features = datasets.load_support()

In [2]:
from auton_survival.preprocessing import Preprocessor
cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 
	     'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 
             'glucose', 'bun', 'urine', 'adlp', 'adls']

features = Preprocessor().fit_transform(features, cat_feats=cat_feats, num_feats=num_feats)

### Compute horizons at which we evaluate the performance of DSM

Survival predictions are issued at certain time horizons. Here we will evaluate the performance
of DSM to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis.

In [3]:
import numpy as np
horizons = [0.25, 0.5, 0.75]
times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()

### Splitting the data into train, test and validation sets

We will train DSM on 70% of the Data, use a Validation set of 10% for Model Selection and report performance on the remaining 20% held out test set.

In [5]:
x, t, e = features.values, outcomes.time.values, outcomes.event.values

n = len(x)

tr_size = int(n*0.70)
vl_size = int(n*0.10)
te_size = int(n*0.20)

x_train, x_test, x_val = x[:tr_size], x[-te_size:], x[tr_size:tr_size+vl_size]
t_train, t_test, t_val = t[:tr_size], t[-te_size:], t[tr_size:tr_size+vl_size]
e_train, e_test, e_val = e[:tr_size], e[-te_size:], e[tr_size:tr_size+vl_size]
processed_data = self._preprocess_training_data(x_train, t_train, e_train,
                                                    vsize, val_data,
                                                    self.random_seed)
x_train, t_train, e_train, x_val, t_val, e_val = processed_data

In [6]:
print(x_train.shape)
L=x_train.shape[0]

(6373, 38)


### Setting the parameter grid

Lets set up the parameter grid to tune hyper-parameters. We will tune the number of underlying survival distributions, 
($K$), the distribution choices (Log-Normal or Weibull), the learning rate for the Adam optimizer between $1\times10^{-3}$ and $1\times10^{-4}$ and the number of hidden layers between $0, 1$ and $2$.

In [6]:
# from sklearn.model_selection import ParameterGrid

In [7]:
# param_grid = {'k' : [3, 4, 6],
#               'distribution' : ['LogNormal', 'Weibull'],
#               'learning_rate' : [ 1e-4, 1e-3],
#               'layers' : [ [], [100], [100, 100] ]
#              }
# params = ParameterGrid(param_grid)

In [None]:
latent_dim = 10
categorical_dim = 4
temp_min = 0.5
ANNEAL_RATE = 0.00003

### Model Training and Selection

In [8]:
# from auton_survival.models.dsm import DeepSurvivalMachines


In [None]:
def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    if args.cuda:
        U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature, hard=False):
    """
    ST-gumple-softmax
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    
    if not hard:
        return y.view(-1, latent_dim * categorical_dim)

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard.view(-1, latent_dim * categorical_dim)

In [None]:
class VAE_gumbel(nn.Module):
    def __init__(self, temp):
        super(VAE_gumbel, self).__init__()

        self.fc1 = nn.Linear(38, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)

        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 38)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))

    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))

    def forward(self, x, temp, hard):
        q = self.encode(x)
        q_y = q.view(q.size(0), latent_dim, categorical_dim)
        z = gumbel_softmax(q_y, temp, hard)
        return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())

In [9]:
# models = []
# for param in params:
#     model = DeepSurvivalMachines(k = param['k'],
#                                  distribution = param['distribution'],
#                                  layers = param['layers'])
#     # The fit method is called to train the model
#     model.fit(x_train, t_train, e_train, iters = 100, learning_rate = param['learning_rate'])
#     models.append([[model.compute_nll(x_val, t_val, e_val), model]])
# best_model = min(models)
# model = best_model[0][1]
model = VAE_gumbel(1.0)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def loss_function(recon_x, x, qy):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False) / x.shape[0]

    log_ratio = torch.log(qy * categorical_dim + 1e-20)
    KLD = torch.sum(qy * log_ratio, dim=-1).mean()

    return BCE + KLD

def train(epoch,x_train):
    model.train()
    train_loss = 0
    temp = 1.0
    bs=128
    nbatches = int(L/bs)+1
    for batch_idx in range(nbatches):
        xb = x_train[batch_idx*bs:(batch_idx+1)*bs]
        if xb.shape[0] == 0:
            continue
        optimizer.zero_grad()
        recon_batch, qy = model(xb, temp, False)
        loss = loss_function(recon_batch, data, qy)
        loss.backward()
        train_loss += loss.item() * len(data)
        optimizer.step()
        if batch_idx % 10 == 1:
            temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss /B )

for epoch in range(1, 10 + 1):
        train(epoch)



 12%|█▏        | 1190/10000 [00:01<00:13, 674.38it/s]
100%|██████████| 100/100 [00:08<00:00, 11.88it/s]
 12%|█▏        | 1190/10000 [00:01<00:11, 736.51it/s]
 59%|█████▉    | 59/100 [00:05<00:03, 11.64it/s]
 12%|█▏        | 1190/10000 [00:01<00:11, 737.15it/s]
100%|██████████| 100/100 [00:10<00:00,  9.63it/s]
 12%|█▏        | 1190/10000 [00:01<00:12, 733.51it/s]
 23%|██▎       | 23/100 [00:02<00:08,  9.04it/s]
 12%|█▏        | 1190/10000 [00:01<00:11, 734.29it/s]
100%|██████████| 100/100 [00:12<00:00,  7.81it/s]
 12%|█▏        | 1190/10000 [00:01<00:11, 742.63it/s]
 11%|█         | 11/100 [00:01<00:12,  6.94it/s]
 12%|█▏        | 1190/10000 [00:01<00:11, 736.05it/s]
100%|██████████| 100/100 [00:09<00:00, 10.46it/s]
 12%|█▏        | 1190/10000 [00:01<00:11, 740.17it/s]
 59%|█████▉    | 59/100 [00:05<00:04, 10.22it/s]
 12%|█▏        | 1190/10000 [00:01<00:12, 715.29it/s]
 93%|█████████▎| 93/100 [00:11<00:00,  8.37it/s]
 12%|█▏        | 1190/10000 [00:01<00:12, 725.80it/s]
 23%|██▎       

### Inference

In [10]:
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)

### Evaluation

We evaluate the performance of DSM in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score.

In [11]:
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

In [12]:
cis = []
brs = []

et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                 dtype = [('e', bool), ('t', float)])
et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                 dtype = [('e', bool), ('t', float)])
et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                 dtype = [('e', bool), ('t', float)])

for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])
brs.append(brier_score(et_train, et_test, out_survival, times)[1])
roc_auc = []
for i, _ in enumerate(times):
    roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])
for horizon in enumerate(horizons):
    print(f"For {horizon[1]} quantile,")
    print("TD Concordance Index:", cis[horizon[0]])
    print("Brier Score:", brs[0][horizon[0]])
    print("ROC AUC ", roc_auc[horizon[0]][0], "\n")

For 0.25 quantile,
TD Concordance Index: 0.7654588597145041
Brier Score: 0.11137220428760089
ROC AUC  0.7726522677974235 

For 0.5 quantile,
TD Concordance Index: 0.7028085372828209
Brier Score: 0.18272355422012257
ROC AUC  0.7233134711382971 

For 0.75 quantile,
TD Concordance Index: 0.6598328655895858
Brier Score: 0.2213274872450867
ROC AUC  0.7150287756709275 

