In [16]:
import torch
CUDA = torch.cuda.is_available()
print(CUDA)

False


In [17]:
import time
from lifelines.utils import concordance_index 
import sys
from torch import nn
import survival_analysis_chirag
import numpy as np
import pandas as pd
import network
from torch.utils.data import TensorDataset, Dataset
import torch.utils.data.dataloader as dataloader
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline


ds = pd.read_csv('./datasets/whas500.csv',sep=',')
train = ds[:400]
validation = ds[400:]

x = train[['age', 'gender', 'bmi', 'chf', 'miord']]
e = train['fstat']
t = train['lenfol']

x = torch.from_numpy(x.as_matrix()).float()
e = torch.from_numpy(e.as_matrix()).float()
t = torch.from_numpy(t.as_matrix())

In [18]:
if CUDA:
    x = x.cuda()
    e = e.cuda()
    t = t.cuda()

In [19]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight.data)
#         m.weight.data.fill_(0)
#         m.bias.data.fill_(1)

def init_weights_for_cox(m):
    if type(m) == nn.Linear:
        m.weight.data.fill_(0)
        m.bias.data.fill_(0)

In [44]:
def train(model, x, e, t, CUDA, optimizer, n_epochs):

    # Initialize Metrics
    c_index = []
    valid_c_index = []

    risk_set = []
    for i in range(len(t)):
        temp = []
        for j in range(len(t)):
            if t[j] >= t[i]:
                temp.append(j)
        risk_set.append(temp)
        
    start = time.time()
    for epoch in range(n_epochs):

        optimizer.zero_grad()
        # print("x: ", x)
        outputs = model(x)

        loss = negative_log_likelihood(outputs, e, risk_set, CUDA)
        loss.backward()
        optimizer.step()


        print(loss.cpu().data.numpy())
        
        ci_train = get_concordance_index(outputs, t, e)
        c_index.append(ci_train)
        torch.cuda.empty_cache()
                   
        print('Finished Training with %d iterations in %0.2fs' % (epoch + 1, time.time() - start))
    
    metrics = {}
    metrics['c-index'] = c_index
    return metrics


In [77]:
def negative_log_likelihood(risk, E, risk_set, CUDA):
    
    import numpy as np
    
#     new_risk = []
#     for i in range(len(risk_set)):
#         new_risk.append(risk[risk_set[i]])
        
#     log_risk = []
#     for i in range(len(new_risk)):
#         temp = torch.logsumexp(new_risk[i], 0)
#         log_risk.append(temp)

    lnumerator = risk
    
    idxs = range(risk.shape[0])
    
    
    ldenominator = []
    
    for i in range(len(idxs)):
        
        ldenominator.append(torch.logsumexp(risk[risk_set[i]], dim=0))
    
    ldenominator = torch.cat(ldenominator).reshape(-1, 1)
    
    
    likelihoods = lnumerator - ldenominator
    
    E =  np.where(E.cpu().data.numpy()==1)[0]
    

    
    neg_likelihood = - torch.sum(likelihoods[E])
    


    return neg_likelihood


In [78]:
def get_concordance_index(x, t, e, **kwargs):
    x = x.detach().cpu().numpy()
    t = t.detach().cpu().numpy()
    e = e.detach().cpu().numpy()
    computed_hazard = np.exp(x)

    return concordance_index(t,-1*computed_hazard,e)


In [101]:

# For CPH, set cox argument as True
print("CPH model")
n_in = x.shape[1]

layers_sizes = [n_in, 1]

# Construct Neural Network
layers = []
for i in range(len(layers_sizes)-2):
    layers.append(nn.Linear(layers_sizes[i],layers_sizes[i+1]))
    layers.append(nn.ReLU())

layers.append(nn.Linear(layers_sizes[-2], layers_sizes[-1]))
my_network = nn.Sequential(*layers)
#my_network.apply(init_weights)

#optimizer = optimizer = torch.optim.SGD(my_network.parameters(), lr=learning_rate, momentum=momentum, weight_decay=L2_reg, nesterov=True)

optimizer = torch.optim.Adam(my_network.parameters(), lr=1e-3)
my_network.train()
if CUDA:
    my_network.cuda()

# If you have validation data, you can add it as the valid_dataloader parameter to the function
n_epochs = 1000
metrics = train(my_network, x, e, t, CUDA, optimizer, n_epochs)
print()

print("Done")

CPH model
1775.405
Finished Training with 1 iterations in 0.03s
1767.8845
Finished Training with 2 iterations in 0.05s
1760.3849
Finished Training with 3 iterations in 0.08s
1752.9061
Finished Training with 4 iterations in 0.11s
1745.4496
Finished Training with 5 iterations in 0.13s
1738.0154
Finished Training with 6 iterations in 0.16s
1730.6035
Finished Training with 7 iterations in 0.19s
1723.2148
Finished Training with 8 iterations in 0.22s
1715.8499
Finished Training with 9 iterations in 0.26s
1708.5093
Finished Training with 10 iterations in 0.28s
1701.1935
Finished Training with 11 iterations in 0.31s
1693.9028
Finished Training with 12 iterations in 0.34s
1686.6375
Finished Training with 13 iterations in 0.37s
1679.3989
Finished Training with 14 iterations in 0.39s
1672.1866
Finished Training with 15 iterations in 0.42s
1665.0007
Finished Training with 16 iterations in 0.45s
1657.8429
Finished Training with 17 iterations in 0.49s
1650.713
Finished Training with 18 iterations in

1029.7874
Finished Training with 152 iterations in 5.20s
1027.7001
Finished Training with 153 iterations in 5.24s
1025.6427
Finished Training with 154 iterations in 5.27s
1023.615
Finished Training with 155 iterations in 5.30s
1021.61676
Finished Training with 156 iterations in 5.34s
1019.6474
Finished Training with 157 iterations in 5.37s
1017.7069
Finished Training with 158 iterations in 5.41s
1015.79456
Finished Training with 159 iterations in 5.45s
1013.91046
Finished Training with 160 iterations in 5.49s
1012.0539
Finished Training with 161 iterations in 5.52s
1010.22504
Finished Training with 162 iterations in 5.55s
1008.42285
Finished Training with 163 iterations in 5.59s
1006.6474
Finished Training with 164 iterations in 5.62s
1004.8984
Finished Training with 165 iterations in 5.66s
1003.1753
Finished Training with 166 iterations in 5.70s
1001.47784
Finished Training with 167 iterations in 5.73s
999.8059
Finished Training with 168 iterations in 5.76s
998.15906
Finished Training

902.73145
Finished Training with 302 iterations in 10.52s
902.4758
Finished Training with 303 iterations in 10.56s
902.223
Finished Training with 304 iterations in 10.59s
901.9734
Finished Training with 305 iterations in 10.63s
901.72675
Finished Training with 306 iterations in 10.66s
901.48303
Finished Training with 307 iterations in 10.70s
901.2424
Finished Training with 308 iterations in 10.73s
901.0046
Finished Training with 309 iterations in 10.77s
900.7696
Finished Training with 310 iterations in 10.81s
900.5374
Finished Training with 311 iterations in 10.84s
900.30835
Finished Training with 312 iterations in 10.87s
900.0815
Finished Training with 313 iterations in 10.91s
899.85767
Finished Training with 314 iterations in 10.95s
899.6367
Finished Training with 315 iterations in 10.98s
899.4182
Finished Training with 316 iterations in 11.02s
899.20215
Finished Training with 317 iterations in 11.05s
898.9889
Finished Training with 318 iterations in 11.09s
898.7779
Finished Training

885.1395
Finished Training with 446 iterations in 15.63s
885.09485
Finished Training with 447 iterations in 15.67s
885.05066
Finished Training with 448 iterations in 15.70s
885.0069
Finished Training with 449 iterations in 15.74s
884.96387
Finished Training with 450 iterations in 15.77s
884.9212
Finished Training with 451 iterations in 15.81s
884.87897
Finished Training with 452 iterations in 15.85s
884.8374
Finished Training with 453 iterations in 15.88s
884.7964
Finished Training with 454 iterations in 15.92s
884.75586
Finished Training with 455 iterations in 15.95s
884.7156
Finished Training with 456 iterations in 15.99s
884.6761
Finished Training with 457 iterations in 16.02s
884.63696
Finished Training with 458 iterations in 16.06s
884.5982
Finished Training with 459 iterations in 16.09s
884.56
Finished Training with 460 iterations in 16.13s
884.5222
Finished Training with 461 iterations in 16.16s
884.485
Finished Training with 462 iterations in 16.20s
884.44824
Finished Training 

882.04205
Finished Training with 590 iterations in 20.73s
882.03357
Finished Training with 591 iterations in 20.77s
882.0246
Finished Training with 592 iterations in 20.79s
882.01624
Finished Training with 593 iterations in 20.82s
882.0078
Finished Training with 594 iterations in 20.85s
881.9994
Finished Training with 595 iterations in 20.87s
881.9912
Finished Training with 596 iterations in 20.90s
881.9827
Finished Training with 597 iterations in 20.93s
881.9747
Finished Training with 598 iterations in 20.96s
881.96655
Finished Training with 599 iterations in 20.99s
881.95886
Finished Training with 600 iterations in 21.02s
881.9507
Finished Training with 601 iterations in 21.04s
881.94293
Finished Training with 602 iterations in 21.07s
881.93524
Finished Training with 603 iterations in 21.10s
881.92755
Finished Training with 604 iterations in 21.12s
881.91974
Finished Training with 605 iterations in 21.15s
881.91235
Finished Training with 606 iterations in 21.19s
881.9047
Finished Tra

881.2794
Finished Training with 734 iterations in 24.80s
881.2759
Finished Training with 735 iterations in 24.83s
881.2724
Finished Training with 736 iterations in 24.86s
881.269
Finished Training with 737 iterations in 24.89s
881.26575
Finished Training with 738 iterations in 24.91s
881.26245
Finished Training with 739 iterations in 24.94s
881.25903
Finished Training with 740 iterations in 24.97s
881.25574
Finished Training with 741 iterations in 24.99s
881.25226
Finished Training with 742 iterations in 25.03s
881.249
Finished Training with 743 iterations in 25.06s
881.2457
Finished Training with 744 iterations in 25.08s
881.2425
Finished Training with 745 iterations in 25.11s
881.2391
Finished Training with 746 iterations in 25.14s
881.23596
Finished Training with 747 iterations in 25.16s
881.2327
Finished Training with 748 iterations in 25.19s
881.2293
Finished Training with 749 iterations in 25.22s
881.22614
Finished Training with 750 iterations in 25.25s
881.22284
Finished Trainin

880.8879
Finished Training with 877 iterations in 28.90s
880.8857
Finished Training with 878 iterations in 28.93s
880.8835
Finished Training with 879 iterations in 28.95s
880.88135
Finished Training with 880 iterations in 28.98s
880.8791
Finished Training with 881 iterations in 29.01s
880.8766
Finished Training with 882 iterations in 29.03s
880.8745
Finished Training with 883 iterations in 29.06s
880.87244
Finished Training with 884 iterations in 29.09s
880.8704
Finished Training with 885 iterations in 29.12s
880.8682
Finished Training with 886 iterations in 29.15s
880.86584
Finished Training with 887 iterations in 29.18s
880.86365
Finished Training with 888 iterations in 29.20s
880.8616
Finished Training with 889 iterations in 29.23s
880.8595
Finished Training with 890 iterations in 29.26s
880.8573
Finished Training with 891 iterations in 29.29s
880.85516
Finished Training with 892 iterations in 29.31s
880.853
Finished Training with 893 iterations in 29.35s
880.8511
Finished Training 

In [90]:
my_network

Sequential(
  (0): Linear(in_features=5, out_features=1, bias=True)
)

In [102]:
print("Train C-Index:", metrics['c-index'])

('Train C-Index:', [0.2941594429444365, 0.294687265271321, 0.2951338841633001, 0.29568200734891087, 0.2958850159361741, 0.296067723664711, 0.2966158468503218, 0.2970015631661219, 0.297630889786638, 0.2982399155484277, 0.2986459327229542, 0.29911285247365965, 0.2994985687894598, 0.2997218782354494, 0.30037150571469173, 0.30089932804157615, 0.3011835400637447, 0.30146775208591325, 0.3021782821413346, 0.3027873079031243, 0.3032948293712824, 0.3037211474045352, 0.3042489697314196, 0.30475649119957776, 0.3052437118090095, 0.305649728983536, 0.30601514444060984, 0.30672567449603116, 0.30747680626890517, 0.3080046285957896, 0.308654256075032, 0.3091820784019164, 0.3098520067398851, 0.31082644795874864, 0.31107005826346457, 0.3116790840252543, 0.3123084106457703, 0.31273472867902313, 0.31379037333279197, 0.3143993990945817, 0.3149881239976451, 0.3156783531943401, 0.3162670780974035, 0.31695730729409854, 0.3176475364907936, 0.3182159605351306, 0.31878438457946773, 0.3196370206459733, 0.31988063