In [None]:
import numpy as np
import pandas as pd

from generate_data import load_data
import lifelines
from lifelines import CoxPHFitter

data = 'rr_nl_nph'
df = load_data('~/Research/Cox/data/rr_nl_nph.pkl')

time = [min(df['duration'])+i*0.01 for i in range(int((max(df['duration'])-min(df['duration']))/0.01))]

train_frac = 0.8
alpha = 0.95
epochs = 100

alphas = [0.6,0.7,0.8,0.9,0.95]
for alpha in alphas:
    coverage = []
    interval_len = []
    for epoch in range(epochs):
      rng = np.random.RandomState(epoch)
      shuffle_idx = rng.permutation(range(len(df)))
      train_idx = shuffle_idx[:int(train_frac*len(df))]
      test_idx = shuffle_idx[int(train_frac*len(df)):]
      df_train = df.iloc[train_idx,:]
      df_test = df.iloc[test_idx,:]
      df_train = df_train.drop(columns=['duration_true'])
      duration_true = df_test['duration_true']
      df_test = df_test.drop(columns=['duration_true'])
      cph = CoxPHFitter()
      cph.fit(df_train,duration_col = 'duration',event_col = 'event')

      surv = cph.predict_survival_function(df_test.iloc[:,:3],times=time)
      surv_ = (surv<=1-alpha).to_numpy(dtype='int8')
      index = np.array(surv.index)
      multiply_surv = np.transpose(surv_)*index
      multiply_surv_ = np.where(multiply_surv==0,np.max(index),multiply_surv)

      t_predict = multiply_surv_.min(axis = 1)
      diff_predict_true = np.subtract(t_predict,np.array(duration_true))

      cover = sum(diff_predict_true>=0)/len(t_predict)

      coverage.append(cover)
      print('[%d]\t%.3f'%(epoch,cover))
      interval_len.append(np.mean(t_predict))
    print('Total Coverage Statistics:\t [Mean]%.3f\t[Std.]%.3f\t[Max]%.3f\t[Min]%.3f'%(np.mean(coverage),np.std(coverage),np.max(coverage),np.min(coverage)))

    np.savetxt('./output/cox_reg_coverage_'+data+'_'+str(epochs)+str(alpha)+'.txt',np.array(coverage))
    np.savetxt('./output/cox_reg_interval_'+data+'_'+str(epochs)+str(alpha)+'.txt',np.array(interval_len))



[0]	0.612
[1]	0.614
[2]	0.617
[3]	0.637
[4]	0.625
[5]	0.630
[6]	0.586
[7]	0.611
[8]	0.609
[9]	0.596
[10]	0.618
[11]	0.600
[12]	0.604
[13]	0.608
[14]	0.639
[15]	0.623
[16]	0.617
[17]	0.600
[18]	0.596
[19]	0.594
[20]	0.616
[21]	0.600
[22]	0.597
[23]	0.601
[24]	0.615
[25]	0.603
[26]	0.626
[27]	0.622
[28]	0.627
[29]	0.625
[30]	0.617
[31]	0.627
[32]	0.620
[33]	0.616
[34]	0.625
[35]	0.611
[36]	0.620
[37]	0.623
[38]	0.626
[39]	0.619
[40]	0.624
[41]	0.615
[42]	0.619
[43]	0.607
[44]	0.624
[45]	0.604
[46]	0.616
[47]	0.622
[48]	0.613
[49]	0.618
[50]	0.611
[51]	0.590
[52]	0.611
[53]	0.604
[54]	0.623
[55]	0.591
[56]	0.622
[57]	0.607
[58]	0.619
[59]	0.617
[60]	0.602
[61]	0.602
[62]	0.625
[63]	0.603
[64]	0.610
[65]	0.615
[66]	0.619
[67]	0.623
[68]	0.634
[69]	0.614
[70]	0.631
[71]	0.600
[72]	0.619
[73]	0.611
[74]	0.596
[75]	0.619
[76]	0.617
[77]	0.627
[78]	0.620
[79]	0.605
[80]	0.612
[81]	0.631
[82]	0.608
[83]	0.608
[84]	0.610
[85]	0.614
[86]	0.625
[87]	0.617
[88]	0.595
[89]	0.621
[90]	0.620
[91]	0.61

In [3]:
len(diff_predict_true)

2000