In [7]:
import sksurv as sks
import sksurv.preprocessing
import sksurv.metrics
import sksurv.datasets
import sksurv.linear_model
import sksurv.ensemble

from pathlib import Path
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import numpy as np
import sklearn as skl
import scipy.stats as sp

import pymc as pm
import pymc_bart as pmb
import pandas as pd

import importlib
import mlflow as ml

In [None]:
plt.ioff()
np.random.seed(99)

In [1]:
import simsurv_func

# Objective
Simulations of the cox proportional and non proportional survial models. 
Validation of bart pymc against cph model, rsf model.

Validation metrics:
    2 bool
    5 bool
    2 bool + 1 linear
    complex combination



In [4]:
def sim_surv(N=100, 
            T=100, 
            x_vars = 1, 
            lambda_f=None, 
            a=2, 
            alpha_f = None, 
            seed=999, 
            cens_ind = True,
            cens_scale = 20,
            err_ind = False):
    # np.random.seed(seed)

    x_mat = np.zeros((N, x_vars))
    for x in np.arange(x_vars):
        x1 = sp.bernoulli.rvs(.5, size = N)
        x_mat[:,x] = x1
    # calculate lambda
    
    # set lambda
    if lambda_f is None:
        lmbda = np.exp(2 + 0.3*(x_mat[:,0] + x_mat[:,1]) + x_mat[:,2])
    else:
        lmbda = eval(lambda_f)
    
    # set alpha if specified
    if alpha_f is None:
        a = np.repeat(a, N)
    else:
        a = eval(alpha_f)

    # add error
    if err_ind:
        error = sp.norm.rvs(0, .5, size = N)
        lmbda=lmbda + error

    # get time series
    t = np.linspace(0,T, T)

    # calculate survival and event times
    sv_mat = np.zeros((N, t.shape[0]))
    tlat = np.zeros(N)
    for idx, l in enumerate(lmbda):
        sv = np.exp(-1 * np.power((t/l), a[idx]))
        sv_mat[idx,:] = sv
        
        # generate event times 
        unif = np.random.uniform(size=1)
        ev = lmbda[idx] * np.power((-1 * np.log(unif)), 1/a[idx])
        tlat[idx] = ev

    if cens_ind:
        # censor
        cens = np.ceil(np.random.exponential(size = N, scale = cens_scale))

        # min cen and surv event
        t_event  = np.minimum(cens, np.ceil(tlat))
        status = (tlat <= cens) * 1
    else:
        cens=np.zeros(N)
        t_event = np.ceil(tlat)
        status = np.ones(N)

        

    return sv_mat, x_mat, lmbda, a, tlat, cens, t_event, status

In [None]:
# def get_x_info(x_mat):
#     x_comb_unique = np.unique(x_mat, axis=0)
#     return x_comb_unique

def get_x_info(x_mat):
    x = np.unique(x_mat, axis=0, return_index=True, return_counts=True)
    x_out, x_idx, x_cnt = x[0], x[1], x[2]
    return x_out, x_idx, x_cnt


def get_status_perc(status):
    out = status.sum()/status.shape[0]
    cens = 1-out
    return out, cens

def get_event_time_metric(t_event):
    t_mean = t_event.mean()
    t_max = t_event.max()
    return t_mean, t_max

def get_train_matrix(x_mat, t_event, status):
    et = pd.DataFrame({"status": status, "time":t_event})
    train = pd.concat([et, pd.DataFrame(x_mat)],axis=1)
    return train

def get_y_sklearn(status, t_event):
    y = np.array(list(zip(np.array(status, dtype="bool"), t_event)), dtype=[("Status","?"),("Survival_in_days", "<f8")])
    return y


In [None]:
# Plot Function
def plot_sv(x_mat, sv_mat, t, title="TITLE", save=False, dir=".", show=False):
    dist_x, dist_idx = np.unique(x_mat, axis=0, return_index=True)

    # print(tt)
    try:
        fig = plt.figure()
        if sv_mat.shape[0] != dist_idx.shape[0]:
            for idx, i in enumerate(sv_mat[dist_idx]):
                plt.step(np.arange(t), i, label = str(dist_x[idx]))
                plt.legend()
                plt.title(title)
        else:
            for idx, i in enumerate(sv_mat):
                plt.step(np.arange(t), i, label = str(dist_x[idx]))
                plt.legend()
                plt.title(title)
        if show:
            plt.show()
        if save:
            plt.savefig(f"{dir}/{title}.png")
    finally:
        plt.close(fig)

In [None]:
def surv_pre_train(data_x_n, data_y, X_TIME=True):
    # set up times
    # t_sort = np.append([0], np.unique(data_y["Survival_in_days"]))
    t_sort = np.unique(data_y["Survival_in_days"])
    t_ind = np.arange(0,t_sort.shape[0])
    t_dict = dict(zip(t_sort, t_ind))

    # set up delta
    delta = np.array(data_y["Status"], dtype = "int")
    
    t_out = []
    pat_x_out = []
    delta_out = []
    for idx, t in enumerate(data_y["Survival_in_days"]):
        # get the pat_time and use to get the array of times for the patient
        p_t_ind = t_dict[t]
        p_t_set = t_sort[0:p_t_ind+1]
        t_out.append(p_t_set)
        
        size = p_t_set.shape[0]
        # get patient array
        pat_x = np.tile(data_x_n.iloc[idx].to_numpy(), (size, 1))
        pat_x_out.append(pat_x)

        # get delta
        pat_delta = delta[idx]
        delta_set = np.zeros(shape=size, dtype=int)
        delta_set[-1] = pat_delta
        delta_out.append(delta_set)
    
    
    t_out, delta_out, pat_x_out = np.concatenate(t_out), np.concatenate(delta_out), np.concatenate(pat_x_out)
    if X_TIME:
        pat_x_out = np.array([np.concatenate([np.array([t_out[idx]]), i]) for idx, i in enumerate(pat_x_out)])
    return t_out, delta_out, pat_x_out

def surv_pre_test(data_x_n, data_y, X_TIME=True):
    # t_sort = np.append([0], np.unique(data_y["Survival_in_days"]))
    t_sort = np.unique(data_y["Survival_in_days"])
    t_out = []
    pat_x_out = []
    for idx, t in enumerate(data_y["Survival_in_days"]):
        # get the pat_time and use to get the array of times for the patient
        p_t_set = t_sort
        t_out.append(p_t_set)
        
        size = p_t_set.shape[0]
        # get patient array
        pat_x = np.tile(data_x_n.iloc[idx].to_numpy(), (size, 1))
        pat_x_out.append(pat_x)
    
    t_out, pat_x_out = np.concatenate(t_out),  np.concatenate(pat_x_out)
    if X_TIME:
        pat_x_out = np.array([np.concatenate([np.array([t_out[idx]]), i]) for idx, i in enumerate(pat_x_out)])
    return t_out, pat_x_out

def get_bart_test(x_out, T):
    d1 = np.arange(T + 1)
    d2 = np.arange(x_out.shape[1])
    
    out = np.stack(np.array(np.meshgrid(d1, d2, d2)),-1).reshape(-1, d2.shape[0] + 1)
    return out


In [None]:
# Set experiment
    # - each simulation parms is a new experiment

# Simulation loop
    # Creat run
        # Simulate data
        # log param alpha, lambda
        # log param N
        # log param T (of probabilites generated)
        # log param x_info
        # log param cen percent calculated
        # log param status event calculated
        # log param t_event mean
        # log param t_event max
        # log artif train dataset
        # log artif plot curves

        # model cph
        # log metri coeff
        # log metri exp(coef)
        # log artif plot curves
        # log model cph
        
        #  model rsf
        # log artif plot curves
        # log model resf

        # tranform data long-form
        # model bart
        # transform to survival
        # log artif plot curves
        # log model bart

        # get metrics rmse, bias
        # log metri cph_rmse
        # log metri cph_bias
        # log metri rsf_rmse
        # log metri rsf_bias
        # log metri bart_rmse
        # log metri bart_bias
    
    # End run





In [None]:
ml.create_experiment(name = "test_sim")
# Set experiment
    # - each simulation parms is a new experiment

# Simulation loop

    # Creat run
with ml.start_run() as run:
    OUTPUTS = "output"
    ALPHA = 3
    LAMBDA = "np.exp(2 + 0.3*(x_mat[:,0] + x_mat[:,1]))"
    N = 100
    T = 30
    X_VARS = 2
    CENS_SCALE = 60
    # Simulate data
    sv_mat, x_mat, lmbda, a, tlat, cens, t_event, status = sim_surv(N=N, 
                    T=T,
                    x_vars=X_VARS,
                    a = ALPHA,
                    lambda_f = LAMBDA,
                    cens_scale=CENS_SCALE,
                    err_ind = False)

    # log param alpha
    ml.log_param("alpha", ALPHA)
    # log param labmda
    ml.log_param("lambda", LAMBDA)
    # log param N
    ml.log_param("N", N)
    # log param T (# timepoint probabilites generated)
    ml.log_param("T", T)
    # log param X_VARS
    ml.log_param("X_VARS", X_VARS)
    # log parm CENS_SCALE
    ml.log_param("CENS_SCALE", CENS_SCALE)
    
    
    # log param x_info
    x_out, x_idx, x_cnt = get_x_info(x_mat)
    ml.log_param("X_INFO", str(list(zip(x_out, x_cnt))))

    # log metric cen percent calculated
    # log metric status event calculated
    event_calc, cens_calc = get_status_perc(status)
    ml.log_metric("EVENT_PERC", event_calc)
    ml.log_metric("CENS_PERC", cens_calc)

    # log metric t_event mean
    # log metric t_event max
    t_mean, t_max = get_event_time_metric(t_event)
    ml.log_metric("T_EVENT_MEAN", t_mean)
    ml.log_metric("T_EVENT_MAX", t_max)
    
    # log artif train dataset
    train = get_train_matrix(x_mat, t_event, status)
    ml.log_artifact("TRAIN", train)
    
    # log artif plot curves
    title = "actual_survival"
    plot_sv(x_mat, sv_mat, T, title=title, save = True, dir="output")
    ml.log_artifact(f"output/{title}.png")
    

    # get sklearn components
    y_sk = get_y_sklearn(status, t_event)
    x_sk = train.iloc[:,2:]

    # model cph
    cph = sksurv.linear_model.CoxPHSurvivalAnalysis()
    cph.fit(x_sk, y_sk)
    # log metri coeff
    ml.log_metric("cph_coef", cph.coef_)
    # log metri exp(coef)
    ml.log_metric("cph_exp_coef", np.exp(cph.coef_))
    # predic cph
    cph_surv = cph.predict_survival_function(pd.DataFrame(x_out))
    cph_sv_val = [sf(np.arange(T)) for sf in cph_surv]
    # log artif plot curves
    title = "cph_surv_pred"
    plot_sv(x_mat, cph_sv_val, T, title = title, save=True, dir="outputs")
    ml.log_artifact(f"outputs/{title}.png")
    # log model cph
    # idk how to do
    
    #  model rsf
    rsf = sksurv.ensemble.RandomSurvivalForest(
        n_estimators=1000, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=20
    )
    rsf.fit(x_sk, y_sk)
    # predict rsf
    rsf_surv = rsf.predict_survival_function(pd.DataFrame(x_out))
    rsf_sv_val = [sf(np.arange(T)) for sf in rsf_surv]
    # log artif plot curves
    title = "rsf_surv_pred"
    plot_sv(x_mat, rsf_sv_val, T, title=title, save=True, dir="outputs")
    ml.log_artifact(f"outputs/{title}.png")
    # log model resf

    # tranform data long-form
    b_tr_t, b_tr_delta, b_tr_x = surv_pre_train(x_sk, y_sk)
    # b_te_t, b_te_x = surv_pre_test(x_sk, y_sk)
    b_te_x = get_bart_test(x_out, T)
    off = sp.norm.ppf(np.mean(b_tr_delta))
    # model bart
    with pm.Model() as bart:
        x_data = pm.MutableData("x", b_tr_x)
        f = pmb.BART("f", X=x_data, y=b_tr_delta, m=50)
        z = pm.Deterministic("z", f + off)
        mu = pm.Deterministic("mu", pm.math.invprobit(z))
        y_pred = pm.Bernoulli("y_pred")
        

    # transform to survival
    # log artif plot curves
    # log model bart

    # get metrics rmse, bias
    # log metri cph_rmse
    # log metri cph_bias
    # log metri rsf_rmse
    # log metri rsf_bias
    # log metri bart_rmse
    # log metri bart_bias

# End run (defaults when using with/ block)

