In [0]:
import sksurv as sks
import sksurv.preprocessing
import sksurv.metrics
import sksurv.datasets
import sksurv.linear_model
import sksurv.ensemble

from pathlib import Path
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import numpy as np
import sklearn as skl
import scipy.stats as sp

import pymc as pm
import pymc_bart as pmb
import pandas as pd

import importlib
import mlflow as ml
import simsurv_func as ssf
import subprocess
import lifelines
import pytensor.tensor as tt
import subprocess

In [0]:
lung = pd.read_csv("lung.csv")
lung["ph.karno"] = lung["ph.karno"].fillna(lung["pat.karno"])
lung["status"] = lung["status"] - 1
lung["time"] = np.ceil(lung["time"]/30)
train = lung[["time", "status", "sex", "age", "ph.karno"]]

In [0]:
train

In [0]:
TRAIN_CSV = "lung_train.csv"
RBART_CSV = "lung_result.csv"
train.to_csv(TRAIN_CSV)

# Set up pymc-bart

In [0]:
# get x,y in sklearn format
y_sk = ssf.get_y_sklearn(lung["status"], lung["time"])
x_sk = train.iloc[:,2:]

In [0]:
# get long format
# tranform data long-form
b_tr_t, b_tr_delta, b_tr_x = ssf.surv_pre_train2(x_sk, y_sk)
# b_te_t, b_te_x = surv_pre_test(x_sk, y_sk)
b_te_x = ssf.get_bart_test(x_sk, np.unique(b_tr_t))
off = sp.norm.ppf(np.mean(b_tr_delta))

In [0]:
with pm.Model() as bart:    
    x_data = pm.MutableData("x", b_tr_x)
    f = pmb.BART("f", 
            X=x_data, 
            Y=b_tr_delta,
            m=200, 
            split_rules = [pmb.ContinuousSplitRule(), 
                            pmb.OneHotSplitRule(),
                            pmb.ContinuousSplitRule(),
                            pmb.ContinuousSplitRule()])
    z = pm.Deterministic("z", f + off)
    mu = pm.Deterministic("mu", pm.math.invprobit(z))
    y_pred = pm.Bernoulli("y_pred", p=mu, observed=b_tr_delta, shape=x_data.shape[0])
    bdata = pm.sample(random_seed=2, draws=200, tune = 200, cores=4)

with bart:
    pm.set_data({"x":pd.DataFrame(b_te_x)}, coords= {"obs":np.arange(0,b_te_x.shape[0],1)})
    pp = pm.sample_posterior_predictive(bdata, var_names = ["y_pred", "f", "z", "mu"])



In [0]:
# transform to survival
bart_sv_fx = ssf.get_sv_fx(pp, x_sk)
# bart_svt
bart_sv_t = np.unique(b_tr_t)

# add a time 0 with prob 1 
bart_sv_t = np.concatenate([np.array([0]), bart_sv_t])
bart_sv_val = [np.concatenate([np.array([1]), sv]) for sv in bart_sv_fx]


In [0]:
# get raw probs
n_t = np.unique(b_tr_t).shape[0]
n = x_sk.shape[0]
bart_prob_val = pp.posterior_predictive["mu"].mean(("chain","draw")).values.reshape(n, n_t)

# Run RBART

In [0]:
p1 = subprocess.Popen([
        "Rscript",
        "lung_run.r",
        TRAIN_CSV,
        RBART_CSV
        ])
p1.wait()

In [0]:
rbart = pd.read_csv(RBART_CSV)
rbart

In [0]:
# get rbart components
tshape = np.unique(train["time"]).shape[0]
N = train.shape[0]
rb_surv_val = rbart["surv"].to_numpy().reshape((N,tshape))
rb_prob_val = rbart["prob"].to_numpy().reshape((N,tshape))
rb_mat = rbart[["sex","age","ph.karno"]][rbart["t"] == 1].to_numpy()
rb_time = train["time"].to_numpy()
rb_delta = train["status"].to_numpy()

# Metrics

In [0]:
from sksurv.metrics import (
    concordance_index_censored,
    concordance_index_ipcw,
    cumulative_dynamic_auc,
    integrated_brier_score,
)

## C-index

In [0]:
# get the quantiles to evaluate on 0.25, 0.5,0.75
idx_quant = [np.array(np.round(x),dtype="int") for x in (tshape*.25, tshape*.5, tshape*.75)]

# empy array
rb_cindx = np.zeros(shape=(len(idx_quant)))
pb_cindx = np.zeros(shape=(len(idx_quant)))

# get c-index
for idx, i in enumerate(idx_quant):
    rb_cindx[idx] = concordance_index_censored(
                event_indicator=y_sk["Status"],
                event_time=y_sk["Survival_in_days"],
                estimate=rb_prob_val[:,i]
            )[0]
    pb_cindx[idx] = concordance_index_censored(
                event_indicator=y_sk["Status"],
                event_time=y_sk["Survival_in_days"],
                estimate=bart_prob_val[:,i]
            )[0]

In [0]:
print(f"CINDEX at .25, .5, .75 quantile times using probability values")
print(f"RBART CINDEX: {rb_cindx}")
print(f"PYMC-BART CINDEX: {pb_cindx}")

## TIME-AUC


In [0]:
# get the quantiles to evaluate on 0.25, 0.5,0.75
idx_quant = [np.array(np.round(x),dtype="int") for x in (tshape*.1, tshape*.25, tshape*.5, tshape*.75, tshape*.9)]
rb_uniq_t = np.unique(rb_time)

r_cda = cumulative_dynamic_auc(survival_train=y_sk, survival_test=y_sk, estimate=rb_prob_val[:, idx_quant], times=rb_uniq_t[idx_quant])
p_cda = cumulative_dynamic_auc(survival_train=y_sk, survival_test=y_sk, estimate=bart_prob_val[:,idx_quant], times=rb_uniq_t[idx_quant])

In [0]:
print(f"PYMC Cumul Dynamic AUC {p_cda[0]} \n AVE {p_cda[1]}")
print(f"RBART Cumul Dynamic AUC {r_cda[0]} \n AVE {r_cda[1]}")

In [0]:
plt.plot(rb_uniq_t[idx_quant], r_cda[0], marker="o", label="rbart")
plt.plot(rb_uniq_t[idx_quant], p_cda[0], marker="o", label="pymc")
plt.legend()

## BRIER SCORE

In [0]:
p_ibs = integrated_brier_score(y_sk, y_sk, bart_sv_fx[:, 1:-1], rb_uniq_t[1:-1])
r_ibs = integrated_brier_score(y_sk, y_sk, rb_surv_val[:, 1:-1], rb_uniq_t[1:-1])

print(f"PYMC BRIER SCORE: {p_ibs}")
print(f"RBART BRIER SCORE: {r_ibs}")

In [0]:
# bart_sv_t[idx_quant]
# rb_prob_val[0].shape
# rb_time[idx_quant]
# np.unique(rb_time)[idx_quant]
rb_uniq_t[idx_quant]
rb_prob_val[:,idx_quant]

# Plot Posterior Predictive on sex (NOT A FPD)

In [0]:
# add the 0 time sv val
rb_surv_adj = np.hstack([np.repeat(1, rb_surv_val.shape[0]).reshape(rb_surv_val.shape[0],1), rb_surv_val])
rb_t_adj = np.hstack([0, rb_uniq_t])

rb_ml = rb_surv_adj[rb_mat[:,0] == 1,:]
rb_fl = rb_surv_adj[rb_mat[:,0] == 2,:]


In [0]:
p_ml = np.array(bart_sv_val)[rb_mat[:,0]==1,:]
p_fl = np.array(bart_sv_val)[rb_mat[:,0]==2,:]

In [0]:
plt.plot(rb_t_adj.T, rb_fl.T, color="red", alpha=0.2)
plt.plot(bart_sv_t.T, p_fl.T, color = "purple", alpha=0.1)
plt.legend()

In [0]:
plt.plot(rb_t_adj.T, rb_ml.T, color="lightgreen", alpha=0.2)
plt.plot(bart_sv_t.T, p_ml.T, color = "lightblue", alpha=0.2)
plt.legend()

## pat level difference

In [0]:
diff_mean = np.zeros_like(rb_t_adj)
diff_95 = np.zeros_like(rb_t_adj)
diff_05 = np.zeros_like(rb_t_adj)
diff_25 = np.zeros_like(rb_t_adj)
diff_50 = np.zeros_like(rb_t_adj)
diff_75 = np.zeros_like(rb_t_adj)
for idx, i in enumerate(rb_t_adj):
    diff = rb_surv_adj[:,idx] - np.array(bart_sv_val)[:,idx]
    diff_mean[idx] = diff.mean()
    # print(diff)
    
    diff_pc = np.quantile(diff, [0.05, .25, .5, .75, 0.95])
    # print(diff_pc)
    
    diff_05[idx] = diff_pc[0]
    diff_25[idx] = diff_pc[1]
    diff_50[idx] = diff_pc[2]
    diff_75[idx] = diff_pc[3]
    diff_95[idx] = diff_pc[4]

    # print(diff_pc)
    # print(i)


In [0]:
plt.plot(rb_t_adj, diff_mean, label = "mean diff", marker="1")
plt.plot(rb_t_adj, diff_05, label = ".05 diff")
plt.plot(rb_t_adj, diff_25, label = ".25 diff")
plt.plot(rb_t_adj, diff_50, label = ".50 diff")
plt.plot(rb_t_adj, diff_75, label = ".75 diff")
plt.plot(rb_t_adj, diff_95, label = ".95 diff")
plt.legend()

Mean abs difference hangs around 0.01
The increasing trend with time indicates the PYMC bart approaches 0% survival slightly more quickly than the RBART. However as the mean and 50% interval indicates that the difference is fairly insiginificant across most of the matched patient values.