In [1]:
import sys
# import viz
import torch
from torch import nn
import survival_analysis
import numpy as np
import pandas as pd
import network
from torch.utils.data import TensorDataset, Dataset
import torch.utils.data.dataloader as dataloader
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

# event_col is the header in the df that represents the 'Event / Status' indicator
# time_col is the header in the df that represents the event time
def dataframe_to_deepsurv_ds(df, event_col = 'fstat', time_col = 'lenfol'):
    # Extract the event and time columns as numpy arrays
    e = df[event_col].values.astype(np.int32)
    t = df[time_col].values.astype(np.float32)

    # Extract the patient's covariates as a numpy array
    x_df = df.drop([event_col, time_col], axis = 1)
    x = x_df.values.astype(np.float32)
    
    # Return the deep surv dataframe
    return {
        'x' : x,
        'e' : e,
        't' : t
    }

class MyTrainDataset(Dataset):
    
    def __init__(self, dataframe, standardize=False):
        self.train_df = dataframe
#         print(len(self.train_df.index))
        # If the headers of the csv change, you can replace the values of 
        # 'event_col' and 'time_col' with the names of the new headers
        # You can also use this function on your training dataset, validation dataset, and testing dataset
        train_data = dataframe_to_deepsurv_ds(self.train_df, event_col = 'fstat', time_col= 'lenfol')

        self.x, self.e, self.t = train_data['x'], train_data['e'], train_data['t']
        
        if standardize:
            offset = self.x.mean(axis = 0)
            scale = self.x.std(axis = 0)
            self.x = (self.x - offset) / scale
        
        # Sort Training Data for Accurate Likelihood
        sort_idx = np.argsort(self.t)[::-1]
        self.x = self.x[sort_idx]
        self.e = self.e[sort_idx]
        self.t = self.t[sort_idx]
        
        self.processed_count = 1
                
    def __len__(self):
        return len(self.train_df.index)
    
    def __getitem__(self, i):
        self.processed_count += 1
        return self.x[i], self.e[i], self.t[i]
#         return (torch.from_numpy(x), torch.from_numpy(e), torch.from_numpy(t))

ds = pd.read_csv('whas500.csv',sep=',')
train = ds[:400]
validation = ds[400:]
train_ds = MyTrainDataset(train,True)
validation_ds = MyTrainDataset(validation,True)

train_loader = dataloader.DataLoader(train_ds, shuffle=False, batch_size=len(train_ds))
validation_loader = dataloader.DataLoader(validation_ds, shuffle=False, batch_size=len(validation_ds))
print(len(train_ds),len(validation_ds))

400 100


# Transform the dataset to "DeepSurv" format
DeepSurv expects a dataset to be in the form:

    {
        'x': numpy array of float32
        'e': numpy array of int32
        't': numpy array of float32
        'hr': (optional) numpy array of float32
    }
    
You are providing me a csv, which I read in as a pandas dataframe. Then I convert the pandas dataframe into the DeepSurv dataset format above. 

In [11]:
def init_weights(m):
    if type(m) == nn.Linear:
        #torch.nn.init.xavier_normal_(m.weight.data)
        m.weight.data.fill_(0)
def init_weights_for_cox(m):
    if type(m) == nn.Linear:
        m.weight.data.fill_(0)
        m.bias.data.fill_(0)

In [18]:
n_epochs = 50
L2_reg = 5e-05
batch_norm = True
dropout = 0.147
hidden_layers_sizes = [48, 48]
learning_rate = 1e-03
lr_decay = 6.494e-4
momentum = 0.863
n_in = train_ds.x.shape[1]
standardize = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("DeepSurv model")
my_network = network.DeepSurv(n_in, hidden_layers_sizes=hidden_layers_sizes, dropout=dropout, batch_norm=batch_norm, momentum=0.1)
my_network.apply(init_weights)
# network.load_state_dict(torch.load("model_99.pt"))

optimizer = optimizer = torch.optim.Adam(my_network.parameters(), lr=learning_rate, weight_decay=L2_reg)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1,gamma=lr_decay,last_epoch=-1)
my_network.train()

my_network.to(device)

# If you have validation data, you can add it as the valid_dataloader parameter to the function
metrics = survival_analysis.train(my_network, train_loader,device,optimizer,exp_lr_scheduler,validation_loader, n_epochs,True)
print()
# print(my_network.layers[0].weight)
# print(my_network.layers[0].bias)

###################################################################################################################

# For CPH, set cox argument as True
print("CPH model")
my_network = network.DeepSurv(n_in, hidden_layers_sizes=[], dropout=dropout, batch_norm=batch_norm, momentum=0.1, cox=True)
my_network.apply(init_weights_for_cox)
# network.load_state_dict(torch.load("model_99.pt"))

optimizer = optimizer = torch.optim.SGD(my_network.parameters(), lr=learning_rate, momentum=momentum, weight_decay=L2_reg, nesterov=True)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1,gamma=lr_decay,last_epoch=-1)
my_network.train()

my_network.to(device)

# If you have validation data, you can add it as the valid_dataloader parameter to the function
metrics = survival_analysis.train(my_network, train_loader, device, optimizer, exp_lr_scheduler, validation_loader, n_epochs,True)
print()

print("Done")

DeepSurv model
neg_likelihood:  tensor(957.1607, device='cuda:0', grad_fn=<NegBackward>)
957.1607055664062
Finished Training with 1 iterations in 0.04s
neg_likelihood:  tensor(957.1608, device='cuda:0', grad_fn=<NegBackward>)
957.1608276367188
Finished Training with 2 iterations in 0.07s
neg_likelihood:  tensor(957.1608, device='cuda:0', grad_fn=<NegBackward>)
957.1608276367188
Finished Training with 3 iterations in 0.11s
neg_likelihood:  tensor(957.1608, device='cuda:0', grad_fn=<NegBackward>)
957.1607666015625
Finished Training with 4 iterations in 0.15s
neg_likelihood:  tensor(957.1608, device='cuda:0', grad_fn=<NegBackward>)
957.1607666015625
neg_likelihood:  tensor(174.3156, device='cuda:0', grad_fn=<NegBackward>)
valid_loss:  tensor(174.3156, device='cuda:0', grad_fn=<NegBackward>)
Finished Training with 5 iterations in 0.22s
neg_likelihood:  tensor(957.1608, device='cuda:0', grad_fn=<NegBackward>)
957.1607666015625
Finished Training with 6 iterations in 0.25s
neg_likelihood:  te

907.8980712890625
neg_likelihood:  tensor(161.9753, device='cuda:0', grad_fn=<NegBackward>)
valid_loss:  tensor(161.9753, device='cuda:0', grad_fn=<NegBackward>)
Finished Training with 10 iterations in 0.20s
neg_likelihood:  tensor(907.8981, device='cuda:0', grad_fn=<NegBackward>)
907.8980712890625
Finished Training with 11 iterations in 0.21s
neg_likelihood:  tensor(907.8981, device='cuda:0', grad_fn=<NegBackward>)
907.8980712890625
Finished Training with 12 iterations in 0.23s
neg_likelihood:  tensor(907.8981, device='cuda:0', grad_fn=<NegBackward>)
907.8980712890625
Finished Training with 13 iterations in 0.25s
neg_likelihood:  tensor(907.8981, device='cuda:0', grad_fn=<NegBackward>)
907.8980712890625
Finished Training with 14 iterations in 0.27s
neg_likelihood:  tensor(907.8981, device='cuda:0', grad_fn=<NegBackward>)
907.8980712890625
neg_likelihood:  tensor(161.9753, device='cuda:0', grad_fn=<NegBackward>)
valid_loss:  tensor(161.9753, device='cuda:0', grad_fn=<NegBackward>)
Fini

In [19]:
# Print the final metrics
print('Train C-Index:', metrics['c-index'])
# print('Valid C-Index: ',metrics['valid_c-index'][-1])

# Plot the training / validation curves
# viz.plot_log(metrics)

Train C-Index: [0.5, 0.7602671593008384, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911, 0.7603077610182911]