In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os, sys
sys.path.append(os.path.abspath('..'))

In [3]:
from btdsa.config import Config, BASELINE_MODEL_FAMILY, TDSA_MODEL_LIST
from btdsa.eval_utils import EvalSurv
from btdsa.train_utils import init_trainer
from btdsa.utils import create_logger

In [4]:
logger = create_logger(logs_dir='./logs')

In [5]:
def run_experiment(model_name, time_range='full'):
    ev = EvalSurv()  # custom evaluation interface

    cfg = Config
    cfg.model_name = model_name
    cfg.time_range = time_range
    cfg.random_state = 1234 # will be repeated across multiple random_state variables
    cfg.setup()

    for dataset in cfg.list_of_datasets:
        trainer = init_trainer(cfg)
        trainer.logger = logger # assign logger
        
        surv, model = trainer.fit_and_predict(dataset)
        ev.trainer = trainer
        x_test, y_test = trainer.test
        ev.evaluate(surv, x_test, y_test)
    ev.report()  # log and report results in beautiful tables
    
def run_all_exp(model_list, time_range):
    for model_name in model_list:
        run_experiment(model_name, time_range=time_range)

## Models

In [6]:
simpler_ver = True # set False, for faster experiments

if simpler_ver:
    MODEL_LIST = ['CoxPH', 'DeepHitSingle', 'DRSA', 'BTDSA']
else:
    MODEL_LIST = BASELINE_MODEL_FAMILY + TDSA_MODEL_LIST

## time_range = `full`

In [7]:
run_all_exp(MODEL_LIST, 'full')

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                                       CoxPH                                                                                       |
+-----------------------------------------------------------+-----------------------------------------------------------+-----------------------------------------------------------+
|                            gbsg                           |                          metabric                         |                          support                          |
+-----------------------------------------------------------+-----------------------------------------------------------+-----------------------------------------------------------+
|                0.641857 (0.638766,0.644949)               |                0.676998 (0.6

## time_range = `truncated`

In [8]:
run_all_exp(MODEL_LIST, 'truncated')

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                                       CoxPH                                                                                       |
+-----------------------------------------------------------+-----------------------------------------------------------+-----------------------------------------------------------+
|                            gbsg                           |                          metabric                         |                          support                          |
+-----------------------------------------------------------+-----------------------------------------------------------+-----------------------------------------------------------+
|                0.641857 (0.638766,0.644949)               |                0.676998 (0.6