In [1]:
from synthcity.utils.datasets.time_series.pbc import PBCDataloader
from synthcity.plugins.core.dataloader import (
     TimeSeriesDataLoader,
     TimeSeriesSurvivalDataLoader,
 )
import numpy as np
import pandas as pd
import tabulate

static, temporal, temporal_horizons, outcome = PBCDataloader(as_numpy = True).load()
T, E = outcome
  
horizons = [0.25, 0.5, 0.75]
time_horizons = np.quantile(T, horizons).tolist()



In [2]:
from synthcity.plugins.core.models.time_series_survival.benchmarks import (
     evaluate_ts_survival_model,
)


def eval_model(mod, **kwargs):
    n_folds = 3
    model = mod(**kwargs)
        
    score = evaluate_ts_survival_model(
        model, 
        static, temporal, temporal_horizons, T, E, 
        time_horizons,
        n_folds = n_folds
    )
    return score


In [3]:
from synthcity.plugins.core.models.time_series_survival import (
     DynamicDeephitTimeSeriesSurvival, rnn_modes, output_modes
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for output_mode in output_modes:
    for base in rnn_modes:
        score = eval_model(DynamicDeephitTimeSeriesSurvival, rnn_type = base, output_type = output_mode)["str"]
        local_results = pd.DataFrame([[f"DynDeephit[{base} -> {output_mode}]", score["c_index"], score["brier_score"]]], columns = headers)
        results = pd.concat([results, local_results], ignore_index = True)
    
tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,DynDeephit[GRU -> MLP],0.7712 +/- 0.0445,0.1632 +/- 0.0195
1,DynDeephit[LSTM -> MLP],0.756 +/- 0.036,0.1686 +/- 0.0274
2,DynDeephit[RNN -> MLP],0.7591 +/- 0.0298,0.1642 +/- 0.0191
3,DynDeephit[Transformer -> MLP],0.7996 +/- 0.0059,0.1248 +/- 0.0063
4,DynDeephit[Wavelet -> MLP],0.6973 +/- 0.052,0.1824 +/- 0.014
5,DynDeephit[GRU -> MiniRocket],0.7388 +/- 0.0253,0.1753 +/- 0.0088
6,DynDeephit[LSTM -> MiniRocket],0.7323 +/- 0.0379,0.18 +/- 0.0134
7,DynDeephit[RNN -> MiniRocket],0.7431 +/- 0.0155,0.1611 +/- 0.012
8,DynDeephit[Transformer -> MiniRocket],0.7532 +/- 0.0504,0.1497 +/- 0.0207
9,DynDeephit[Wavelet -> MiniRocket],0.5661 +/- 0.1138,0.2313 +/- 0.0101


In [4]:
from synthcity.plugins.core.models.time_series_survival import (
     CoxTimeSeriesSurvival,
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for output_mode in output_modes:
    for base in rnn_modes:
        score = eval_model(CoxTimeSeriesSurvival, emb_rnn_type = base, emb_output_type = output_mode)["str"]
        local_results = pd.DataFrame([[f"CoxPH[{base} -> {output_mode}]", score["c_index"], score["brier_score"]]], columns = headers)
        results = pd.concat([results, local_results], ignore_index = True)

tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,CoxPH[GRU -> MLP],0.8026 +/- 0.0025,0.1383 +/- 0.0094
1,CoxPH[LSTM -> MLP],0.7985 +/- 0.0113,0.1381 +/- 0.0028
2,CoxPH[RNN -> MLP],0.7745 +/- 0.0136,0.1407 +/- 0.0202
3,CoxPH[Transformer -> MLP],0.8026 +/- 0.004,0.1349 +/- 0.0116
4,CoxPH[Wavelet -> MLP],0.7428 +/- 0.0432,0.1531 +/- 0.0283
5,CoxPH[GRU -> MiniRocket],0.782 +/- 0.0156,0.1479 +/- 0.0211
6,CoxPH[LSTM -> MiniRocket],0.7721 +/- 0.0238,0.1523 +/- 0.0254
7,CoxPH[RNN -> MiniRocket],0.7787 +/- 0.0131,0.1431 +/- 0.0165
8,CoxPH[Transformer -> MiniRocket],0.7119 +/- 0.0421,0.1812 +/- 0.0048
9,CoxPH[Wavelet -> MiniRocket],0.7291 +/- 0.0337,0.1523 +/- 0.0051


In [5]:
from synthcity.plugins.core.models.time_series_survival import (
     XGBTimeSeriesSurvival,
 )

headers = ["Model", "C-Index", "Brier score"]
results = pd.DataFrame([], columns = headers)

for output_mode in output_modes:
    for base in rnn_modes:
        score = eval_model(XGBTimeSeriesSurvival, emb_rnn_type = base, emb_output_type = output_mode)["str"]
        local_results = pd.DataFrame([[f"XGB[{base} -> {output_mode}]", score["c_index"], score["brier_score"]]], columns = headers)
        results = pd.concat([results, local_results], ignore_index = True)

tabulate.tabulate(results, tablefmt='html')

0,1,2,3
0,XGB[GRU -> MLP],0.7493 +/- 0.0099,0.1866 +/- 0.0081
1,XGB[LSTM -> MLP],0.7501 +/- 0.0162,0.1867 +/- 0.0171
2,XGB[RNN -> MLP],0.7604 +/- 0.02,0.1769 +/- 0.0184
3,XGB[Transformer -> MLP],0.7927 +/- 0.0189,0.161 +/- 0.0089
4,XGB[Wavelet -> MLP],0.6991 +/- 0.0399,0.2166 +/- 0.0307
5,XGB[GRU -> MiniRocket],0.7344 +/- 0.0176,0.1813 +/- 0.0136
6,XGB[LSTM -> MiniRocket],0.724 +/- 0.0312,0.204 +/- 0.0345
7,XGB[RNN -> MiniRocket],0.7481 +/- 0.0013,0.1722 +/- 0.0118
8,XGB[Transformer -> MiniRocket],0.728 +/- 0.0519,0.1946 +/- 0.0256
9,XGB[Wavelet -> MiniRocket],0.7076 +/- 0.0266,0.1695 +/- 0.0131
