In [None]:
from synthcity.utils.datasets.time_series.pbc import PBCDataloader
from synthcity.plugins.core.dataloader import (
     TimeSeriesDataLoader,
     TimeSeriesSurvivalDataLoader,
 )
import numpy as np
import pandas as pd
import tabulate

static, temporal, temporal_horizons, outcome = PBCDataloader(as_numpy = True).load()
T, E = outcome
  
horizons = [0.25, 0.5, 0.75]
time_horizons = np.quantile(T, horizons).tolist()

In [2]:
from synthcity.plugins.core.models.time_series_survival.benchmarks import (
     evaluate_ts_survival_model,
)


def eval_model(mod, **kwargs):
    n_folds = 3
    model = mod(**kwargs)
        
    score = evaluate_ts_survival_model(
        model, 
        static, temporal, temporal_horizons, T, E, 
        time_horizons,
        n_folds = n_folds
    )
    return score


In [3]:
from synthcity.plugins.core.models.time_series_survival import (
     DynamicDeephitTimeSeriesSurvival, rnn_modes, output_modes
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for output_mode in output_modes:
    for base in rnn_modes:
        score = eval_model(DynamicDeephitTimeSeriesSurvival, rnn_type = base, output_type = output_mode)["str"]
        local_results = pd.DataFrame([[f"DynDeephit[{base} -> {output_mode}]", score["c_index"], score["brier_score"]]], columns = headers)
        results = pd.concat([results, local_results], ignore_index = True)
    
tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,DynDeephit[GRU -> MLP],0.8213 +/- 0.0154,0.116 +/- 0.0109
1,DynDeephit[LSTM -> MLP],0.8115 +/- 0.0101,0.1237 +/- 0.0123
2,DynDeephit[RNN -> MLP],0.8356 +/- 0.012,0.1103 +/- 0.003
3,DynDeephit[Transformer -> MLP],0.826 +/- 0.0082,0.1141 +/- 0.0053
4,DynDeephit[Wavelet -> MLP],0.805 +/- 0.0307,0.1255 +/- 0.0135
5,DynDeephit[GRU -> MiniRocket],0.796 +/- 0.0163,0.1327 +/- 0.0114
6,DynDeephit[LSTM -> MiniRocket],0.7883 +/- 0.0083,0.1406 +/- 0.0059
7,DynDeephit[RNN -> MiniRocket],0.8257 +/- 0.0187,0.125 +/- 0.0104
8,DynDeephit[Transformer -> MiniRocket],0.821 +/- 0.0133,0.1229 +/- 0.0087
9,DynDeephit[Wavelet -> MiniRocket],0.7811 +/- 0.0245,0.1365 +/- 0.0124


In [4]:
from synthcity.plugins.core.models.time_series_survival import (
     DynamicDeephitTimeSeriesSurvival, rnn_modes, output_modes
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for wavelet_type in ["haar", "db4", "db6", "sym2", "sym4", "sym14"]:
    base = "Wavelet"
    output_mode = "MLP"
    score = eval_model(DynamicDeephitTimeSeriesSurvival, 
                       rnn_type = base, 
                       output_type = output_mode, 
                       wavelet_type = wavelet_type,
                       wavelet_mode = "symmetric",
                      )["str"]
    local_results = pd.DataFrame([[f"DynDeephit[{base} -> {wavelet_type}]", score["c_index"], score["brier_score"]]], columns = headers)
    results = pd.concat([results, local_results], ignore_index = True)
    
tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,DynDeephit[Wavelet -> haar],0.805 +/- 0.0307,0.1255 +/- 0.0135
1,DynDeephit[Wavelet -> db4],0.7968 +/- 0.0202,0.1225 +/- 0.0055
2,DynDeephit[Wavelet -> db6],0.8108 +/- 0.0017,0.1212 +/- 0.007
3,DynDeephit[Wavelet -> sym2],0.8278 +/- 0.0032,0.1051 +/- 0.0074
4,DynDeephit[Wavelet -> sym4],0.7888 +/- 0.0292,0.1299 +/- 0.0088
5,DynDeephit[Wavelet -> sym14],0.8062 +/- 0.0382,0.1241 +/- 0.0182


In [5]:
from synthcity.plugins.core.models.time_series_survival import (
     CoxTimeSeriesSurvival,
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for output_mode in output_modes:
    for base in rnn_modes:
        score = eval_model(CoxTimeSeriesSurvival, emb_rnn_type = base, emb_output_type = output_mode)["str"]
        local_results = pd.DataFrame([[f"CoxPH[{base} -> {output_mode}]", score["c_index"], score["brier_score"]]], columns = headers)
        results = pd.concat([results, local_results], ignore_index = True)

tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,CoxPH[GRU -> MLP],0.8505 +/- 0.0074,0.1008 +/- 0.0109
1,CoxPH[LSTM -> MLP],0.8335 +/- 0.0023,0.1085 +/- 0.0076
2,CoxPH[RNN -> MLP],0.846 +/- 0.0099,0.1059 +/- 0.0114
3,CoxPH[Transformer -> MLP],0.8222 +/- 0.0401,0.1201 +/- 0.0205
4,CoxPH[Wavelet -> MLP],0.8146 +/- 0.0153,0.1298 +/- 0.02
5,CoxPH[GRU -> MiniRocket],0.8294 +/- 0.0137,0.1128 +/- 0.0222
6,CoxPH[LSTM -> MiniRocket],0.8072 +/- 0.022,0.1256 +/- 0.0217
7,CoxPH[RNN -> MiniRocket],0.8305 +/- 0.014,0.1264 +/- 0.0055
8,CoxPH[Transformer -> MiniRocket],0.8102 +/- 0.0351,0.1269 +/- 0.0238
9,CoxPH[Wavelet -> MiniRocket],0.7869 +/- 0.0183,0.1508 +/- 0.0111


In [6]:
from synthcity.plugins.core.models.time_series_survival import (
     CoxTimeSeriesSurvival,
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for wavelet_type in ["haar", "db4", "db6", "sym2", "sym4", "sym14"]:
    base = "Wavelet"
    output_mode = "MLP"
    score = eval_model(CoxTimeSeriesSurvival, 
                       emb_rnn_type = base, 
                       emb_output_type = output_mode,
                       emb_wavelet_type = wavelet_type,
                       emb_wavelet_mode = "symmetric",
    )["str"]

    local_results = pd.DataFrame([[f"CoxPH[{base} -> {wavelet_type}]", score["c_index"], score["brier_score"]]], columns = headers)
    results = pd.concat([results, local_results], ignore_index = True)
    
tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,CoxPH[Wavelet -> haar],0.8146 +/- 0.0153,0.1298 +/- 0.02
1,CoxPH[Wavelet -> db4],0.8185 +/- 0.0091,0.1309 +/- 0.027
2,CoxPH[Wavelet -> db6],0.8224 +/- 0.0061,0.1316 +/- 0.0237
3,CoxPH[Wavelet -> sym2],0.8279 +/- 0.0236,0.1255 +/- 0.0236
4,CoxPH[Wavelet -> sym4],0.7986 +/- 0.0324,0.1348 +/- 0.0075
5,CoxPH[Wavelet -> sym14],0.8264 +/- 0.0197,0.1106 +/- 0.0068


In [7]:
from synthcity.plugins.core.models.time_series_survival import (
     XGBTimeSeriesSurvival,
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for output_mode in output_modes:
    for base in rnn_modes:
        score = eval_model(XGBTimeSeriesSurvival, emb_rnn_type = base, emb_output_type = output_mode)["str"]
        local_results = pd.DataFrame([[f"XGB[{base} -> {output_mode}]", score["c_index"], score["brier_score"]]], columns = headers)
        results = pd.concat([results, local_results], ignore_index = True)

tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,XGB[GRU -> MLP],0.791 +/- 0.0089,0.1425 +/- 0.0111
1,XGB[LSTM -> MLP],0.7804 +/- 0.0165,0.1447 +/- 0.0148
2,XGB[RNN -> MLP],0.7917 +/- 0.0194,0.1384 +/- 0.0194
3,XGB[Transformer -> MLP],0.802 +/- 0.0233,0.1534 +/- 0.0324
4,XGB[Wavelet -> MLP],0.7812 +/- 0.024,0.1535 +/- 0.0187
5,XGB[GRU -> MiniRocket],0.7582 +/- 0.0432,0.1668 +/- 0.0363
6,XGB[LSTM -> MiniRocket],0.7421 +/- 0.0282,0.1749 +/- 0.029
7,XGB[RNN -> MiniRocket],0.7856 +/- 0.0166,0.1666 +/- 0.0132
8,XGB[Transformer -> MiniRocket],0.7936 +/- 0.0145,0.1554 +/- 0.0257
9,XGB[Wavelet -> MiniRocket],0.7575 +/- 0.0117,0.1729 +/- 0.0159


In [8]:
from synthcity.plugins.core.models.time_series_survival import (
     XGBTimeSeriesSurvival,
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for wavelet_type in ["haar", "db4", "db6", "sym2", "sym4", "sym14"]:
    base = "Wavelet"
    output_mode = "MLP"
    score = eval_model(XGBTimeSeriesSurvival, 
                       emb_rnn_type = base, 
                       emb_output_type = output_mode,
                       emb_wavelet_type = wavelet_type,
                       emb_wavelet_mode = "symmetric",
                      )["str"]
    local_results = pd.DataFrame([[f"XGB[{base} -> {wavelet_type}]", score["c_index"], score["brier_score"]]], columns = headers)
    results = pd.concat([results, local_results], ignore_index = True)

tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,XGB[Wavelet -> haar],0.7812 +/- 0.024,0.1535 +/- 0.0187
1,XGB[Wavelet -> db4],0.7852 +/- 0.011,0.1616 +/- 0.0344
2,XGB[Wavelet -> db6],0.7866 +/- 0.0203,0.1761 +/- 0.0498
3,XGB[Wavelet -> sym2],0.8063 +/- 0.0121,0.1359 +/- 0.0143
4,XGB[Wavelet -> sym4],0.7519 +/- 0.0285,0.1705 +/- 0.0238
5,XGB[Wavelet -> sym14],0.766 +/- 0.0256,0.1542 +/- 0.0109
