In [1]:
import numpy as np

from pycox import datasets
from lifelines.datasets import load_rossi
from sksurv.datasets import (
    load_aids,
    load_breast_cancer,
    load_flchain,
    load_gbsg2,
    load_whas500,
)
from sklearn.preprocessing import LabelEncoder


def get_dataset(name: str):
    if name == "metabric":
        df = datasets.metabric.read_df()
    elif name == "support":
        df = datasets.support.read_df()
    elif name == "gbsg":
        df = datasets.gbsg.read_df()
    elif name == "rossi":
        df = load_rossi()
        df = df.rename(columns={"week": "duration", "arrest": "event"})
    elif name == "aids":
        X, Y = load_aids()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "flchain":
        X, Y = load_flchain()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "gbsg2":
        X, Y = load_gbsg2()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "whas500":
        X, Y = load_whas500()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]

    for col in df.columns:
        if df[col].dtype.name in ["object", "category"]:
            df[col] = LabelEncoder().fit_transform(df[col])

    duration_col = "duration"
    event_col = "event"
    
    df = df[df["duration"] > 0].fillna(0)

    X = df.drop(columns=[duration_col, event_col])
    Y = df[event_col]
    T = df[duration_col]

    time_horizons = np.linspace(T.min(), T.max(), num=5)[1:-1]

    return df, X, T, Y, time_horizons

In [2]:
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
import pandas as pd


def constant_columns(dataframe: pd.DataFrame) -> list:
    """
    Drops constant value columns of pandas dataframe.
    """
    result = []
    for column in dataframe.columns:
        if len(dataframe[column].unique()) == 1:
            result.append(column)
    return result


def _train_and_evaluate(
    X_train, T_train, Y_train, X_test, T_test, Y_test, time_horizons, n_folds=3, key = "clf"
):
    predictor = RiskEstimation().get("cox_ph")
    n_folds = 3

    const_cols = constant_columns(X_train)
    X_train = X_train.drop(columns=const_cols)
    X_test = X_test.drop(columns=const_cols)
    try:
        predictor.fit(X_train, T_train, Y_train)

        return evaluate_survival_estimator(
            [predictor] * n_folds,
            X_test,
            T_test,
            Y_test,
            time_horizons=time_horizons,
            n_folds=n_folds,
            metrics = ["c_index", "brier_score"],
            pretrained=True,
        )[key]

    except BaseException:
        return {
            "c_index": 0,
            "aucroc": 0,
            "brier_score": 1,
        }


def _fold_evaluate(
    X_test,
    T_test,
    Y_test,
    time_horizons,
    n_folds=3,
     key = "clf",
):
    predictor = RiskEstimation().get("cox_ph")
    n_folds = 3

    const_cols = constant_columns(X_test)
    X_test = X_test.drop(columns=const_cols)

    try:
        return evaluate_survival_estimator(
            predictor,
            X_test,
            T_test,
            Y_test,
            time_horizons=time_horizons,
            n_folds=n_folds,
            metrics = ["c_index", "brier_score"],
        )[key]
    except BaseException:
        return {
            "c_index": 0,
            "aucroc": 0,
            "brier_score": 1,
        }


def evaluate_surv_generation(generative_method: str, dataset: str, generative_method_args: dict = {}):
    df, X_real, T_real, Y_real, time_horizons = get_dataset(dataset)

    syn_model = Plugins().get(generative_method, **generative_method_args)
    syn_model.fit(df)

    df_generated = syn_model.generate()

    X_syn = df_generated.drop(columns=["duration", "event"])
    Y_syn = df_generated["event"]
    T_syn = df_generated["duration"]

    real_real_score = _fold_evaluate(X_real, T_real, Y_real, time_horizons, key = "str")
    syn_syn_score = _fold_evaluate(X_syn, T_syn, Y_syn, time_horizons, key = "str")
    real_syn_score = _train_and_evaluate(
        X_real, T_real, Y_real, X_syn, T_syn, Y_syn, time_horizons, key = "str"
    )
    syn_real_score = _train_and_evaluate(
        X_syn, T_syn, Y_syn, X_real, T_real, Y_real, time_horizons, key = "str"
    )

    return {
        "real_real": real_real_score,
        "syn_syn": syn_syn_score,
        "real_syn": real_syn_score,
        "syn_real": syn_real_score,
    }

In [3]:
import tabulate
from IPython.display import HTML, display


def evaluate_dataset(dataset):
    model_metrics = {}

    for generator in Plugins().list():
        try:
            model_metrics[generator] = evaluate_surv_generation(generator, dataset)
        except BaseException as e:
            print("generator failed", generator, e)
            continue

    return model_metrics


def plot_metrics(model_metrics):
    for metric in ["c_index", "brier_score"]:
        headers = ["generator", "real_real", "syn_syn", "real_syn", "syn_real"]

        results = []
        for generator in model_metrics:
            local_results = [
                generator,
                model_metrics[generator]["real_real"][metric],
                model_metrics[generator]["syn_syn"][metric],
                model_metrics[generator]["real_syn"][metric],
                model_metrics[generator]["syn_real"][metric],
            ]
            results.append(local_results)

        print(metric)
        display(HTML(tabulate.tabulate(results, headers=headers, tablefmt="html")))

In [4]:
# Automl templates

import sys
import synthcity.logger as log
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.tester import evaluate_survival_estimator
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
import optuna
from synthcity.plugins.models.time_to_event.tte_tenn import TENNTimeToEvent

optuna.logging.set_verbosity(optuna.logging.DEBUG)
optuna.logging.enable_propagation()
optuna.logging.enable_default_handler()

log.remove()
log.add(sink=sys.stderr, level="CRITICAL")


def _trial_params(trial, param_space):
    out = {}

    for param in param_space:
        if hasattr(param, "choices"):
            out[param.name] = trial.suggest_categorical(
                param.name, choices=param.choices
            )
        elif hasattr(param, "step"):
            out[param.name] = trial.suggest_int(
                param.name, param.low, param.high, param.step
            )
        else:
            out[param.name] = trial.suggest_float(param.name, param.low, param.high)

    return out

def objective_meta(model_name, dataset):
    def objective(trial):
        template = Plugins().get_type(model_name)
        
        params = _trial_params(trial, template.hyperparameter_space())
        df, X_real, T_real, Y_real, time_horizons = get_dataset(dataset)

        syn_model = Plugins().get(model_name, **params)
        syn_model.fit(df)

        df_generated = syn_model.generate()

        X_syn = df_generated.drop(columns=["duration", "event"])
        Y_syn = df_generated["event"]
        T_syn = df_generated["duration"]

        score = _train_and_evaluate(
            X_syn, T_syn, Y_syn, X_real, T_real, Y_real, time_horizons
        )["c_index"][0]

        return score

    return objective

def search_args(model_name, dataset, n_trials = 30):
    storage = optuna.storages.RedisStorage(url="redis://localhost")
    study = optuna.create_study(
        study_name = f"{model_name}_{dataset}_v4", 
        direction='maximize',
        storage = storage,
        load_if_exists = True,
    )    
    print("previous best", study.best_trial.params, flush = True)
    study.optimize(objective_meta(model_name, dataset), n_trials=n_trials)

    if study.best_trial.params is not None:
        return study.best_trial.params
        
    return {}

def uncensoring_objective_meta(model_name, dataset, uncensoring_model, model_args = {}):
    def objective(trial):
        if uncensoring_model == "tenn":
            template = TENNTimeToEvent
        else:
            raise ValueError(uncensoring_model)
            
        params = _trial_params(trial, template.hyperparameter_space())
        
        args = model_args
        args["uncensoring_model"] = "tenn"
        args["uncensoring_model_args"] = params
        
        df, X_real, T_real, Y_real, time_horizons = get_dataset(dataset)

        syn_model = Plugins().get(model_name, **args)
        syn_model.fit(df)

        df_generated = syn_model.generate()

        X_syn = df_generated.drop(columns=["duration", "event"])
        Y_syn = df_generated["event"]
        T_syn = df_generated["duration"]

        score = _train_and_evaluate(
            X_syn, T_syn, Y_syn, X_real, T_real, Y_real, time_horizons
        )["c_index"][0]

        return score

    return objective

def uncensoring_search_args(model_name, dataset, uncensoring_model, model_args = {}, n_trials = 30):
    storage = optuna.storages.RedisStorage(url="redis://localhost")
    study = optuna.create_study(
        study_name = f"{model_name}_{dataset}_{uncensoring_model}_v4", 
        direction='maximize',
        storage = storage,
        load_if_exists = True,
    )
    study.optimize(uncensoring_objective_meta(model_name, dataset, uncensoring_model, model_args = model_args), n_trials=n_trials)

    if study.best_trial.params is not None:
        return study.best_trial.params
        
    return {}

## AIDS dataset

In [5]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("aids")

df

Unnamed: 0,age,cd4,hemophil,ivdrug,karnof,priorzdv,raceth,sex,strat2,tx,txgrp,event,duration
0,34.0,169.0,0,0,0,39.0,0,0,1,0,0,0,189.0
1,34.0,149.5,0,0,3,15.0,1,1,1,0,0,0,287.0
2,20.0,23.5,1,0,0,9.0,0,0,0,1,1,0,242.0
3,48.0,46.0,0,0,3,53.0,0,0,1,0,0,0,199.0
4,46.0,10.0,0,2,3,12.0,0,0,0,1,1,0,286.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1146,44.0,65.5,0,0,0,103.0,0,0,1,1,1,0,273.0
1147,41.0,7.5,0,0,2,20.0,1,0,0,1,1,1,47.0
1148,43.0,170.0,0,2,3,27.0,1,0,1,0,0,0,272.0
1149,44.0,282.5,0,2,2,12.0,0,0,1,0,0,0,192.0


In [6]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation

predictor = RiskEstimation().get_type("cox_ph")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)

evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.7389 +/- 0.0531',
 'brier_score': '0.0617 +/- 0.0034',
 'aucroc': '0.7269 +/- 0.0281'}

In [7]:
model_metrics = evaluate_dataset("aids")

In [8]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.7389 +/- 0.0531,0.6572 +/- 0.0386,0.5549 +/- 0.0228,0.5924 +/- 0.0297
copulagan,0.7389 +/- 0.0531,0,0,0.5 +/- 0.0
ctgan,0.7389 +/- 0.0531,0.6783 +/- 0.0185,0.7031 +/- 0.0267,0.6991 +/- 0.0555
pategan,0.7389 +/- 0.0531,0.7084 +/- 0.0179,0.4562 +/- 0.0181,0.5471 +/- 0.0267
survae,0.7389 +/- 0.0531,0.5905 +/- 0.0244,0.5094 +/- 0.0354,0.642 +/- 0.043
adsgan,0.7389 +/- 0.0531,0,0,0.5185 +/- 0.0083
rtvae,0.7389 +/- 0.0531,0.7592 +/- 0.0147,0.4852 +/- 0.0243,0.4631 +/- 0.0323
tvae,0.7389 +/- 0.0531,0,0,0.608 +/- 0.0385
survival_gan,0.7389 +/- 0.0531,0.7323 +/- 0.0344,0.748 +/- 0.0175,0.7359 +/- 0.0277
gaussian_copula,0.7389 +/- 0.0531,0,0,0.5 +/- 0.0


brier_score


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.0617 +/- 0.0034,0.0892 +/- 0.0041,0.0951 +/- 0.0028,0.0677 +/- 0.0004
copulagan,0.0617 +/- 0.0034,1,1,0.0701 +/- 0.0012
ctgan,0.0617 +/- 0.0034,0.0423 +/- 0.0061,0.0446 +/- 0.0067,0.0637 +/- 0.0017
pategan,0.0617 +/- 0.0034,0.1279 +/- 0.0014,0.2278 +/- 0.0011,0.1468 +/- 0.0044
survae,0.0617 +/- 0.0034,0.1571 +/- 0.011,0.2151 +/- 0.015,0.1395 +/- 0.0038
adsgan,0.0617 +/- 0.0034,1,1,0.087 +/- 0.0051
rtvae,0.0617 +/- 0.0034,0.1161 +/- 0.0021,0.2035 +/- 0.0071,0.2389 +/- 0.0074
tvae,0.0617 +/- 0.0034,1,1,0.0691 +/- 0.0012
survival_gan,0.0617 +/- 0.0034,0.0689 +/- 0.0127,0.0995 +/- 0.0167,0.0916 +/- 0.0041
gaussian_copula,0.0617 +/- 0.0034,1,1,0.0701 +/- 0.0012


## FLChain dataset

In [5]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("flchain")

df

Unnamed: 0,age,chapter,creatinine,flc.grp,kappa,lambda,mgus,sample.yr,sex,event,duration
0,97.0,1,1.7,1,5.700,4.860,0,2,0,1,85.0
1,92.0,12,0.9,0,0.870,0.683,0,5,0,1,1281.0
2,94.0,1,1.4,1,4.360,3.850,0,2,0,1,69.0
3,92.0,1,1.0,9,2.420,2.220,0,1,0,1,115.0
4,93.0,1,1.1,6,1.320,1.690,0,1,0,1,1039.0
...,...,...,...,...,...,...,...,...,...,...,...
7869,52.0,16,1.0,6,1.210,1.610,0,0,0,0,4997.0
7870,52.0,16,0.8,0,0.858,0.581,0,4,0,0,3652.0
7871,54.0,16,0.0,8,1.700,1.720,0,7,0,0,2507.0
7872,53.0,16,0.0,9,1.710,2.690,0,0,0,0,4982.0


In [6]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation

predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)

evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.9179 +/- 0.0033',
 'brier_score': '0.0557 +/- 0.0014',
 'aucroc': '0.9526 +/- 0.0003'}

In [7]:
model_metrics = evaluate_dataset("flchain")

generator failed survival_gan 


In [8]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.8841 +/- 0.0035,0,0,0.5 +/- 0.0
copulagan,0.8841 +/- 0.0035,0.9015 +/- 0.0012,0.8849 +/- 0.0015,0.8964 +/- 0.0008
bayesian_network,0.8841 +/- 0.0035,0.7741 +/- 0.0106,0.7622 +/- 0.0121,0.8914 +/- 0.0015
gaussian_copula,0.8841 +/- 0.0035,0.8522 +/- 0.0012,0.8243 +/- 0.003,0.9036 +/- 0.001
survae,0.8841 +/- 0.0035,0.7312 +/- 0.0015,0.5323 +/- 0.0149,0.4786 +/- 0.0179
tvae,0.8841 +/- 0.0035,0.91 +/- 0.0059,0.9105 +/- 0.0053,0.8593 +/- 0.0068
privbayes,0.8841 +/- 0.0035,0.5109 +/- 0.0201,0.5019 +/- 0.0255,0.1875 +/- 0.0021
rtvae,0.8841 +/- 0.0035,0.6751 +/- 0.0141,0.4108 +/- 0.0116,0.1817 +/- 0.0109
nflow,0.8841 +/- 0.0035,0.9236 +/- 0.0018,0.891 +/- 0.0039,0.8978 +/- 0.0022
ctgan,0.8841 +/- 0.0035,0.8673 +/- 0.0058,0.8567 +/- 0.0052,0.8681 +/- 0.005


brier_score


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.0828 +/- 0.0014,1,1,0.169 +/- 0.0044
copulagan,0.0828 +/- 0.0014,0.0768 +/- 0.0011,0.0825 +/- 0.0018,0.0876 +/- 0.0011
bayesian_network,0.0828 +/- 0.0014,0.1275 +/- 0.0049,0.1511 +/- 0.0046,0.0927 +/- 0.0023
gaussian_copula,0.0828 +/- 0.0014,0.084 +/- 0.0016,0.0946 +/- 0.001,0.0912 +/- 0.001
survae,0.0828 +/- 0.0014,0.1249 +/- 0.0008,0.2053 +/- 0.0008,0.2002 +/- 0.0031
tvae,0.0828 +/- 0.0014,0.0745 +/- 0.0029,0.0916 +/- 0.0007,0.086 +/- 0.0021
privbayes,0.0828 +/- 0.0014,0.1209 +/- 0.0024,0.1558 +/- 0.0053,0.1399 +/- 0.003
rtvae,0.0828 +/- 0.0014,0.1136 +/- 0.0026,0.1991 +/- 0.0007,0.1913 +/- 0.0024
nflow,0.0828 +/- 0.0014,0.0627 +/- 0.0017,0.0641 +/- 0.001,0.0961 +/- 0.001
ctgan,0.0828 +/- 0.0014,0.088 +/- 0.0032,0.0981 +/- 0.003,0.0856 +/- 0.0025


## gbsg2 dataset

In [13]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("gbsg2")

df

Unnamed: 0,age,estrec,horTh,menostat,pnodes,progrec,tgrade,tsize,event,duration
0,70.0,66.0,0,0,3.0,48.0,1,21.0,1,1814.0
1,56.0,77.0,1,0,7.0,61.0,1,12.0,1,2018.0
2,58.0,271.0,1,0,9.0,52.0,1,35.0,1,712.0
3,59.0,29.0,1,0,4.0,60.0,1,17.0,1,1807.0
4,73.0,65.0,0,0,1.0,26.0,1,35.0,1,772.0
...,...,...,...,...,...,...,...,...,...,...
681,49.0,84.0,0,1,3.0,1.0,2,30.0,0,721.0
682,53.0,0.0,1,0,17.0,0.0,2,25.0,0,186.0
683,51.0,0.0,0,1,5.0,43.0,2,25.0,1,769.0
684,52.0,34.0,0,0,3.0,15.0,1,23.0,1,727.0


In [14]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation

predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)

evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.6794 +/- 0.0158',
 'brier_score': '0.1998 +/- 0.0241',
 'aucroc': '0.7095 +/- 0.0159'}

In [15]:
model_metrics = evaluate_dataset("gbsg2")

In [16]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.6735 +/- 0.0254,0.6475 +/- 0.0131,0.6483 +/- 0.0166,0.6864 +/- 0.0179
copulagan,0.6735 +/- 0.0254,0.5671 +/- 0.0322,0.5662 +/- 0.0044,0.5849 +/- 0.0303
ctgan,0.6735 +/- 0.0254,0.5144 +/- 0.0229,0.5049 +/- 0.0049,0.5056 +/- 0.0026
pategan,0.6735 +/- 0.0254,0.6647 +/- 0.017,0.5519 +/- 0.0235,0.5284 +/- 0.0146
survae,0.6735 +/- 0.0254,0.6635 +/- 0.0177,0.4476 +/- 0.0093,0.5479 +/- 0.0065
adsgan,0.6735 +/- 0.0254,0.8416 +/- 0.0291,0.7569 +/- 0.0299,0.656 +/- 0.0364
rtvae,0.6735 +/- 0.0254,0.6146 +/- 0.0736,0.5566 +/- 0.0471,0.5174 +/- 0.0401
tvae,0.6735 +/- 0.0254,0.466 +/- 0.2493,0.7242 +/- 0.0181,0.6784 +/- 0.02
survival_gan,0.6735 +/- 0.0254,0.687 +/- 0.0434,0.6862 +/- 0.0459,0.6618 +/- 0.0319
gaussian_copula,0.6735 +/- 0.0254,0.6355 +/- 0.0212,0.6289 +/- 0.0273,0.6619 +/- 0.0369


brier_score


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.1957 +/- 0.0171,0.2023 +/- 0.0113,0.2042 +/- 0.0109,0.2001 +/- 0.0208
copulagan,0.1957 +/- 0.0171,0.102 +/- 0.005,0.3191 +/- 0.0159,0.3036 +/- 0.0035
ctgan,0.1957 +/- 0.0171,0.2145 +/- 0.0127,0.232 +/- 0.0056,0.2319 +/- 0.0249
pategan,0.1957 +/- 0.0171,0.1576 +/- 0.0095,0.2394 +/- 0.0033,0.2479 +/- 0.0295
survae,0.1957 +/- 0.0171,0.152 +/- 0.0081,0.2647 +/- 0.0053,0.2691 +/- 0.0363
adsgan,0.1957 +/- 0.0171,0.1133 +/- 0.0226,0.1895 +/- 0.017,0.2807 +/- 0.0329
rtvae,0.1957 +/- 0.0171,0.1117 +/- 0.0467,0.1908 +/- 0.0266,0.2936 +/- 0.0256
tvae,0.1957 +/- 0.0171,0.2755 +/- 0.0977,0.1977 +/- 0.0009,0.2251 +/- 0.01
survival_gan,0.1957 +/- 0.0171,0.1567 +/- 0.0185,0.2052 +/- 0.0153,0.25 +/- 0.0465
gaussian_copula,0.1957 +/- 0.0171,0.1989 +/- 0.0097,0.2019 +/- 0.0117,0.2071 +/- 0.0269


## Metabric

In [17]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("metabric")

df

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,duration,event
0,5.603834,7.811392,10.797988,5.967607,1.0,1.0,0.0,1.0,56.840000,99.333336,0
1,5.284882,9.581043,10.204620,5.664970,1.0,0.0,0.0,1.0,85.940002,95.733330,1
2,5.920251,6.776564,12.431715,5.873857,0.0,1.0,0.0,1.0,48.439999,140.233337,0
3,6.654017,5.341846,8.646379,5.655888,0.0,0.0,0.0,0.0,66.910004,239.300003,0
4,5.456747,5.339741,10.555724,6.008429,1.0,0.0,0.0,1.0,67.849998,56.933334,1
...,...,...,...,...,...,...,...,...,...,...,...
1899,5.946987,5.370492,12.345780,5.741395,1.0,1.0,0.0,1.0,76.839996,87.233330,1
1900,5.339228,5.408853,12.176101,5.693043,1.0,1.0,0.0,1.0,63.090000,157.533340,0
1901,5.901610,5.272237,14.200950,6.139390,0.0,0.0,0.0,1.0,57.770000,37.866665,1
1902,6.818109,5.372744,11.652624,6.077852,1.0,0.0,0.0,1.0,58.889999,198.433334,0


In [18]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("metabric")

predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)


evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.6412 +/- 0.0078',
 'brier_score': '0.1988 +/- 0.0124',
 'aucroc': '0.6874 +/- 0.0143'}

In [19]:
model_metrics = evaluate_dataset("metabric")

In [20]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.6355 +/- 0.0128,0.6374 +/- 0.0098,0.6412 +/- 0.0106,0.6368 +/- 0.0137
copulagan,0.6355 +/- 0.0128,0.5207 +/- 0.0152,0.5011 +/- 0.0126,0.5895 +/- 0.012
ctgan,0.6355 +/- 0.0128,0.5747 +/- 0.0239,0.4679 +/- 0.0454,0.4702 +/- 0.0128
pategan,0.6355 +/- 0.0128,0.6694 +/- 0.0132,0.4188 +/- 0.0111,0.4304 +/- 0.004
survae,0.6355 +/- 0.0128,0.7072 +/- 0.0074,0.4776 +/- 0.0193,0.459 +/- 0.0258
adsgan,0.6355 +/- 0.0128,0.812 +/- 0.0158,0.6367 +/- 0.0255,0.5309 +/- 0.0169
rtvae,0.6355 +/- 0.0128,0.7515 +/- 0.0136,0.4231 +/- 0.0094,0.4962 +/- 0.005
tvae,0.6355 +/- 0.0128,0.8011 +/- 0.0122,0.7404 +/- 0.0178,0.6407 +/- 0.0085
survival_gan,0.6355 +/- 0.0128,0.6458 +/- 0.0115,0.6207 +/- 0.0125,0.6105 +/- 0.0186
gaussian_copula,0.6355 +/- 0.0128,0.6066 +/- 0.0293,0.5932 +/- 0.0287,0.6265 +/- 0.0081


brier_score


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.1876 +/- 0.0077,0.1896 +/- 0.0048,0.1877 +/- 0.0055,0.1872 +/- 0.0081
copulagan,0.1876 +/- 0.0077,0.1934 +/- 0.0058,0.2346 +/- 0.0049,0.2196 +/- 0.0082
ctgan,0.1876 +/- 0.0077,0.1891 +/- 0.0024,0.249 +/- 0.0248,0.2597 +/- 0.0026
pategan,0.1876 +/- 0.0077,0.1176 +/- 0.0062,0.208 +/- 0.005,0.2845 +/- 0.0144
survae,0.1876 +/- 0.0077,0.1121 +/- 0.0039,0.1877 +/- 0.006,0.2733 +/- 0.0227
adsgan,0.1876 +/- 0.0077,0.3903 +/- 0.0368,0.1347 +/- 0.0149,0.4819 +/- 0.0077
rtvae,0.1876 +/- 0.0077,0.118 +/- 0.0051,0.2481 +/- 0.0107,0.2695 +/- 0.0123
tvae,0.1876 +/- 0.0077,0.1292 +/- 0.0113,0.1664 +/- 0.0129,0.2149 +/- 0.0106
survival_gan,0.1876 +/- 0.0077,0.2308 +/- 0.174,0.1571 +/- 0.0028,0.2305 +/- 0.0107
gaussian_copula,0.1876 +/- 0.0077,0.1963 +/- 0.0029,0.2107 +/- 0.002,0.1951 +/- 0.0103


In [9]:
search_args("survival_gan", "metabric")

[32m[I 2022-05-04 11:32:12,869][0m A new study created in Redis with name: survival_gan_metabric_v4[0m


ValueError: No trials are completed yet.

## gbsg

In [21]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("gbsg")

df

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,duration,event
0,0.0,0.0,0.0,32.0,1.0,155.0,168.0,84.000000,0
1,0.0,1.0,0.0,27.0,1.0,717.0,95.0,84.000000,0
2,0.0,1.0,1.0,52.0,1.0,120.0,437.0,84.000000,0
3,0.0,0.0,0.0,28.0,1.0,251.0,11.0,84.000000,0
4,0.0,0.0,0.0,39.0,1.0,241.0,92.0,66.234085,1
...,...,...,...,...,...,...,...,...,...
2227,0.0,1.0,0.0,49.0,3.0,1.0,84.0,23.687885,0
2228,1.0,1.0,1.0,53.0,17.0,0.0,0.0,6.110883,0
2229,0.0,1.0,0.0,51.0,5.0,43.0,0.0,25.264887,1
2230,0.0,1.0,1.0,52.0,3.0,15.0,34.0,23.885010,1


In [22]:
model_metrics = evaluate_dataset("gbsg")

In [23]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.6776 +/- 0.0169,0.6791 +/- 0.0112,0.6802 +/- 0.0119,0.6766 +/- 0.0171
copulagan,0.6776 +/- 0.0169,0.5187 +/- 0.0179,0.5371 +/- 0.0199,0.6085 +/- 0.0171
ctgan,0.6776 +/- 0.0169,0.5769 +/- 0.0131,0.5622 +/- 0.0227,0.668 +/- 0.0104
pategan,0.6776 +/- 0.0169,0.6534 +/- 0.0105,0.4666 +/- 0.0193,0.4634 +/- 0.0101
survae,0.6776 +/- 0.0169,0.7257 +/- 0.0088,0.5183 +/- 0.0169,0.5434 +/- 0.0122
adsgan,0.6776 +/- 0.0169,0.7185 +/- 0.0336,0.5849 +/- 0.0228,0.4979 +/- 0.01
rtvae,0.6776 +/- 0.0169,0.646 +/- 0.0334,0.516 +/- 0.0259,0.4427 +/- 0.0049
tvae,0.6776 +/- 0.0169,0.8285 +/- 0.0142,0.8031 +/- 0.0147,0.6774 +/- 0.0206
survival_gan,0.6776 +/- 0.0169,0.6495 +/- 0.012,0.5864 +/- 0.0101,0.6261 +/- 0.0242
gaussian_copula,0.6776 +/- 0.0169,0.6437 +/- 0.0132,0.6414 +/- 0.0155,0.6843 +/- 0.0169


brier_score


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.2048 +/- 0.0022,0.2033 +/- 0.0023,0.2043 +/- 0.0019,0.2066 +/- 0.0024
copulagan,0.2048 +/- 0.0022,0.1718 +/- 0.0009,0.2193 +/- 0.0057,0.2656 +/- 0.0024
ctgan,0.2048 +/- 0.0022,0.2064 +/- 0.0048,0.2168 +/- 0.0095,0.2196 +/- 0.0008
pategan,0.2048 +/- 0.0022,0.1117 +/- 0.0022,0.2844 +/- 0.0021,0.3198 +/- 0.013
survae,0.2048 +/- 0.0022,0.1181 +/- 0.0031,0.2506 +/- 0.0029,0.2996 +/- 0.01
adsgan,0.2048 +/- 0.0022,0.0942 +/- 0.0056,0.1792 +/- 0.0022,0.3437 +/- 0.0049
rtvae,0.2048 +/- 0.0022,0.1201 +/- 0.0034,0.2785 +/- 0.0034,0.3232 +/- 0.002
tvae,0.2048 +/- 0.0022,0.1293 +/- 0.0046,0.1866 +/- 0.0032,0.2204 +/- 0.0098
survival_gan,0.2048 +/- 0.0022,0.1233 +/- 0.0006,0.1947 +/- 0.0014,0.2777 +/- 0.0069
gaussian_copula,0.2048 +/- 0.0022,0.2101 +/- 0.0022,0.2157 +/- 0.0015,0.2075 +/- 0.0043


## Support

In [24]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("support")

df

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,duration,event
0,82.709961,1.0,2.0,1.0,0.0,0.0,0.0,160.0,55.0,16.0,38.195309,142.0,19.000000,1.099854,30.0,1
1,79.660950,1.0,0.0,1.0,0.0,0.0,1.0,54.0,67.0,16.0,38.000000,142.0,10.000000,0.899902,1527.0,0
2,23.399990,1.0,2.0,3.0,0.0,0.0,1.0,87.0,144.0,45.0,37.296879,130.0,5.199219,1.199951,96.0,1
3,53.075989,1.0,4.0,3.0,0.0,0.0,0.0,55.0,100.0,18.0,36.000000,135.0,8.699219,0.799927,892.0,0
4,71.794983,0.0,1.0,1.0,0.0,0.0,0.0,65.0,135.0,40.0,38.593750,146.0,0.099991,0.399963,7.0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8868,81.064941,0.0,4.0,1.0,0.0,0.0,1.0,111.0,110.0,34.0,39.593750,135.0,13.000000,1.500000,36.0,1
8869,72.560966,0.0,2.0,1.0,0.0,0.0,1.0,53.0,74.0,28.0,34.695309,139.0,7.899414,1.899902,49.0,1
8870,63.228001,0.0,1.0,1.0,0.0,0.0,2.0,95.0,110.0,22.0,38.695309,132.0,7.799805,1.500000,6.0,1
8871,75.405937,0.0,2.0,1.0,1.0,0.0,2.0,109.0,110.0,30.0,36.195309,140.0,15.398438,0.899902,10.0,1


In [26]:
model_metrics = evaluate_dataset("support")

In [27]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.5706 +/- 0.0054,0.5655 +/- 0.0036,0.5652 +/- 0.0018,0.5719 +/- 0.0051
copulagan,0.5706 +/- 0.0054,0.6711 +/- 0.0052,0.6305 +/- 0.0069,0.537 +/- 0.0037
ctgan,0.5706 +/- 0.0054,0.6399 +/- 0.0086,0.6098 +/- 0.0086,0.5331 +/- 0.0031
pategan,0.5706 +/- 0.0054,0.6457 +/- 0.0066,0.5033 +/- 0.0034,0.5276 +/- 0.0112
survae,0.5706 +/- 0.0054,0.5669 +/- 0.0037,0.4592 +/- 0.0046,0.4797 +/- 0.001
adsgan,0.5706 +/- 0.0054,0.8949 +/- 0.0044,0.8283 +/- 0.0034,0.5144 +/- 0.0098
rtvae,0.5706 +/- 0.0054,0.6562 +/- 0.0141,0.5214 +/- 0.0035,0.5054 +/- 0.0067
tvae,0.5706 +/- 0.0054,0.7223 +/- 0.0032,0.7022 +/- 0.0049,0.5568 +/- 0.001
survival_gan,0.5706 +/- 0.0054,0.6414 +/- 0.0069,0.5886 +/- 0.0156,0.5199 +/- 0.0055
gaussian_copula,0.5706 +/- 0.0054,0.5688 +/- 0.0065,0.5645 +/- 0.0049,0.5694 +/- 0.0059


brier_score


generator,real_real,syn_syn,real_syn,syn_real
bayesian_network,0.2084 +/- 0.004,0.2048 +/- 0.007,0.2049 +/- 0.0062,0.2087 +/- 0.0044
copulagan,0.2084 +/- 0.004,0.1797 +/- 0.0057,0.1938 +/- 0.0063,0.2283 +/- 0.0062
ctgan,0.2084 +/- 0.004,0.1818 +/- 0.006,0.1993 +/- 0.0081,0.2263 +/- 0.0057
pategan,0.2084 +/- 0.004,0.6783 +/- 0.0019,0.1789 +/- 0.0009,0.5892 +/- 0.0078
survae,0.2084 +/- 0.004,0.1481 +/- 0.0009,0.2214 +/- 0.0034,0.2713 +/- 0.0065
adsgan,0.2084 +/- 0.004,0.0553 +/- 0.0049,0.1794 +/- 0.001,0.4198 +/- 0.0088
rtvae,0.2084 +/- 0.004,0.1824 +/- 0.0014,0.2837 +/- 0.0045,0.2945 +/- 0.0058
tvae,0.2084 +/- 0.004,0.1275 +/- 0.0018,0.1738 +/- 0.0023,0.2813 +/- 0.0019
survival_gan,0.2084 +/- 0.004,0.1657 +/- 0.0025,0.1938 +/- 0.0033,0.2419 +/- 0.0073
gaussian_copula,0.2084 +/- 0.004,0.2168 +/- 0.0022,0.2192 +/- 0.002,0.2106 +/- 0.0045
