# Survival Analyses 
Comparison of using scikit-surv and pymc-BART

Objectives:
- Comparison of model using scikit-surv cph, randomForest and pymc-BART
- Use the same validation metrics for each model
- Test on multiple datasets
- Extend to R-BART


## Imports

In [0]:
import sksurv as sks
import sksurv.preprocessing
import sksurv.metrics
import sksurv.datasets

from pathlib import Path
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
import pandas as pd
import importlib
import numpy as np
import sklearn as skl
import scipy.stats as sp

In [0]:
from sklearn import set_config
set_config(display="text")

## Dataset

In [0]:
# get lung dataset
data_x, data_y = sks.datasets.load_veterans_lung_cancer()
print(data_x)
print(data_y)


In [0]:
# onehotencode the data_x
data_x_n = sks.preprocessing.OneHotEncoder().fit_transform(data_x)
data_x_n.head()

## COX Model

In [0]:
# load and fit cox model
cph = sks.linear_model.CoxPHSurvivalAnalysis()
cph.fit(data_x_n, data_y)

In [0]:
# show the coef
pd.Series(cph.coef_, index = data_x_n.columns)

In [0]:
# create predictions for synthetic patients (over the variable combinations of celltype and treatment)
# Prediction matrix for all combinations of treatment, celltype with age 65, karnofsky=60, months from diagnosis=1 and prior therapy 0 held constant
x_new = pd.DataFrame.from_dict(
    {
        1: [65,0,0,1,60,1,0,1],
        2: [65,0,0,1,60,1,0,0],
        3: [65,0,1,0,60,1,0,0],
        4: [65,0,1,0,60,1,0,1]
    },
    columns=data_x_n.columns,
    orient = "index"
)
x_new

In [0]:
# get predictions
# predictions are returned as an array with step functions
pred_surv = cph.predict_survival_function(x_new)
pred_surv

In [0]:
# plot 
time_points = np.arange(1,1000)
for i, surv_func in enumerate(pred_surv):
    plt.step(time_points, surv_func(time_points), where="post", label=f"Sample {i + 1}")
plt.ylabel("est. Surv")
plt.xlabel("time")
plt.legend(loc="best")

NOTE: this is only for the combination of cell type and test, conditional on the specified other covariates

In [0]:
# get risk scores
cph_risk = cph.predict_survival_function(data_x_n, return_array=True)
cph_surv_mean = np.mean(cph_risk, axis=0)
cph_surv_mean

In [0]:
plt.step(np.unique(data_y["Survival_in_days"]), cph_surv_mean)

## Metrics
- Harrell's C-index
    - biased upwards with increasing censoring
- Uno C-index
    - uses ipcw and helps reduce upward bias with increased censoring
    - When dataset is large and censoring is large then using the concordance ipcw will be more stable
- Time-dependent ROC
    - ROC compares the false postive rate (1-specificity) against the true positive rate(sensitivity
    - consider cumulative cases and dynamic controls at time t
        - cumulative cases are individuals who experience event prior to or at time t (t_i <= t)
        - dynamic controls are individuals with t_i > t
    - determins how well a model can identify patients who fail by time t or (t_i < t) from subjects who fail after time t (t_i > t)
    - useful if one is predicting occurence of event in a period up to time t, rather than just at one specific time point t
        eg
        ```
        t ----t1----t2----t3-----t4   
        1 --e--|-----|-----|------| 
        2 -----|-e---|-----|------|
        3 -----|-e---|-----|------| 
        4 -----|-----|-----|---e--|
        ```
        at t2, pats 1,2,3 are cumulative cases and 4 is a dynamic control
    
    - requires that test data time range is within the observed range of the training data
    - the last time point to evaluate the ROC must allow some patients to be censored after the timepoint P(c > t) > 0
    * see the custom auc function for more details
- Brier Score (Time-Dependent)
         

### AUC
sksurv AUC

- will need to implement the time-dependent risk scores

In [0]:

times = np.percentile(data_y["Survival_in_days"], np.linspace(5,81,15))
scores = np.array(data_x_n["Age_in_years"])
cph_auc, cph_auc_mean = sks.metrics.cumulative_dynamic_auc(data_y, data_y, cph_risk, times)

print(cph_auc)
plt.plot(times, cph_auc)
plt.axhline(cph_auc_mean, linestyle="--")

# data_x_n

### IPCW
IPCW is just 1/surv() calculated from the km estimator. 

Easy as fit the km.estimator, get the survival estimates, replace any 0 with infinity and finally inverse with 1/surv



- Required for custom AUC, it is incoporated in the sksurv AUC

- There are two main paths to access the ipcw in sksurv
    - direct with ipc_weights 
    - class censoringDistributionEstimator

- ipc_weights is a little more direct and light-weight

- A custom implementation might be useful for further understanding




In [0]:
import sksurv.nonparametric

# direct
ipwc = sks.nonparametric.ipc_weights(data_y["Status"], data_y["Survival_in_days"])
print(ipwc)

# class
cde = sks.nonparametric.CensoringDistributionEstimator()
cde.fit(data_y)
ipcw2 = cde.predict_ipcw(data_y)
print(ipcw2)

### AUC Custom
- Requires
    - ipcw of training data
    - time range to calculate on (must not extend to the max censoring date)
    - a risk score (i think this can be either a time dependent or independent risk score)
    - the original times data

In [0]:


# auc_num1 = 0
def get_auc(event, time, hz, w):
    t_out = np.zeros(shape=time.shape)
    for ti in range(0,len(time)):
        #set time
        t = time[ti]
        # print("t", t)

        # get numerator
        auc_num = 0
        pat_rng = range(0, len(hz))
        # sum over patients: sum_rank_ipcw
        # where sum_rank_ipcw: rank_indicator * ipcw weight
        # where rank_indicator: 1 if the other patient has event after time, current patient has event before/at time and other patients hazard is < current patients
            # eg. if time is T = 10, current patient event is 5 and hz is 3, and other patients events are 15,20,1; hz 0,1,4
            # then for current patient iterations we would see
                # iter | ot_pat > t | cur_pat <= t | hz(ot_pat) <= hz(cur_pat) | indicator | weight | out | cumsum
                # 1    | 1          | 1            | 1                         | 1         | 1.3    | 1.3 |  1.3
                # 2    | 1          | 1            | 1                         | 1         | 1.3    | 1.3 |  2.6
                # 3    | 0          | 1            | 0                         | 0         | 1.3    | 0   |  2.6 ***
        # The cumsum for each patient will be added together to form the numerator
        for i in pat_rng:
            auc_num1 = 0
            for j in pat_rng:
                # 1 if iter pat time > t
                i1 = 1 if event[j] > t else 0
                # 1 if hold pat time <= t
                i2 = 1 if event[i] <= t else 0
                # 1 if hz iter pat <= hz hold pat
                i3 = 1 if hz[j] <= hz[i] else 0
                # if iter pat is > t, hold pat <= t and iter pat hz <= hz hold pat then all are 1 and you adjust by the ipcw 
                auc_num1 += i1 * i2 * w[i] * i3 
            # print("auc_num1", auc_num1)
            auc_num += auc_num1

        # denom
        # the denom is the cumsum I(pat_time > t) * cumsum I(pat_time <= t * wi)
        auc_denom = 0
        den_i1 = 0
        for i in pat_rng:
            i1 = 1 if event[i] > t else 0
            den_i1 += i1

        den_i2 = 0
        for i in pat_rng:
            i1 = 1 if event[i] <= t else 0
            # added in w to the equation
            den_i2 += i1 * w[i]

        # eval auc
        # the auc final calc for each timepoint is then ratio of number correctly classified over number expected
        auc = auc_num / (den_i1 * den_i2)
        t_out[ti] = auc

    return t_out



In [0]:
# Implementation
# Requires
    # ipcw of training data
    # time range to calculate on (must not extend to the max censoring date)
    # a risk score (i think this can be either a time dependent or independent risk score)
    # the original times data

# get ipcw
ipcw = sks.nonparametric.ipc_weights(data_y["Status"], data_y["Survival_in_days"])

# get times to quantify
times_cph = np.percentile(data_y["Survival_in_days"], np.linspace(5,81, 15))

# get auc
auc_cust = get_auc(event = data_y["Survival_in_days"], time = times_cph, hz = cph_risk, w = ipcw)

In [0]:
# plot auc
plt.plot(times_cph, auc_cust)

In [0]:
# compare the cust auc and the sksurv auc
print(auc_cust)
print(cph_auc)
# Looks like it works correctly!

### C-Index

It would be worth working out these details in a custom function

In [0]:
# get c-index of the trained model on the training dataset

# prediction on train
pred_cph = ecph.predict(data_x_n)

# harrells c-index
h_cindex_cph = sks.metrics.concordance_index_censored(data_y["Status"], data_y["Survival_in_days"], pred_cph)
print("harrels c-index", h_cindex_cph[0])

# uno ipcw c-index
ipcw_cindex_cph = sks.metrics.concordance_index_ipcw(data_y, data_y, pred_cph)
print("ipcw c-index", ipcw_cindex_cph[0])


In [0]:
# # c-index custom
# def c_index(event_indicator, event_time, estimate):
#     order = np.argsort(event_Time)

#     condordant = 0
#     discordant = 0
#     tied_risk = 0
#     numerator = 0.0
#     denominator = 0.0
#     for ind, mask, tied_time in _iter_comparable(event_indicator, event_time, order):
#         est_i = estimate[order[ind]]
#         event_i = event_indicator[order[ind]]
#         w_i = weights[order[ind]]

#         est = estimate[order[mask]]

#         assert event_i, f"got censored sample at index {order[ind]}, but expected uncensored"

#         ties = np.absolute(est - est_i) <= tied_tol
#         n_ties = ties.sum()
#         # an event should have a higher score
#         con = est < est_i
#         n_con = con[~ties].sum()

#         numerator += w_i * n_con + 0.5 * w_i * n_ties
#         denominator += w_i * mask.sum()

#         tied_risk += n_ties
#         concordant += n_con
#         discordant += est.size - n_con - n_ties

#     if tied_time is None:
#         raise NoComparablePairException("Data has no comparable pairs, cannot estimate concordance index.")

#     cindex = numerator / denominator
#     return cindex, concordant, discordant, tied_risk, tied_time


### Brier Score (time-dependent)

- MSE for right censored data
- Depends on the survival probability and ipcw

In [0]:
# Brier score

# get inputs
y_times = data_y["Survival_in_days"]
times = np.percentile(y_times, np.linspace(10,90,10))
print(times)

cph_surv = ecph.predict_survival_function(data_x_n)
cph_surv1 = np.stack([fn(times) for fn in cph_surv])

sks_cph_brier = sks.metrics.brier_score(data_y, data_y, cph_surv1, times)
sks_cph_brier


### Custom Brier

In [0]:
# custom brier

def get_brier(y_time, delta, times, surv, ipcw, kipcw, ig = False, ar_out = False):
    t_out = np.zeros(times.shape)
    n = len(y_time)

    for ti, t in enumerate(times):
        # t = times[ti]
        score = 0
        for i in range(0, len(y_time)):
            sv = surv[i](t)
            
            i1 = 1 if y_time[i] <= t and delta[i] == 1 else 0 
            surv_ipcw1 = pow((0 - sv), 2)/ipcw[i] 

            i2 = 1 if y_time[i] > t else 0
            surv_ipcw2 = pow((1 - sv), 2)/kipcw[ti]

            score += (i1 * surv_ipcw1) + (i2 * surv_ipcw2)
        t_out[ti] = score/n
    
    if ig == False:
        return t_out
    else:
        ig_brier = np.trapz(t_out, times)/(times[-1]-times[0])
        if ar_out:
            return ig_brier, t_out
        else:
            return ig_brier


In [0]:
# get inputs
y_times = data_y["Survival_in_days"]
delta = data_y["Status"]
cph_surv = ecph.predict_survival_function(data_x_n)

# get times
times = np.percentile(y_times, np.linspace(10,90,10))

# get the ipcw
cde = sks.nonparametric.CensoringDistributionEstimator()
cde.fit(data_y)

# get the ipcw for patients
# ipcw1 = cde.predict_ipcw(data_y)
ipcw1 = cde.predict_proba(data_y["Survival_in_days"])
ipcw1[ipcw1 ==0] = np.inf


# get the ipcw for times
ipcw2 = cde.predict_proba(times)
ipcw2[ipcw2 == 0] = np.inf
kipcw = ipcw2

cus_cph_brier = get_brier(y_times, delta, times, cph_surv, ipcw1, kipcw)

print(times)
print(cus_cph_brier)
print(sks_cph_brier[0])
print(sks_cph_brier[1])



In [0]:
ig1 = get_brier(y_times, delta, times, cph_surv, ipcw1, kipcw, ig=True)

ig2 = sks.metrics.integrated_brier_score(data_y, data_y, cph_surv1, times)
print(ig1, ig2)

#### Using the kaplain meier estimator as baseline
You need to build the step funciton as seen below

ig_brier is a bit higher on the insample scoring

In [0]:
km_surv = sks.functions.StepFunction(*sks.nonparametric.kaplan_meier_estimator(delta, y_times, conf_type=None))
km_surv = np.tile(km_surv, (len(y_times)))
# cph_surv.shape
# km_surv.shape

get_brier(y_times, delta, times, km_surv, ipcw1, kipcw, ig=True, ar_out=True)

Success with the brier score custom implementaiton.

## BART-pymc
Run BART on the the dataset used in the cox

In [0]:
data_x, data_y = sks.datasets.load_veterans_lung_cancer()
# print(data_x)
# print(data_y)
data_x_n = sks.preprocessing.OneHotEncoder().fit_transform(data_x)

# full_dtst = pd.concat([pd.DataFrame(data_y), data_x_n], axis=1) 
# full_dtst.to_csv("vet_lung_canc.csv")

### Preprocessing
Data needs to be in a long format

In [0]:
def surv_pre_train(data_x_n, data_y, X_TIME=True):
    # set up times
    t_sort = np.append([0], np.unique(data_y["Survival_in_days"]))
    t_ind = np.arange(0,t_sort.shape[0])
    t_dict = dict(zip(t_sort, t_ind))

    # set up delta
    delta = np.array(data_y["Status"], dtype = "int")
    
    t_out = []
    pat_x_out = []
    delta_out = []
    for idx, t in enumerate(data_y["Survival_in_days"]):
        # get the pat_time and use to get the array of times for the patient
        p_t_ind = t_dict[t]
        p_t_set = t_sort[0:p_t_ind+1]
        t_out.append(p_t_set)
        
        size = p_t_set.shape[0]
        # get patient array
        pat_x = np.tile(data_x_n.iloc[idx].to_numpy(), (size, 1))
        pat_x_out.append(pat_x)

        # get delta
        pat_delta = delta[idx]
        delta_set = np.zeros(shape=size, dtype=int)
        delta_set[-1] = pat_delta
        delta_out.append(delta_set)
    
    
    t_out, delta_out, pat_x_out = np.concatenate(t_out), np.concatenate(delta_out), np.concatenate(pat_x_out)
    if X_TIME:
        pat_x_out = np.array([np.concatenate([np.array([t_out[idx]]), i]) for idx, i in enumerate(pat_x_out)])
    return t_out, delta_out, pat_x_out

def surv_pre_train2(data_x_n, data_y, X_TIME=True):
    # set up times
    # t_sort = np.append([0], np.unique(data_y["Survival_in_days"]))
    t_sort = np.unique(data_y["Survival_in_days"])
    t_ind = np.arange(0,t_sort.shape[0])
    t_dict = dict(zip(t_sort, t_ind))

    # set up delta
    delta = np.array(data_y["Status"], dtype = "int")
    
    t_out = []
    pat_x_out = []
    delta_out = []
    for idx, t in enumerate(data_y["Survival_in_days"]):
        # get the pat_time and use to get the array of times for the patient
        p_t_ind = t_dict[t]
        p_t_set = t_sort[0:p_t_ind+1]
        t_out.append(p_t_set)
        
        size = p_t_set.shape[0]
        # get patient array
        pat_x = np.tile(data_x_n.iloc[idx].to_numpy(), (size, 1))
        pat_x_out.append(pat_x)

        # get delta
        pat_delta = delta[idx]
        delta_set = np.zeros(shape=size, dtype=int)
        delta_set[-1] = pat_delta
        delta_out.append(delta_set)
    
    
    t_out, delta_out, pat_x_out = np.concatenate(t_out), np.concatenate(delta_out), np.concatenate(pat_x_out)
    if X_TIME:
        pat_x_out = np.array([np.concatenate([np.array([t_out[idx]]), i]) for idx, i in enumerate(pat_x_out)])
    return t_out, delta_out, pat_x_out

In [0]:
def surv_pre_test(data_x_n, data_y, X_TIME=True):
    t_sort = np.append([0], np.unique(data_y["Survival_in_days"]))
    t_out = []
    pat_x_out = []
    for idx, t in enumerate(data_y["Survival_in_days"]):
        # get the pat_time and use to get the array of times for the patient
        p_t_set = t_sort
        t_out.append(p_t_set)
        
        size = p_t_set.shape[0]
        # get patient array
        pat_x = np.tile(data_x_n.iloc[idx].to_numpy(), (size, 1))
        pat_x_out.append(pat_x)
    
    t_out, pat_x_out = np.concatenate(t_out),  np.concatenate(pat_x_out)
    if X_TIME:
        pat_x_out = np.array([np.concatenate([np.array([t_out[idx]]), i]) for idx, i in enumerate(pat_x_out)])
    return t_out, pat_x_out
    
def surv_pre_test2(data_x_n, data_y, X_TIME=True):
    # t_sort = np.append([0], np.unique(data_y["Survival_in_days"]))
    t_sort = np.unique(data_y["Survival_in_days"])
    t_out = []
    pat_x_out = []
    for idx, t in enumerate(data_y["Survival_in_days"]):
        # get the pat_time and use to get the array of times for the patient
        p_t_set = t_sort
        t_out.append(p_t_set)
        
        size = p_t_set.shape[0]
        # get patient array
        pat_x = np.tile(data_x_n.iloc[idx].to_numpy(), (size, 1))
        pat_x_out.append(pat_x)
    
    t_out, pat_x_out = np.concatenate(t_out),  np.concatenate(pat_x_out)
    if X_TIME:
        pat_x_out = np.array([np.concatenate([np.array([t_out[idx]]), i]) for idx, i in enumerate(pat_x_out)])
    return t_out, pat_x_out

In [0]:
# # get train dataset
# t1, delta1, trainx1 = surv_pre_train(data_x_n=data_x_n, data_y=data_y)

# # get test dataset
# t_test, testx = surv_pre_test(data_x_n=data_x_n, data_y=data_y)

# get train dataset 2
t2, delta2, trainx2 = surv_pre_train2(data_x_n=data_x_n, data_y=data_y)

# get test dataset 2
t_test2, testx2 = surv_pre_test2(data_x_n=data_x_n, data_y=data_y)

In [0]:
# # check
# print(t1.shape)
# print(delta1.shape)
# print(trainx1.shape)
# print(testx.shape)
# print(t_test.shape)

# check pre2
print(t2.shape)
print(delta2.shape)
print(trainx2.shape)
print(testx2.shape)
print(t_test2.shape)

Preprocessing is successful.
It is comparable to BART::surv.pre.bart with the minor addition of having added the 0 day into the data, where in BART::surv.pre.bart I belevie the 0 day is added in the model

## Models

### BART1 Model
With 0 day padding

In [0]:
with pm.Model() as bart1:
    x_data = pm.MutableData("x", trainx1)
    # y_data = pm.MutableData("y", delta1)
    
    z = pmb.BART("z", X = x_data, Y = delta1)
    mu = pm.Deterministic("mu", pm.math.invprobit(z))
    y_pred = pm.Bernoulli("y_pred", p=mu, observed=delta1, shape=x_data.shape[0])
    smp1 = pm.sample(random_seed=2)


In [0]:
with bart1:
    pm.set_data({"x":testx})
    pp1 = pm.sample_posterior_predictive(smp1, var_names= ["y_pred", "z", "mu"])

In [0]:
# get the predicted mean
def get_pred_mean(pp1, t_test):
    # get the pp
    # np.mean(pp1.posterior_predictive["mu"][3][:,1])
    rr1 = pp1.posterior_predictive["mu"].mean(("draw", "chain")).values
    # print(rr1.shape)
    # print(pp1.posterior_predictive["mu"].mean(("chain", "draw")))

    K = np.unique(t_test).shape[0]
    N = int(t_test.shape[0]/K)
    rr2 = np.cumprod(1-rr1.reshape((N,K)), axis=1)
    rr2_mean = np.mean(rr2, axis=0)

    return rr2_mean

In [0]:
bart_1_rr = get_pred_mean(pp1, t_test=t_test)
bart_1_rr

In [0]:
ax = plt.step(np.unique(t_test), bart_1_rr)

### BART2 Model
No 0 day padding

In [0]:
with pm.Model() as bart2:
    x_data = pm.MutableData("x", trainx2)
    # y_data = pm.MutableData("y", delta1)
    
    z = pmb.BART("z", X = x_data, Y = delta2)
    mu = pm.Deterministic("mu", pm.math.invprobit(z))
    y_pred = pm.Bernoulli("y_pred", p=mu, observed=delta2, shape=x_data.shape[0])
    smp2 = pm.sample(random_seed=2)


In [0]:
# Get test results
with bart2:
    pm.set_data({"x":testx2})
    pp2 = pm.sample_posterior_predictive(smp2, var_names= ["y_pred", "z", "mu"])

In [0]:
# look at first patient
np.mean(pp2.posterior_predictive["z"].values[0], 0)[0:101]

In [0]:
b2_surv = 1- pp2.posterior_predictive["mu"].values

In [0]:
pp2.posterior_predictive["z"]

In [0]:
b2_survm = np.matrix(b2_surv.reshape((4000,13837)))
KK = 101
NN = 137

for h in range(0,NN):
    for j in range(1,KK):
        # print(j)
        l = KK*h +j
        # print(l)
    # break
        b2_survm[:, l] = np.array(b2_survm[:, l-1]) * np.array(b2_survm[:,l])

# b2_survm[:,1] = np.array(b2_survm[:,1]) * np.array(b2_survm[:,2])
b2_survm

In [0]:
b2_surv_mean2 = np.mean(b2_survm, axis=0).reshape((NN,KK))
b2_surv_mean3 = np.array(np.mean(b2_surv_mean2, axis=0)).reshape((KK))
b2_surv_mean3

In [0]:
b2_surv_mean2[0]

In [0]:
bart_2_rr = get_pred_mean(pp2, t_test2)
bart_2_rr.shape

In [0]:
plt.step(np.unique(t_test2), bart_2_rr)
plt.step(np.unique(t_test2), b2_surv_mean3)

### BART3 Model
add in the offset

In [0]:
import scipy.stats as sp
# sp.qnorm(np.mean(delta2))
off = sp.norm.ppf(np.mean(delta2))
off

In [0]:
with pm.Model() as bart3:
    x_data = pm.MutableData("x", trainx2)
    # y_data = pm.MutableData("y", delta1)
    
    f = pmb.BART("f", X = x_data, Y = delta2, m=100)
    z = pm.Deterministic("z", f + off)

    mu = pm.Deterministic("mu", pm.math.invprobit(z))
    y_pred = pm.Bernoulli("y_pred", p=(mu), observed=delta2, shape=x_data.shape[0])
    smp3 = pm.sample(random_seed=2, draws=100)

In [0]:
smp3

In [0]:
with bart3:
    pm.set_data({"x":testx2})
    pp3 = pm.sample_posterior_predictive(smp3, var_names= ["y_pred", "f", "z", "mu"])

In [0]:
pp3

In [0]:
def get_surv(posterior, KK, NN):
    shp = posterior.posterior_predictive["mu"].shape
    rows = shp[0]*shp[1]
    cols = shp[2]

    surv = 1-posterior.posterior_predictive["mu"].values
    survm = np.matrix(surv.reshape((rows, cols)))
   
    for h in range(0,NN):
        for j in range(1,KK):
            l = KK*h +j
            survm[:, l] = np.array(survm[:, l-1]) * np.array(survm[:,l])

    surv_mean = np.mean(survm, axis=0).reshape((NN,KK))
    surv_mean = np.array(np.mean(surv_mean, axis=0)).reshape(KK)
    return surv_mean


In [0]:
b3_surv = get_surv(posterior=pp3, KK=101, NN=137)
b3_surv

In [0]:
# b3_surv_mean = np.mean(b3_survm, axis=0).reshape((NN,KK))
# b3_surv_mean2 = np.array(np.mean(b3_surv_mean, axis=0)).reshape(KK)
# b3_surv_mean2

In [0]:
plt.step(np.unique(t_test2), b3_surv)

### Model BART 4

In [0]:
# trainx2.shape
# delta2.shape
# times2 = np.unique(trainx2[:,0], return_index=True, return_counts=True)
# freq = times2[2]
# freq

# try conditioning y to the z
# delta2

def get_z(delta):
    z = np.zeros(shape=delta.shape)
    for idx, i in enumerate(delta):
        # print(i)
        if i == 1:
            z[idx] = sp.truncnorm.rvs(0, np.Inf, size = 1)
        else:
            z[idx] = sp.truncnorm.rvs(np.NINF, 0, size = 1)
    return z



zdelt = get_z(delta2)
# print(zdelt[delta2==1])
# print(delta2[delta2==1])

In [0]:
# try with binomial
with pm.Model() as bart4:
    x_data = pm.MutableData("x", trainx2)
    # y_data = pm.MutableData("y", delta1)
    
    # f = pmb.BART("f", X = x_data, Y = zdelt, m=50)
    f = pmb.BART("f", X = x_data, Y = delta2, m=50)
    z = pm.Deterministic("z", f + off)

    mu = pm.Deterministic("mu", pm.math.invprobit(z))
    # y_pred = pm.Binomial("y_pred", p = mu, observed = delta2)
    y_pred = pm.Bernoulli("y_pred", p=(mu), observed=delta2, shape=x_data.shape[0])
    smp4 = pm.sample(random_seed=2, draws=1000)


In [0]:
smp4

In [0]:
with bart4:
    pm.set_data({"x":testx2})
    pp4 = pm.sample_posterior_predictive(smp4, var_names= ["y_pred", "f", "z", "mu"])

In [0]:
pp4

In [0]:
b4_surv = get_surv(posterior=pp4, KK=101, NN=137)

In [0]:
b4_surv

In [0]:
plt.step(np.unique(t_test2), b3_surv)
plt.step(np.unique(t_test2), b4_surv)

In [0]:
# with pm.Model() as bart4:
#     x_data = pm.MutableData("x", trainx2)
#     # y_data = pm.MutableData("y", delta1)
    
#     z = pmb.BART("z", X = x_data, Y = delta2)

#     mu = pm.Deterministic("mu", pm.math.invprobit(z + off))
#     # y_pred = pm.Bernoulli("y_pred", p=(mu), observed=delta2, shape=x_data.shape[0])
#     smp4 = pm.sample(random_seed=2, sample=100)

## Random Forest

In [0]:
import sksurv.ensemble

In [0]:
# set up random forest
rsf = sks.ensemble.RandomSurvivalForest(
    n_estimators=1000, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=20
)
rsf.fit(data_x_n, data_y)

In [0]:
rsf.score(data_x_n, data_y)

In [0]:
rsf_surv1 = rsf.predict_survival_function(data_x_n, return_array=True)

In [0]:
rsf_surv_mean = np.mean(rsf_surv1,axis=0)
rsf_surv_mean

In [0]:
plt.step(np.unique(data_y["Survival_in_days"]), rsf_surv_mean, )

## Plot a few different models


In [0]:
kpm = sks.nonparametric.kaplan_meier_estimator(data_y["Status"], data_y["Survival_in_days"])

In [0]:
# get the rbart
rbart_mean = pd.read_csv("rbart_lung_mean.csv")
rbart_mean 

In [0]:
# plt.step(np.unique(t_test), bart_1_rr, label="bart")
# plt.step(np.unique(t_test2), bart_2_rr, label="bart2")
# plt.step(np.unique(t_test2), b2_surv_mean3, label="bart3")
plt.step(np.unique(t_test2), b3_surv, label="bart3")

plt.step(rbart_mean.times, rbart_mean.surv, label="rbart")
plt.step(np.unique(data_y["Survival_in_days"]), rsf_surv_mean, label = "rsf")
# plt.step(np.unique(data_y["Survival_in_days"]), cph_surv_mean, label="cph")
plt.step(kpm[0], kpm[1], label="kpm")
plt.legend(loc="best")


Right now there seems to be a little bit of bias in the the pymc bart implementations