In [12]:
import numpy as np
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
import matplotlib.patches as mpatches
from lifelines.utils import concordance_index
import scipy
from scipy import stats
from scipy.stats import linregress
from scipy.special import gamma, erf
import math

fontsize = 14
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
legend_size =14
plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)

font = {'family': 'normal',
        'weight': 'bold',
        'size': 24}

plt.rc('font', **font)
params = {'legend.fontsize': 'x-large',
          # 'figure.figsize': (15, 5),
          'axes.labelsize': 'x-large',
          'axes.titlesize': 'x-large',
          'xtick.labelsize': 'x-large',
          'ytick.labelsize': 'x-large'}

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

plt.rcParams.update(params)
import seaborn as sns
import pandas
sns.set_style('white')
sns.set_context('paper')
sns.set()
np.random.seed(31415)

## Select best perfoming alpha according to validation set

In [13]:
#
alpha = "100" # alpha=[0, 0.1, 1, 10, 100]

model = 'CSA-INFO'
is_non_param = True
is_stochastic = True

if 'SR' in model:
    is_stochastic = False
    
if 'AFT' in model:
    is_non_param = False      
    
#data = 'actg175'
data = 'actg175_simulated'

if alpha:
    main_path = 'analysis/{}/alpha_{}/'.format(data, alpha)
    path = 'analysis/{}/alpha_{}/{}/'.format(data, alpha, model)
else:
    path = 'analysis/{}/{}/'.format(data +'_{}_split'.format(split), model)
    main_path = 'analysis/{}/'.format(data +'_{}_split'.format(split))


#data = 'sythentic'
#path = 'analysis/{}/alpha_{}/{}/'.format(fld,alpha,  model)

print(path)
has_cf = True
time = 'days'
name = 'Valid'
if is_non_param:
    pred_t0_f = np.load(path + '{}_pred_t0_F.npy'.format(name))
    pred_t0_cf = np.load(path +  '{}_pred_t0_CF.npy'.format(name))

    pred_t1_f = np.load(path +  '{}_pred_t1_F.npy'.format(name))
    pred_t1_cf = np.load(path +  '{}_pred_t1_CF.npy'.format(name))
    print("pred_t0_f: ", pred_t0_f.shape)
    print("pred_t0_cf: ", pred_t0_cf.shape)

    print("pred_t1_f: ", pred_t1_f.shape)
    print("pred_t1_cf: ", pred_t1_cf.shape)
    
    
else:
    pred_t0_f = pandas.read_csv(path + '{}_pred_t0_F.csv'.format(name))
    pred_t0_cf = pandas.read_csv(path +  '{}_pred_t0_CF.csv'.format(name))

    pred_t1_f = pandas.read_csv(path +  '{}_pred_t1_F.csv'.format(name))
    pred_t1_cf = pandas.read_csv(path +  '{}_pred_t1_CF.csv'.format(name))
    
    print("pred_t0_f: ", pred_t0_f.shape,pred_t0_f.head())
    print("pred_t0_cf: ", pred_t0_cf.shape, pred_t0_cf.head())

    print("pred_t1_f: ", pred_t1_f.shape, pred_t1_f.head())
    print("pred_t1_cf: ", pred_t1_cf.shape, pred_t1_cf.head())



analysis/actg175_simulated/alpha_100/CSA-INFO/
pred_t0_f:  (96, 200)
pred_t0_cf:  (118, 200)
pred_t1_f:  (118, 200)
pred_t1_cf:  (96, 200)


In [14]:

path_factual = 'data/{}/{}_{}_idx.csv'.format(data, data, name.lower())
print("path_factual: ", path_factual)
data_frame = pandas.read_csv(path_factual)
## Factual
y_f = data_frame[['time']]
e_f = data_frame[['event']]
a = data_frame[['treatment']]
    
y_f = np.array(y_f).reshape(len(y_f))
e_f = np.array(e_f).reshape(len(e_f))
a = np.array(a).reshape(len(a))
    
## Counter Factual
y_cf = data_frame[['nn_cf_y']]
e_cf = data_frame[['nn_cf_e']]
y_cf = np.array(y_cf).reshape(len(y_cf))
e_cf = np.array(e_cf).reshape(len(e_cf))
    
print("a: ", a.shape)
print("y_cf: ", y_cf.shape)
print("e_cf: ", e_cf.shape)
print("y_f: ", y_f.shape)
print("e_f: ", e_f.shape)
print("a=1", np.sum(a))
print("y_f[a==1.0].shape", y_f[a==1].shape)
print("y_f[a==0.0].shape", y_f[a==0].shape)

    




path_factual:  data/actg175_simulated/actg175_simulated_valid_idx.csv
a:  (214,)
y_cf:  (214,)
e_cf:  (214,)
y_f:  (214,)
e_f:  (214,)
a=1 118.0
y_f[a==1.0].shape (118,)
y_f[a==0.0].shape (96,)


# Compute Factual + Counterfactual Likelihood

In [15]:
def relu(x):
    return x * (x > 0)

def weibull_lik(pred_t, y, e, name):
    shape =  pred_t['logshape_' + name]
    scale = pred_t['logscale_' + name]
    
    log_k = shape
    log_lam = scale
    
    k = np.exp(log_k)
    lam = np.exp(log_lam)
    
    log_surv = - (y / lam) ** k
    log_weibull = log_k - log_lam + (k - 1) * (np.log(y) - log_lam) - (y / lam) ** k
    
    return -log_weibull * e + -log_surv * (1-e) 

def lognormal_lik(pred_t, y, e, name):
    mu = pred_t['mu_' + name]
    logvar = pred_t['logvar_' + name]
    stddev = np.exp(logvar * 0.5)
    
    constant = 1e-8
    log_t = np.log(y + constant)
    
    log_pdf = -0.5 * (logvar + np.power(log_t - mu, 2) / np.exp(logvar))
    
    norm_diff = (log_t - mu) / stddev
    sqrt_2 = math.sqrt(2)
    cdf = 0.5 * (1.0 + erf(norm_diff / sqrt_2))
    log_surv = np.log(1 - cdf + constant)
    
    return -log_pdf * e + -log_surv * (1-e)
    

def non_param_lik(pred_t_samples, y, e):
    pred_t = np.mean(pred_t_samples,  axis=1)
    return np.abs(y - pred_t) * e + relu(y - pred_t)*(1-e)
    

if is_non_param:
    
    pred_lik_t1_f = non_param_lik(pred_t1_f, y=y_f[a==1], e=e_f[a==1]) 
    pred_lik_t0_cf = non_param_lik(pred_t0_cf, y=y_cf[a==1], e=e_cf[a==1])
   
    pred_lik_t1_cf = non_param_lik(pred_t1_cf, y=y_cf[a==0], e=e_cf[a==0]) 
    pred_lik_t0_f = non_param_lik(pred_t0_f, y=y_f[a==0], e=e_f[a==0])
   
    
  
elif 'Weibull'  not in model:
    pred_lik_t1_f = lognormal_lik(pred_t1_f, y=y_f[a==1], e=e_f[a==1], name='one') 
    pred_lik_t0_cf = lognormal_lik(pred_t0_cf,y=y_cf[a==1], e=e_cf[a==1], name='zero')
    
    pred_lik_t1_cf = lognormal_lik(pred_t1_cf, y=y_cf[a==0], e=e_cf[a==0],  name='one') 
    pred_lik_t0_f = lognormal_lik(pred_t0_f, y=y_f[a==0], e=e_f[a==0], name='zero')
    
  

else:
    pred_lik_t1_f = weibull_lik(pred_t1_f, y=y_f[a==1], e=e_f[a==1], name='one') 
    pred_lik_t0_cf = weibull_lik(pred_t0_cf,y=y_cf[a==1], e=e_cf[a==1], name='zero')
    
    pred_lik_t1_cf = weibull_lik(pred_t1_cf, y=y_cf[a==0], e=e_cf[a==0],  name='one') 
    pred_lik_t0_f = weibull_lik(pred_t0_f, y=y_f[a==0], e=e_f[a==0], name='zero')
    

pred_lik = (np.mean(pred_lik_t1_f) + np.mean(pred_lik_t0_cf)+ np.mean(pred_lik_t1_cf) 
            + np.mean(pred_lik_t0_f)) * 0.25 
print("pred_lik: ", np.round(pred_lik, 2))


pred_lik:  126.37
