In [1]:
import numpy as np

from pycox import datasets
from lifelines.datasets import load_rossi
from sksurv.datasets import (
    load_aids,
    load_breast_cancer,
    load_flchain,
    load_gbsg2,
    load_whas500,
)
from sklearn.preprocessing import LabelEncoder


def get_dataset(name: str):
    if name == "metabric":
        df = datasets.metabric.read_df()
    elif name == "support":
        df = datasets.support.read_df()
    elif name == "gbsg":
        df = datasets.gbsg.read_df()
    elif name == "rossi":
        df = load_rossi()
        df = df.rename(columns={"week": "duration", "arrest": "event"})
    elif name == "aids":
        X, Y = load_aids()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "flchain":
        X, Y = load_flchain()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "gbsg2":
        X, Y = load_gbsg2()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "whas500":
        X, Y = load_whas500()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]

    for col in df.columns:
        if df[col].dtype.name in ["object", "category"]:
            df[col] = LabelEncoder().fit_transform(df[col])

    duration_col = "duration"
    event_col = "event"
    
    df = df[df["duration"] > 0]

    X = df.drop(columns=[duration_col, event_col])
    Y = df[event_col]
    T = df[duration_col]

    time_horizons = np.linspace(T.min(), T.max(), num=5)[1:-1]

    return df, X, T, Y, time_horizons

In [2]:
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from typing import Tuple
import optuna
import copy

from synthcity.plugins.models.time_to_event.tte_aft import WeibullAFTTimeToEvent
from synthcity.plugins.models.time_to_event.tte_coxph import CoxPHTimeToEvent
from synthcity.plugins.models.time_to_event.tte_rsf import RandomSurvivalForestTimeToEvent
from synthcity.plugins.models.time_to_event.tte_xgb import XGBTimeToEvent
from synthcity.plugins.models.time_to_event.tte_deephit import DeephitTimeToEvent
from synthcity.plugins.models.time_to_event.tte_date import DATETimeToEvent
from synthcity.plugins.models.time_to_event.tte_robust_date import RobustDATETimeToEvent
from synthcity.plugins.models.time_to_event.tte_tenn import TENNTimeToEvent

from synthcity.plugins.models.time_to_event.metrics import (
     expected_time_error,
     ranking_error,
     rush_error,
 )
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.plugins.prediction.regression import Regression
from synthcity.plugins import Plugins

from sksurv.metrics import concordance_index_censored
from sklearn.model_selection import KFold

from adjutorium.utils.risk_estimation import generate_dataset_for_horizon

In [3]:
def constant_columns(dataframe: pd.DataFrame) -> list:
    """
    Drops constant value columns of pandas dataframe.
    """
    result = []
    for column in dataframe.columns:
        if len(dataframe[column].unique()) == 1:
            result.append(column)
    return result

def print_results(results):
    headers = ["Decensoring model", "Synth model", "synth_real CINDEX", "real_synth CINDEX", "synth_synth CINDEX"]
    out_results = []

    for tte_model in results:
        for synth_model in results[tte_model]:
            local = results[tte_model][synth_model]
            out_results.append([tte_model, synth_model, local["synth_real"]["c_index"], 
                               local["real_synth"]["c_index"], local["synth_synth"]["c_index"],
                               ])

    return pd.DataFrame(out_results, columns = headers)

def get_tte_model(name: str):
    if name == "tenn":
        return TENNTimeToEvent()
    elif name == "random_survival_forest":
        return RandomSurvivalForestTimeToEvent()
    elif name == "date":
        return DATETimeToEvent()
    elif name == "deephit":
        return RandomSurvivalForestTimeToEvent()
    elif name == "coxph":
        return CoxPHTimeToEvent()
    
    raise NotImplementedError(name)
    
def dataset_uncensoring(
    model_name: str, 
    X: pd.DataFrame, 
    T: pd.Series, 
    Y: pd.Series):
    
    tte_model = get_tte_model(model_name)

    # Uncensoring step
    tte_model.fit(X, T, E)
    
    pred_T = tte_model.predict(X)
    pred_T = pd.Series(pred_T, index=X.index)
    pred_T[Y == 1] = T[Y == 1]
    
    synth_baseline = X.copy()
    synth_baseline["event_time"] = pred_T

    return synth_baseline

def evaluate_test_dataset(X_train, T_train, Y_train,
                          X_test, T_test, Y_test,
                          time_horizons, model = "cox_ph", n_folds = 3):    
    predictor = RiskEstimation().get(model)
    
    const_cols = constant_columns(X_train)
    X_train = X_train.drop(columns = const_cols)
    X_test = X_test.drop(columns = const_cols)
    
    predictor.fit(X_train, T_train, Y_train)
    
    return evaluate_survival_estimator(
                                       [predictor] * n_folds, 
                                       X_test, T_test, Y_test, 
                                       time_horizons = time_horizons,
                                       n_folds = n_folds,
                                       metrics = ["c_index", "brier_score"],
                                       pretrained = True,
                                      )["str"]
def evaluate_dataset(X, T, Y, time_horizons, model = "cox_ph", n_folds = 3):    
    predictor = RiskEstimation().get(model)
    
    const_cols = constant_columns(X)
    X = X.drop(columns = const_cols)
    
    return evaluate_survival_estimator(predictor, 
                                       X, T, Y, 
                                       time_horizons = time_horizons,
                                       n_folds = n_folds,
                                       metrics = ["c_index", "brier_score"],
                                      )["str"]
    
def synthetic_survival(X, T, Y,
                       time_to_event_model = "tenn",
                       synth_model = "ctgan",
                      ):
    
    # Uncensoring
    uncensored_dataset = dataset_uncensoring(time_to_event_model, X, T, E)

    # Fit the synth generator
    generator = Plugins().get(synth_model)
    generator.fit(uncensored_dataset)

    synth_df = generator.generate(len(X))
    # 
    
    synth_eval_X = synth_df.drop(columns = ["event_time"])
    synth_eval_T = synth_df["event_time"]
    synth_eval_E = pd.Series([1] * len(X), index = synth_eval_X.index)
    
    return synth_eval_X, synth_eval_T, synth_eval_E

## AIDS dataset

In [4]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("aids")

X = df.drop(columns=["event", "duration"])
E = df["event"]
T = df["duration"]

X

Unnamed: 0,age,cd4,hemophil,ivdrug,karnof,priorzdv,raceth,sex,strat2,tx,txgrp
0,34.0,169.0,0,0,0,39.0,0,0,1,0,0
1,34.0,149.5,0,0,3,15.0,1,1,1,0,0
2,20.0,23.5,1,0,0,9.0,0,0,0,1,1
3,48.0,46.0,0,0,3,53.0,0,0,1,0,0
4,46.0,10.0,0,2,3,12.0,0,0,0,1,1
...,...,...,...,...,...,...,...,...,...,...,...
1146,44.0,65.5,0,0,0,103.0,0,0,1,1,1
1147,41.0,7.5,0,0,2,20.0,1,0,0,1,1
1148,43.0,170.0,0,2,3,27.0,1,0,1,0,0
1149,44.0,282.5,0,2,2,12.0,0,0,1,0,0


In [5]:
evaluate_dataset(X, T, E, time_horizons)

{'c_index': '0.7389 +/- 0.0531', 'brier_score': '0.0617 +/- 0.0034'}

In [6]:
results = {}
for tte_model in ["random_survival_forest", "date", "tenn", "deephit", "coxph"]:
    results[tte_model] = {}
    for synth_model in ["ctgan", "adsgan", "rtvae", "tvae"]:
        synth_eval_X, synth_eval_T, synth_eval_E = synthetic_survival(X, T, E, 
                                                                      time_to_event_model = tte_model, 
                                                                      synth_model = synth_model)

        try:
            eval_synth_synth = evaluate_dataset(synth_eval_X, synth_eval_T, synth_eval_E, time_horizons)
            eval_real_synth = evaluate_test_dataset(X, T, E, synth_eval_X, synth_eval_T, synth_eval_E, time_horizons)
            eval_synth_real = evaluate_test_dataset(synth_eval_X, synth_eval_T, synth_eval_E, X, T, E, time_horizons)
        except BaseException as e:
            print(" !!! evaluation failed ", tte_model, synth_model, e)
            continue
            
        results[tte_model][synth_model] = {
            "synth_synth": eval_synth_synth,
            "synth_real": eval_synth_real,
            "real_synth": eval_real_synth,
        }
        print(f"[{tte_model}][{synth_model}] >>> evaluation synth_synth={eval_synth_synth} synth_real={eval_synth_real} real_synth={eval_real_synth}")

[random_survival_forest][ctgan] >>> evaluation synth_synth={'c_index': '0.506 +/- 0.0147', 'brier_score': '0.1263 +/- 0.0131'} synth_real={'c_index': '0.6058 +/- 0.0342', 'brier_score': '0.0776 +/- 0.001'} real_synth={'c_index': '0.4802 +/- 0.0296', 'brier_score': '0.1447 +/- 0.0158'}
[random_survival_forest][adsgan] >>> evaluation synth_synth={'c_index': '0.9147 +/- 0.0043', 'brier_score': '0.0771 +/- 0.0059'} synth_real={'c_index': '0.6922 +/- 0.0562', 'brier_score': '0.0999 +/- 0.004'} real_synth={'c_index': '0.8953 +/- 0.0077', 'brier_score': '0.222 +/- 0.0036'}
[random_survival_forest][rtvae] >>> evaluation synth_synth={'c_index': '0.699 +/- 0.0294', 'brier_score': '0.1726 +/- 0.0012'} synth_real={'c_index': '0.5374 +/- 0.0374', 'brier_score': '0.1973 +/- 0.0136'} real_synth={'c_index': '0.6349 +/- 0.0177', 'brier_score': '0.3073 +/- 0.0113'}
[random_survival_forest][tvae] >>> evaluation synth_synth={'c_index': '0.8738 +/- 0.0185', 'brier_score': '0.0792 +/- 0.0075'} synth_real={'

In [7]:
print_results(results)

Unnamed: 0,Decensoring model,Synth model,synth_real CINDEX,real_synth CINDEX,synth_synth CINDEX
0,random_survival_forest,ctgan,0.6058 +/- 0.0342,0.4802 +/- 0.0296,0.506 +/- 0.0147
1,random_survival_forest,adsgan,0.6922 +/- 0.0562,0.8953 +/- 0.0077,0.9147 +/- 0.0043
2,random_survival_forest,rtvae,0.5374 +/- 0.0374,0.6349 +/- 0.0177,0.699 +/- 0.0294
3,random_survival_forest,tvae,0.739 +/- 0.0403,0.874 +/- 0.0239,0.8738 +/- 0.0185
4,date,ctgan,0.4084 +/- 0.0332,0.4987 +/- 0.0207,0.5661 +/- 0.02
5,date,adsgan,0.6155 +/- 0.0056,0.7719 +/- 0.022,0.8112 +/- 0.0168
6,date,rtvae,0.6315 +/- 0.032,0.574 +/- 0.0401,0.8314 +/- 0.0316
7,date,tvae,0.7176 +/- 0.0201,0.861 +/- 0.0261,0.8508 +/- 0.0332
8,tenn,ctgan,0.3389 +/- 0.0283,0.5597 +/- 0.0224,0.4574 +/- 0.0871
9,tenn,adsgan,0.7035 +/- 0.056,0.8201 +/- 0.0127,0.8611 +/- 0.004


## Metabric dataset

In [8]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("metabric")
df = df[df["duration"] > 0]

X = df.drop(columns=["event", "duration"])
E = df["event"]
T = df["duration"]

X

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8
0,5.603834,7.811392,10.797988,5.967607,1.0,1.0,0.0,1.0,56.840000
1,5.284882,9.581043,10.204620,5.664970,1.0,0.0,0.0,1.0,85.940002
2,5.920251,6.776564,12.431715,5.873857,0.0,1.0,0.0,1.0,48.439999
3,6.654017,5.341846,8.646379,5.655888,0.0,0.0,0.0,0.0,66.910004
4,5.456747,5.339741,10.555724,6.008429,1.0,0.0,0.0,1.0,67.849998
...,...,...,...,...,...,...,...,...,...
1899,5.946987,5.370492,12.345780,5.741395,1.0,1.0,0.0,1.0,76.839996
1900,5.339228,5.408853,12.176101,5.693043,1.0,1.0,0.0,1.0,63.090000
1901,5.901610,5.272237,14.200950,6.139390,0.0,0.0,0.0,1.0,57.770000
1902,6.818109,5.372744,11.652624,6.077852,1.0,0.0,0.0,1.0,58.889999


In [9]:
evaluate_dataset(X, T, E, time_horizons)

{'c_index': '0.6333 +/- 0.0078', 'brier_score': '0.1887 +/- 0.009'}

In [10]:
results = {}
for tte_model in ["random_survival_forest", "date", "tenn", "deephit", "coxph"]:
    results[tte_model] = {}
    for synth_model in ["ctgan", "adsgan", "rtvae", "tvae"]:
        synth_eval_X, synth_eval_T, synth_eval_E = synthetic_survival(X, T, E, 
                                                                      time_to_event_model = tte_model, 
                                                                      synth_model = synth_model)

        try:
            eval_synth_synth = evaluate_dataset(synth_eval_X, synth_eval_T, synth_eval_E, time_horizons)
            eval_real_synth = evaluate_test_dataset(X, T, E, synth_eval_X, synth_eval_T, synth_eval_E, time_horizons)
            eval_synth_real = evaluate_test_dataset(synth_eval_X, synth_eval_T, synth_eval_E, X, T, E, time_horizons)
        except BaseException as e:
            print(" !!! evaluation failed ", tte_model, synth_model, e)
            continue
            
        results[tte_model][synth_model] = {
            "synth_synth": eval_synth_synth,
            "synth_real": eval_synth_real,
            "real_synth": eval_real_synth,
        }
        print(f"[{tte_model}][{synth_model}] >>> evaluation synth_synth={eval_synth_synth} synth_real={eval_synth_real} real_synth={eval_real_synth}")

[random_survival_forest][ctgan] >>> evaluation synth_synth={'c_index': '0.5652 +/- 0.0151', 'brier_score': '0.2015 +/- 0.0021'} synth_real={'c_index': '0.4484 +/- 0.0169', 'brier_score': '0.2614 +/- 0.0088'} real_synth={'c_index': '0.4778 +/- 0.012', 'brier_score': '0.2648 +/- 0.0058'}
[random_survival_forest][adsgan] >>> evaluation synth_synth={'c_index': '0.735 +/- 0.0041', 'brier_score': '0.0713 +/- 0.0052'} synth_real={'c_index': '0.5444 +/- 0.0233', 'brier_score': '0.2885 +/- 0.0169'} real_synth={'c_index': '0.6237 +/- 0.0097', 'brier_score': '0.1896 +/- 0.0045'}
[random_survival_forest][rtvae] >>> evaluation synth_synth={'c_index': '0.6251 +/- 0.0222', 'brier_score': '0.1457 +/- 0.0008'} synth_real={'c_index': '0.568 +/- 0.008', 'brier_score': '0.222 +/- 0.0161'} real_synth={'c_index': '0.478 +/- 0.0135', 'brier_score': '0.2007 +/- 0.003'}
[random_survival_forest][tvae] >>> evaluation synth_synth={'c_index': '0.7713 +/- 0.0074', 'brier_score': '0.1494 +/- 0.0004'} synth_real={'c_

In [11]:
print_results(results)

Unnamed: 0,Decensoring model,Synth model,synth_real CINDEX,real_synth CINDEX,synth_synth CINDEX
0,random_survival_forest,ctgan,0.4484 +/- 0.0169,0.4778 +/- 0.012,0.5652 +/- 0.0151
1,random_survival_forest,adsgan,0.5444 +/- 0.0233,0.6237 +/- 0.0097,0.735 +/- 0.0041
2,random_survival_forest,rtvae,0.568 +/- 0.008,0.478 +/- 0.0135,0.6251 +/- 0.0222
3,random_survival_forest,tvae,0.6322 +/- 0.0178,0.7361 +/- 0.0056,0.7713 +/- 0.0074
4,date,ctgan,0.5264 +/- 0.0239,0.4645 +/- 0.0048,0.6062 +/- 0.0117
5,date,adsgan,0.5059 +/- 0.0093,0.6945 +/- 0.0059,0.8269 +/- 0.0112
6,date,rtvae,0.514 +/- 0.0043,0.5143 +/- 0.0035,0.6213 +/- 0.0106
7,date,tvae,0.6341 +/- 0.0166,0.7012 +/- 0.0096,0.7695 +/- 0.0109
8,tenn,ctgan,0.581 +/- 0.0105,0.5172 +/- 0.0154,0.5719 +/- 0.0092
9,tenn,adsgan,0.4897 +/- 0.0206,0.9182 +/- 0.0098,0.7994 +/- 0.039


## SUPPORT dataset

In [12]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("support")
df = df[df["duration"] > 0]

X = df.drop(columns=["event", "duration"])
E = df["event"]
T = df["duration"]

X

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13
0,82.709961,1.0,2.0,1.0,0.0,0.0,0.0,160.0,55.0,16.0,38.195309,142.0,19.000000,1.099854
1,79.660950,1.0,0.0,1.0,0.0,0.0,1.0,54.0,67.0,16.0,38.000000,142.0,10.000000,0.899902
2,23.399990,1.0,2.0,3.0,0.0,0.0,1.0,87.0,144.0,45.0,37.296879,130.0,5.199219,1.199951
3,53.075989,1.0,4.0,3.0,0.0,0.0,0.0,55.0,100.0,18.0,36.000000,135.0,8.699219,0.799927
4,71.794983,0.0,1.0,1.0,0.0,0.0,0.0,65.0,135.0,40.0,38.593750,146.0,0.099991,0.399963
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8868,81.064941,0.0,4.0,1.0,0.0,0.0,1.0,111.0,110.0,34.0,39.593750,135.0,13.000000,1.500000
8869,72.560966,0.0,2.0,1.0,0.0,0.0,1.0,53.0,74.0,28.0,34.695309,139.0,7.899414,1.899902
8870,63.228001,0.0,1.0,1.0,0.0,0.0,2.0,95.0,110.0,22.0,38.695309,132.0,7.799805,1.500000
8871,75.405937,0.0,2.0,1.0,1.0,0.0,2.0,109.0,110.0,30.0,36.195309,140.0,15.398438,0.899902


In [13]:
evaluate_dataset(X, T, E, time_horizons)

{'c_index': '0.5706 +/- 0.0054', 'brier_score': '0.2084 +/- 0.004'}

In [14]:
results = {}
for tte_model in ["random_survival_forest", "date", "tenn", "deephit", "coxph"]:
    results[tte_model] = {}
    for synth_model in ["ctgan", "adsgan", "rtvae", "tvae"]:
        synth_eval_X, synth_eval_T, synth_eval_E = synthetic_survival(X, T, E, 
                                                                      time_to_event_model = tte_model, 
                                                                      synth_model = synth_model)

        try:
            eval_synth_synth = evaluate_dataset(synth_eval_X, synth_eval_T, synth_eval_E, time_horizons)
            eval_real_synth = evaluate_test_dataset(X, T, E, synth_eval_X, synth_eval_T, synth_eval_E, time_horizons)
            eval_synth_real = evaluate_test_dataset(synth_eval_X, synth_eval_T, synth_eval_E, X, T, E, time_horizons)
        except BaseException as e:
            print(" !!! evaluation failed ", tte_model, synth_model, e)
            continue
            
        results[tte_model][synth_model] = {
            "synth_synth": eval_synth_synth,
            "synth_real": eval_synth_real,
            "real_synth": eval_real_synth,
        }
        print(f"[{tte_model}][{synth_model}] >>> evaluation synth_synth={eval_synth_synth} synth_real={eval_synth_real} real_synth={eval_real_synth}")

[random_survival_forest][ctgan] >>> evaluation synth_synth={'c_index': '0.5612 +/- 0.0055', 'brier_score': '0.1349 +/- 0.0034'} synth_real={'c_index': '0.531 +/- 0.0067', 'brier_score': '0.2416 +/- 0.008'} real_synth={'c_index': '0.5352 +/- 0.006', 'brier_score': '0.1646 +/- 0.0031'}
[random_survival_forest][adsgan] >>> evaluation synth_synth={'c_index': '0.8282 +/- 0.0042', 'brier_score': '0.0664 +/- 0.0034'} synth_real={'c_index': '0.4582 +/- 0.0061', 'brier_score': '0.3193 +/- 0.0081'} real_synth={'c_index': '0.343 +/- 0.0051', 'brier_score': '0.2332 +/- 0.0023'}
[random_survival_forest][rtvae] >>> evaluation synth_synth={'c_index': '0.5753 +/- 0.0026', 'brier_score': '0.1556 +/- 0.0014'} synth_real={'c_index': '0.5089 +/- 0.0022', 'brier_score': '0.2506 +/- 0.007'} real_synth={'c_index': '0.4723 +/- 0.0019', 'brier_score': '0.2087 +/- 0.0033'}
[random_survival_forest][tvae] >>> evaluation synth_synth={'c_index': '0.6511 +/- 0.0049', 'brier_score': '0.055 +/- 0.003'} synth_real={'c_

In [15]:
print_results(results)

Unnamed: 0,Decensoring model,Synth model,synth_real CINDEX,real_synth CINDEX,synth_synth CINDEX
0,random_survival_forest,ctgan,0.531 +/- 0.0067,0.5352 +/- 0.006,0.5612 +/- 0.0055
1,random_survival_forest,adsgan,0.4582 +/- 0.0061,0.343 +/- 0.0051,0.8282 +/- 0.0042
2,random_survival_forest,rtvae,0.5089 +/- 0.0022,0.4723 +/- 0.0019,0.5753 +/- 0.0026
3,random_survival_forest,tvae,0.5617 +/- 0.0018,0.616 +/- 0.0064,0.6511 +/- 0.0049
4,date,ctgan,0.5384 +/- 0.0053,0.5185 +/- 0.0068,0.5251 +/- 0.0028
5,date,adsgan,0.5377 +/- 0.0024,0.7224 +/- 0.0044,0.796 +/- 0.0061
6,date,rtvae,0.5288 +/- 0.0089,0.548 +/- 0.006,0.6294 +/- 0.0032
7,date,tvae,0.5543 +/- 0.006,0.556 +/- 0.0102,0.5869 +/- 0.0089
8,tenn,ctgan,0.5212 +/- 0.0013,0.5249 +/- 0.0068,0.573 +/- 0.007
9,tenn,adsgan,0.519 +/- 0.0053,0.7002 +/- 0.0067,0.8055 +/- 0.0078
