In [4]:
from bart_survival import simulation as sim
from bart_survival import surv_bart as sb

In [5]:
import importlib
import numpy as np
rng = np.random.default_rng(1)

In [70]:
x_mat = sim.get_x_matrix(
    N=100,
    x_vars=1,
    VAR_CLASS=[2],
    VAR_PROB=[.5],
    rng = rng
)
# x_mat

In [75]:
# x_mat = np.array([1,2]).reshape(-1,1)
event_dict, sv_true, sv_scale_true = sim.simulate_survival(
    x_mat = x_mat,
    scale_f = "np.exp(2 + .4*x_mat[:,0])",
    shape_f = "1",
    eos = 20,
    cens_scale=None,
    time_scale=5,
    true_only=False,
    rng = rng
)

# HR check
print(np.exp(-.4))
print(sv_true["hz_true"][0,0]/sv_true["hz_true"][1,0])

mean shape 1.0
mean scale 8.951727820066358
mean time draws 8.183425597928592
0.6703200460356392
1.0


In [163]:
print(np.exp(-.4))
print(sv_true["hz_true"][-1,0]/sv_true["hz_true"][0,0])


0.6703200460356392
0.6703200460356394


In [86]:
# test time transform works 
t_scale = sb.get_time_transform(event_dict["t_event"], time_scale = 5)
np.alltrue(t_scale == event_dict["t_event_scale"])

True

In [111]:
importlib.reload(sb)
t_scale = sb.get_time_transform(event_dict["t_event"], time_scale = 5)
y_sk = sb.get_y_sklearn(event_dict["status"], t_scale)
trn = sb.get_surv_pre_train(y_sk, x_mat, weight=None)
post_test = sb.get_posterior_test(y_sk=y_sk, x_test = x_mat)


In [123]:
# # intitialize models
SPLIT_RULES =  [
    "pmb.ContinuousSplitRule()", # time
    "pmb.OneHotSplitRule", # ccsr_ind_p2
]
model_dict = {"trees": 50,
    "split_rules": SPLIT_RULES
}
sampler_dict = {
            "draws": 100,
            "tune": 100,
            "cores": 8,
            "chains": 8,
            "compute_convergence_checks": False
        }

In [254]:
# event_dict["t_event_scale"]
importlib.reload(sb)
BSM = sb.BartSurvModel(model_config=model_dict, sampler_config=sampler_dict)

BSM.fit(
    y =  trn["y"],
    X = trn["x"],
    weights=trn["w"],
    coords = trn["coord"],
    random_seed=5
)


    
    


Only 100 samples in chain.
Multiprocess sampling (8 chains in 8 jobs)
PGBART: [f]


 |████████████████████████████████| 100.00% [1600/1600 00:14<00:00 Sampling 8 chains, 0 divergences]

Sampling 8 chains for 100 tune and 100 draw iterations (800 + 800 draws total) took 16 seconds.


In [258]:
# importlib.reload(sb)
# check sample_posterior_predictive and bart_predict
# rng = None will maintain a singular state
post1 = BSM.sample_posterior_predictive(X_pred=post_test["post_x"], coords=post_test["coords"])
post2 = BSM.bart_predict(X_pred=post_test["post_x"], coords = post_test["coords"], rng=rng)


In [259]:
importlib.reload(sb)
sv_prob = sb.get_sv_prob(post1)
msk_1 = x_mat[:,0] == 1
HR = (sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)).mean(0)
HR_QT = np.quantile((sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)),[0.025,0.975])
print(HR)
print(HR_QT)
sv_prob = sb.get_sv_prob(post2)
msk_1 = x_mat[:,0] == 1
HR = (sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)).mean(0)
HR_QT = np.quantile((sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)),[0.025,0.975])
print(HR)
print(HR_QT)

[0.62760576 0.58929952 0.6071536  0.59607398]
[0.46375165 0.77200443]
[0.6198279  0.57943217 0.61163861 0.59489684]
[0.4476553  0.77291287]


In [274]:
# check that save and load works
BSM.save("idata.pkl", "tree.pkl")
importlib.reload(sb)
BSM2 = sb.BartSurvModel.load("idata.pkl", "tree.pkl")

print(BSM2.is_fitted_)

post3 = BSM2.bart_predict(X_pred=post_test["post_x"], coords = post_test["coords"], rng=rng)

sv_prob = sb.get_sv_prob(post3)
msk_1 = x_mat[:,0] == 1
HR = (sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)).mean(0)
HR_QT = np.quantile((sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)),[0.025,0.975])
print(HR)
print(HR_QT)

False
[0.61833547 0.58106054 0.60661005 0.58899409]
[0.44757265 0.74400483]


In [333]:
sv_prob = sb.get_sv_prob(post1)
assert(10>np.abs((sv_prob["sv"].mean(0)-sv_scale_true["sv_true"]).mean()))
np.abs((sv_prob["sv"].mean(0)-sv_scale_true["sv_true"]).mean())

0.02528974210233843

In [329]:
# check pdp
pdp1 = sb.get_pdp(x_mat, var_col = [0], values = [[0,1]], sample_n = None)
# pdp1[0].shape
pdp_tst = sb.get_posterior_test(y_sk, pdp1[0])
pdp_post = BSM.sample_posterior_predictive(pdp_tst["post_x"], pdp_tst["coords"])

sv_prob = sb.get_sv_prob(pdp_post)
msk_1 = pdp1[1]["coord"] == 1
HR = (sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)).mean(0)
HR_QT = np.quantile((sv_prob["prob"][:,msk_1,:].mean(1) / sv_prob["prob"][:,~msk_1,:].mean(1)),[0.025,0.975])
print(HR)
print(HR_QT)

█

Sampling: [f]


[0.61642049 0.57997963 0.6140865  0.59736831]00.00% [800/800 00:04<00:00]
[0.4594366 0.7593368]


In [335]:
(np.exp(-.4) - HR.mean())

0.06835631007238019

In [316]:
type(BSM2)

bart_survival.surv_bart.BartSurvModel