In [86]:
import numpy as np

from pycox import datasets
from lifelines.datasets import load_rossi
from sksurv.datasets import (
    load_aids,
    load_breast_cancer,
    load_flchain,
    load_gbsg2,
    load_whas500,
)
from sklearn.preprocessing import LabelEncoder


def get_dataset(name: str):
    if name == "metabric":
        df = datasets.metabric.read_df()
    elif name == "support":
        df = datasets.support.read_df()
    elif name == "gbsg":
        df = datasets.gbsg.read_df()
    elif name == "rossi":
        df = load_rossi()
        df = df.rename(columns={"week": "duration", "arrest": "event"})
    elif name == "aids":
        X, Y = load_aids()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "flchain":
        X, Y = load_flchain()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "gbsg2":
        X, Y = load_gbsg2()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "whas500":
        X, Y = load_whas500()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]

    for col in df.columns:
        if df[col].dtype.name in ["object", "category"]:
            df[col] = LabelEncoder().fit_transform(df[col])

    duration_col = "duration"
    event_col = "event"

    X = df.drop(columns=[duration_col, event_col])
    Y = df[event_col]
    T = df[duration_col]

    time_horizons = np.linspace(T.min(), T.max(), num=5)[1:-1]

    return df, X, T, Y, time_horizons

In [87]:
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
import pandas as pd


def constant_columns(dataframe: pd.DataFrame) -> list:
    """
    Drops constant value columns of pandas dataframe.
    """
    result = []
    for column in dataframe.columns:
        if len(dataframe[column].unique()) == 1:
            result.append(column)
    return result


def _train_and_evaluate(
    X_train, T_train, Y_train, X_test, T_test, Y_test, time_horizons, n_folds=3
):
    predictor = RiskEstimation().get("survival_xgboost")
    n_folds = 3

    const_cols = constant_columns(X_train)
    X_train = X_train.drop(columns=const_cols)
    X_test = X_test.drop(columns=const_cols)
    try:
        predictor.fit(X_train, T_train, Y_train)

        return evaluate_survival_estimator(
            [predictor] * n_folds,
            X_test,
            T_test,
            Y_test,
            time_horizons=time_horizons,
            n_folds=n_folds,
            pretrained=True,
        )["str"]

    except BaseException:
        return {
            "c_index": "0 +/-0",
            "aucroc": "0 +/- 0",
            "brier_score": "1 +/- 0",
        }


def _fold_evaluate(
    X_test,
    T_test,
    Y_test,
    time_horizons,
    n_folds=3,
):
    predictor = RiskEstimation().get("survival_xgboost")
    n_folds = 3

    const_cols = constant_columns(X_test)
    X_test = X_test.drop(columns=const_cols)

    try:
        return evaluate_survival_estimator(
            predictor,
            X_test,
            T_test,
            Y_test,
            time_horizons=time_horizons,
            n_folds=n_folds,
        )["str"]
    except BaseException:
        return {
            "c_index": "0 +/-0",
            "aucroc": "0 +/- 0",
            "brier_score": "1 +/- 0",
        }


def evaluate_surv_generation(generative_method: str, dataset: str):
    df, X_real, T_real, Y_real, time_horizons = get_dataset(dataset)

    syn_model = Plugins().get(generative_method)
    syn_model.fit(df)

    df_generated = syn_model.generate()

    X_syn = df_generated.drop(columns=["duration", "event"])
    Y_syn = df_generated["event"]
    T_syn = df_generated["duration"]

    real_real_score = _fold_evaluate(X_real, T_real, Y_real, time_horizons)
    syn_syn_score = _fold_evaluate(X_syn, T_syn, Y_syn, time_horizons)
    real_syn_score = _train_and_evaluate(
        X_real, T_real, Y_real, X_syn, T_syn, Y_syn, time_horizons
    )
    syn_real_score = _train_and_evaluate(
        X_syn, T_syn, Y_syn, X_real, T_real, Y_real, time_horizons
    )

    return {
        "real_real": real_real_score,
        "syn_syn": syn_syn_score,
        "real_syn": real_syn_score,
        "syn_real": syn_real_score,
    }

In [106]:
import tabulate
from IPython.display import HTML, display


def evaluate_dataset(dataset):
    model_metrics = {}

    for generator in Plugins().list():
        try:
            model_metrics[generator] = evaluate_surv_generation(generator, dataset)
        except BaseException as e:
            print("generator failed", generator, e)
            continue

    return model_metrics


def plot_metrics(model_metrics):
    for metric in ["c_index", "aucroc", "brier_score"]:
        headers = ["generator", "real_real", "syn_syn", "real_syn", "syn_real"]

        results = []
        for generator in model_metrics:
            local_results = [
                generator,
                model_metrics[generator]["real_real"][metric],
                model_metrics[generator]["syn_syn"][metric],
                model_metrics[generator]["real_syn"][metric],
                model_metrics[generator]["syn_real"][metric],
            ]
            results.append(local_results)

        print(metric)
        display(HTML(tabulate.tabulate(results, headers=headers, tablefmt="html")))

In [89]:
# Automl templates

import sys
import synthcity.logger as log
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.tester import evaluate_survival_estimator

log.remove()
log.add(sink=sys.stderr, level="CRITICAL")


def _trial_params(trial, param_space):
    out = {}

    for param in param_space:
        if hasattr(param, "choices"):
            out[param.name] = trial.suggest_categorical(
                param.name, choices=param.choices
            )
        elif hasattr(param, "step"):
            out[param.name] = trial.suggest_int(
                param.name, param.low, param.high, param.step
            )
        else:
            out[param.name] = trial.suggest_float(param.name, param.low, param.high)

    return out

## AIDS dataset

In [98]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("aids")

df

Unnamed: 0,age,cd4,hemophil,ivdrug,karnof,priorzdv,raceth,sex,strat2,tx,txgrp,event,duration
0,34.0,169.0,0,0,0,39.0,0,0,1,0,0,0,189.0
1,34.0,149.5,0,0,3,15.0,1,1,1,0,0,0,287.0
2,20.0,23.5,1,0,0,9.0,0,0,0,1,1,0,242.0
3,48.0,46.0,0,0,3,53.0,0,0,1,0,0,0,199.0
4,46.0,10.0,0,2,3,12.0,0,0,0,1,1,0,286.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1146,44.0,65.5,0,0,0,103.0,0,0,1,1,1,0,273.0
1147,41.0,7.5,0,0,2,20.0,1,0,0,1,1,1,47.0
1148,43.0,170.0,0,2,3,27.0,1,0,1,0,0,0,272.0
1149,44.0,282.5,0,2,2,12.0,0,0,1,0,0,0,192.0


In [99]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation

predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)

evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.7374 +/- 0.0241',
 'brier_score': '0.0623 +/- 0.0024',
 'aucroc': '0.7385 +/- 0.0419'}

In [100]:
model_metrics = evaluate_dataset("aids")

     fun: -5.346954018840134e+283
     jac: array([-inf, -inf,  inf,  inf])
 message: 'Rank-deficient equality constraint subproblem HFTI'
    nfev: 46
     nit: 13
    njev: 13
  status: 7
 success: False
       x: array([-1.01476494e+02, -5.05252821e-01,  3.02026426e+00,  1.41642475e-02])
None


In [101]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
nflow,0.7374 +/- 0.0241,0.7599 +/- 0.0261,0.7304 +/- 0.0405,0.6792 +/- 0.0231
adsgan,0.7374 +/- 0.0241,0.865 +/- 0.0177,0.822 +/- 0.0243,0.7223 +/- 0.0588
bayesian_network,0.7374 +/- 0.0241,0.5903 +/- 0.0127,0.5584 +/- 0.0246,0.6317 +/- 0.0397
rtvae,0.7374 +/- 0.0241,0.7841 +/- 0.0151,0.631 +/- 0.0301,0.6527 +/- 0.0459
gaussian_copula,0.7374 +/- 0.0241,0 +/-0,0 +/-0,0 +/-0
copulagan,0.7374 +/- 0.0241,0 +/-0,0 +/-0,0 +/-0
ctgan,0.7374 +/- 0.0241,0 +/-0,0.5103 +/- 0.0224,0.5098 +/- 0.0626
tvae,0.7374 +/- 0.0241,0 +/-0,0 +/-0,0 +/-0
pategan,0.7374 +/- 0.0241,0 +/-0,0 +/-0,0.5459 +/- 0.061
privbayes,0.7374 +/- 0.0241,0 +/-0,0.4996 +/- 0.0425,0.6269 +/- 0.0466


aucroc


generator,real_real,syn_syn,real_syn,syn_real
nflow,0.7385 +/- 0.0419,0.7247 +/- 0.0163,0.7313 +/- 0.0358,0.6868 +/- 0.0426
adsgan,0.7385 +/- 0.0419,0.834 +/- 0.0079,0.835 +/- 0.0275,0.7325 +/- 0.0115
bayesian_network,0.7385 +/- 0.0419,0.5456 +/- 0.004,0.5616 +/- 0.0371,0.6464 +/- 0.0216
rtvae,0.7385 +/- 0.0419,0.7965 +/- 0.0277,0.7557 +/- 0.0259,0.6665 +/- 0.0097
gaussian_copula,0.7385 +/- 0.0419,0 +/- 0,0 +/- 0,0 +/- 0
copulagan,0.7385 +/- 0.0419,0 +/- 0,0 +/- 0,0 +/- 0
ctgan,0.7385 +/- 0.0419,0 +/- 0,0.504 +/- 0.0419,0.5248 +/- 0.0205
tvae,0.7385 +/- 0.0419,0 +/- 0,0 +/- 0,0 +/- 0
pategan,0.7385 +/- 0.0419,0 +/- 0,0 +/- 0,0.5549 +/- 0.0161
privbayes,0.7385 +/- 0.0419,0 +/- 0,0.5099 +/- 0.0806,0.6324 +/- 0.0186


brier_score


generator,real_real,syn_syn,real_syn,syn_real
nflow,0.0623 +/- 0.0024,0.0332 +/- 0.0053,0.0349 +/- 0.0051,0.0637 +/- 0.0014
adsgan,0.0623 +/- 0.0024,0.0421 +/- 0.0028,0.0487 +/- 0.0036,0.0747 +/- 0.0062
bayesian_network,0.0623 +/- 0.0024,0.0962 +/- 0.0046,0.104 +/- 0.0048,0.0687 +/- 0.0014
rtvae,0.0623 +/- 0.0024,0.1298 +/- 0.0073,0.262 +/- 0.0062,0.1396 +/- 0.0061
gaussian_copula,0.0623 +/- 0.0024,1 +/- 0,1 +/- 0,1 +/- 0
copulagan,0.0623 +/- 0.0024,1 +/- 0,1 +/- 0,1 +/- 0
ctgan,0.0623 +/- 0.0024,1 +/- 0,0.1354 +/- 0.0085,0.0818 +/- 0.003
tvae,0.0623 +/- 0.0024,1 +/- 0,1 +/- 0,1 +/- 0
pategan,0.0623 +/- 0.0024,1 +/- 0,1 +/- 0,0.2412 +/- 0.0035
privbayes,0.0623 +/- 0.0024,1 +/- 0,0.0539 +/- 0.0049,0.065 +/- 0.0017


## FLChain dataset

In [104]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("flchain")
df = df.fillna(0)

df

Unnamed: 0,age,chapter,creatinine,flc.grp,kappa,lambda,mgus,sample.yr,sex,event,duration
0,97.0,1,1.7,1,5.700,4.860,0,2,0,1,85.0
1,92.0,12,0.9,0,0.870,0.683,0,5,0,1,1281.0
2,94.0,1,1.4,1,4.360,3.850,0,2,0,1,69.0
3,92.0,1,1.0,9,2.420,2.220,0,1,0,1,115.0
4,93.0,1,1.1,6,1.320,1.690,0,1,0,1,1039.0
...,...,...,...,...,...,...,...,...,...,...,...
7869,52.0,16,1.0,6,1.210,1.610,0,0,0,0,4997.0
7870,52.0,16,0.8,0,0.858,0.581,0,4,0,0,3652.0
7871,54.0,16,0.0,8,1.700,1.720,0,7,0,0,2507.0
7872,53.0,16,0.0,9,1.710,2.690,0,0,0,0,4982.0


In [93]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation

predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)

evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.9172 +/- 0.0003',
 'brier_score': '0.0561 +/- 0.0013',
 'aucroc': '0.9997 +/- 0.0004'}

In [107]:
model_metrics = evaluate_dataset("flchain")

generator failed nflow Shape of passed values is (7874, 2), indices imply (7874, 3)
generator failed adsgan Shape of passed values is (7874, 2), indices imply (7874, 3)
generator failed bayesian_network Input contains NaN, infinity or a value too large for dtype('float64').
generator failed rtvae Shape of passed values is (7874, 2), indices imply (7874, 3)
generator failed pategan Shape of passed values is (7874, 2), indices imply (7874, 3)


In [108]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
gaussian_copula,0.9172 +/- 0.0003,0.8531 +/- 0.0102,0.7668 +/- 0.0148,0.9041 +/- 0.004
copulagan,0.9172 +/- 0.0003,0.8828 +/- 0.0038,0.8553 +/- 0.0022,0.9053 +/- 0.0028
ctgan,0.9172 +/- 0.0003,0.8903 +/- 0.005,0.8662 +/- 0.0014,0.9037 +/- 0.0009
tvae,0.9172 +/- 0.0003,0.9444 +/- 0.0004,0.9242 +/- 0.0024,0.9111 +/- 0.0042
privbayes,0.9172 +/- 0.0003,0.5093 +/- 0.0153,0.4887 +/- 0.0037,0.4133 +/- 0.0222
marginal_distributions,0.9172 +/- 0.0003,0 +/-0,0 +/-0,0.5925 +/- 0.0166


aucroc


generator,real_real,syn_syn,real_syn,syn_real
gaussian_copula,0.9997 +/- 0.0004,0.933 +/- 0.0032,0.8468 +/- 0.0015,0.9996 +/- 0.0002
copulagan,0.9997 +/- 0.0004,0.9769 +/- 0.0024,0.9595 +/- 0.0029,0.9981 +/- 0.0006
ctgan,0.9997 +/- 0.0004,0.9756 +/- 0.0009,0.9662 +/- 0.0023,0.9954 +/- 0.0007
tvae,0.9997 +/- 0.0004,0.998 +/- 0.0009,0.9946 +/- 0.0004,0.9997 +/- 0.0
privbayes,0.9997 +/- 0.0004,0.4947 +/- 0.0076,0.4873 +/- 0.0121,0.3905 +/- 0.0028
marginal_distributions,0.9997 +/- 0.0004,0 +/- 0,0 +/- 0,0.6238 +/- 0.0065


brier_score


generator,real_real,syn_syn,real_syn,syn_real
gaussian_copula,0.0561 +/- 0.0013,0.0828 +/- 0.0029,0.1666 +/- 0.0042,0.0765 +/- 0.0026
copulagan,0.0561 +/- 0.0013,0.0884 +/- 0.001,0.0937 +/- 0.0017,0.0797 +/- 0.0006
ctgan,0.0561 +/- 0.0013,0.0836 +/- 0.0007,0.0862 +/- 0.0014,0.0895 +/- 0.0005
tvae,0.0561 +/- 0.0013,0.0408 +/- 0.0011,0.0535 +/- 0.0013,0.0789 +/- 0.003
privbayes,0.0561 +/- 0.0013,0.1289 +/- 0.0008,0.2379 +/- 0.0089,0.1518 +/- 0.0018
marginal_distributions,0.0561 +/- 0.0013,1 +/- 0,1 +/- 0,0.1817 +/- 0.0021


## gbsg2 dataset

In [109]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("gbsg2")

df

Unnamed: 0,age,estrec,horTh,menostat,pnodes,progrec,tgrade,tsize,event,duration
0,70.0,66.0,0,0,3.0,48.0,1,21.0,1,1814.0
1,56.0,77.0,1,0,7.0,61.0,1,12.0,1,2018.0
2,58.0,271.0,1,0,9.0,52.0,1,35.0,1,712.0
3,59.0,29.0,1,0,4.0,60.0,1,17.0,1,1807.0
4,73.0,65.0,0,0,1.0,26.0,1,35.0,1,772.0
...,...,...,...,...,...,...,...,...,...,...
681,49.0,84.0,0,1,3.0,1.0,2,30.0,0,721.0
682,53.0,0.0,1,0,17.0,0.0,2,25.0,0,186.0
683,51.0,0.0,0,1,5.0,43.0,2,25.0,1,769.0
684,52.0,34.0,0,0,3.0,15.0,1,23.0,1,727.0


In [95]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation

predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)

evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.6794 +/- 0.0158',
 'brier_score': '0.1998 +/- 0.0241',
 'aucroc': '0.7124 +/- 0.0176'}

In [110]:
model_metrics = evaluate_dataset("gbsg2")

None
None
None
None


In [111]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
nflow,0.6794 +/- 0.0158,0.6402 +/- 0.0426,0.6562 +/- 0.0219,0.6748 +/- 0.0269
adsgan,0.6794 +/- 0.0158,0.7297 +/- 0.019,0.6721 +/- 0.015,0.6163 +/- 0.0287
bayesian_network,0.6794 +/- 0.0158,0 +/-0,0 +/-0,0.6995 +/- 0.0193
rtvae,0.6794 +/- 0.0158,0 +/-0,0.5377 +/- 0.0198,0.4395 +/- 0.0206
gaussian_copula,0.6794 +/- 0.0158,0.6329 +/- 0.0191,0.6206 +/- 0.0263,0.696 +/- 0.0238
copulagan,0.6794 +/- 0.0158,0.5697 +/- 0.0207,0.4377 +/- 0.0176,0.3342 +/- 0.0219
ctgan,0.6794 +/- 0.0158,0 +/-0,0.495 +/- 0.0539,0.5341 +/- 0.0235
tvae,0.6794 +/- 0.0158,0 +/-0,0.7289 +/- 0.0248,0.6521 +/- 0.0111
pategan,0.6794 +/- 0.0158,0 +/-0,0 +/-0,0.4745 +/- 0.0236
privbayes,0.6794 +/- 0.0158,0 +/-0,0.4777 +/- 0.0427,0.5275 +/- 0.0177


aucroc


generator,real_real,syn_syn,real_syn,syn_real
nflow,0.7124 +/- 0.0176,0.5951 +/- 0.0567,0.7152 +/- 0.0295,0.7577 +/- 0.0283
adsgan,0.7124 +/- 0.0176,0.7056 +/- 0.0158,0.7474 +/- 0.0417,0.6504 +/- 0.0104
bayesian_network,0.7124 +/- 0.0176,0 +/- 0,0 +/- 0,0.7668 +/- 0.0179
rtvae,0.7124 +/- 0.0176,0 +/- 0,0.5273 +/- 0.038,0.4381 +/- 0.0203
gaussian_copula,0.7124 +/- 0.0176,0.6334 +/- 0.0475,0.6855 +/- 0.0331,0.7705 +/- 0.0218
copulagan,0.7124 +/- 0.0176,0.5579 +/- 0.0591,0.362 +/- 0.0241,0.2906 +/- 0.0183
ctgan,0.7124 +/- 0.0176,0 +/- 0,0.4575 +/- 0.0346,0.5227 +/- 0.025
tvae,0.7124 +/- 0.0176,0 +/- 0,0.8254 +/- 0.0009,0.7087 +/- 0.0365
pategan,0.7124 +/- 0.0176,0 +/- 0,0 +/- 0,0.4817 +/- 0.0263
privbayes,0.7124 +/- 0.0176,0 +/- 0,0.4903 +/- 0.0237,0.5355 +/- 0.0022


brier_score


generator,real_real,syn_syn,real_syn,syn_real
nflow,0.1998 +/- 0.0241,0.1981 +/- 0.0237,0.2011 +/- 0.0262,0.1981 +/- 0.0154
adsgan,0.1998 +/- 0.0241,0.1496 +/- 0.0108,0.1649 +/- 0.0125,0.2292 +/- 0.032
bayesian_network,0.1998 +/- 0.0241,1 +/- 0,1 +/- 0,0.1996 +/- 0.0319
rtvae,0.1998 +/- 0.0241,1 +/- 0,0.2255 +/- 0.0167,0.3214 +/- 0.0319
gaussian_copula,0.1998 +/- 0.0241,0.1991 +/- 0.0128,0.1994 +/- 0.0163,0.207 +/- 0.0295
copulagan,0.1998 +/- 0.0241,0.2262 +/- 0.0151,0.2797 +/- 0.0227,0.3221 +/- 0.0274
ctgan,0.1998 +/- 0.0241,1 +/- 0,0.2296 +/- 0.0199,0.2335 +/- 0.015
tvae,0.1998 +/- 0.0241,1 +/- 0,0.1812 +/- 0.0033,0.2329 +/- 0.0068
pategan,0.1998 +/- 0.0241,1 +/- 0,1 +/- 0,0.2983 +/- 0.0129
privbayes,0.1998 +/- 0.0241,1 +/- 0,0.2467 +/- 0.011,0.232 +/- 0.0201


## Metabric

In [113]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("metabric")

df

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,duration,event
0,5.603834,7.811392,10.797988,5.967607,1.0,1.0,0.0,1.0,56.840000,99.333336,0
1,5.284882,9.581043,10.204620,5.664970,1.0,0.0,0.0,1.0,85.940002,95.733330,1
2,5.920251,6.776564,12.431715,5.873857,0.0,1.0,0.0,1.0,48.439999,140.233337,0
3,6.654017,5.341846,8.646379,5.655888,0.0,0.0,0.0,0.0,66.910004,239.300003,0
4,5.456747,5.339741,10.555724,6.008429,1.0,0.0,0.0,1.0,67.849998,56.933334,1
...,...,...,...,...,...,...,...,...,...,...,...
1899,5.946987,5.370492,12.345780,5.741395,1.0,1.0,0.0,1.0,76.839996,87.233330,1
1900,5.339228,5.408853,12.176101,5.693043,1.0,1.0,0.0,1.0,63.090000,157.533340,0
1901,5.901610,5.272237,14.200950,6.139390,0.0,0.0,0.0,1.0,57.770000,37.866665,1
1902,6.818109,5.372744,11.652624,6.077852,1.0,0.0,0.0,1.0,58.889999,198.433334,0


In [15]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
import optuna

optuna.logging.set_verbosity(optuna.logging.DEBUG)
optuna.logging.enable_propagation()
optuna.logging.enable_default_handler()


predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)


def objective(trial):
    params = _trial_params(trial, predictor.hyperparameter_space())

    try:
        score = evaluate_survival_estimator(
            predictor(**params),
            X_real,
            T_real,
            Y_real,
            time_horizons=time_horizons,
            n_folds=n_folds,
        )["clf"]["c_index"][0]
    except BaseException as e:
        print("   >>>>>>>> trial failed", e)
        score = 0

    return score


# study = optuna.create_study(direction='maximize')
# study.optimize(objective, n_trials=30)

In [14]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("metabric")

predictor = RiskEstimation().get_type("survival_xgboost")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)


evaluate_survival_estimator(
    predictor(),
    X_real,
    T_real,
    Y_real,
    time_horizons=time_horizons,
    n_folds=n_folds,
)["str"]

{'c_index': '0.6412 +/- 0.0078',
 'brier_score': '0.1988 +/- 0.0124',
 'aucroc': '0.7188 +/- 0.0156'}

In [6]:
model_metrics = evaluate_dataset("metabric")

In [7]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.6355 +/- 0.0128,0.6964 +/- 0.0132,0.5863 +/- 0.0191,0.6191 +/- 0.0039
pategan,0.6355 +/- 0.0128,0.7896 +/- 0.0094,0.4082 +/- 0.0086,0.4789 +/- 0.0093
tvae,0.6355 +/- 0.0128,0.7509 +/- 0.0039,0.7225 +/- 0.0006,0.6348 +/- 0.0136
bayesian_network,0.6355 +/- 0.0128,0 +/-0,0 +/-0,0.6322 +/- 0.0138
privbayes,0.6355 +/- 0.0128,0.4889 +/- 0.0012,0.4991 +/- 0.028,0.4988 +/- 0.0208
ctgan,0.6355 +/- 0.0128,0.5771 +/- 0.0175,0.5158 +/- 0.0299,0.4856 +/- 0.0051
marginal_distributions,0.6355 +/- 0.0128,0 +/-0,0 +/-0,0.4646 +/- 0.0058
copulagan,0.6355 +/- 0.0128,0 +/-0,0 +/-0,0.5728 +/- 0.0145
gaussian_copula,0.6355 +/- 0.0128,0.6278 +/- 0.0162,0.614 +/- 0.0267,0.6234 +/- 0.0085
rtvae,0.6355 +/- 0.0128,0.6953 +/- 0.0115,0.4959 +/- 0.0225,0.435 +/- 0.0052


aucroc


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.7471 +/- 0.0132,0.828 +/- 0.0131,0.708 +/- 0.0153,0.6698 +/- 0.0059
pategan,0.7471 +/- 0.0132,0.9467 +/- 0.0069,0.3824 +/- 0.0398,0.3952 +/- 0.0223
tvae,0.7471 +/- 0.0132,0.8731 +/- 0.0113,0.8893 +/- 0.0059,0.7489 +/- 0.019
bayesian_network,0.7471 +/- 0.0132,0 +/- 0,0 +/- 0,0.7704 +/- 0.0136
privbayes,0.7471 +/- 0.0132,0.4621 +/- 0.0206,0.5047 +/- 0.0536,0.4153 +/- 0.025
ctgan,0.7471 +/- 0.0132,0.6207 +/- 0.0497,0.5259 +/- 0.019,0.4824 +/- 0.0314
marginal_distributions,0.7471 +/- 0.0132,0 +/- 0,0 +/- 0,0.485 +/- 0.0085
copulagan,0.7471 +/- 0.0132,0 +/- 0,0 +/- 0,0.6238 +/- 0.0371
gaussian_copula,0.7471 +/- 0.0132,0.7156 +/- 0.0014,0.6983 +/- 0.0127,0.7405 +/- 0.0198
rtvae,0.7471 +/- 0.0132,0.792 +/- 0.0476,0.5707 +/- 0.0338,0.3895 +/- 0.0169


brier_score


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.1876 +/- 0.0077,0.1882 +/- 0.0089,0.2638 +/- 0.0184,0.2347 +/- 0.01
pategan,0.1876 +/- 0.0077,0.1045 +/- 0.0065,0.2692 +/- 0.0225,0.2915 +/- 0.0159
tvae,0.1876 +/- 0.0077,0.1511 +/- 0.0105,0.1701 +/- 0.0099,0.2127 +/- 0.006
bayesian_network,0.1876 +/- 0.0077,1 +/- 0,1 +/- 0,0.1865 +/- 0.0072
privbayes,0.1876 +/- 0.0077,0.1833 +/- 0.0041,0.2239 +/- 0.01,0.2214 +/- 0.0101
ctgan,0.1876 +/- 0.0077,0.2008 +/- 0.0069,0.2507 +/- 0.0204,0.2199 +/- 0.0096
marginal_distributions,0.1876 +/- 0.0077,1 +/- 0,1 +/- 0,0.2611 +/- 0.0039
copulagan,0.1876 +/- 0.0077,1 +/- 0,1 +/- 0,0.2293 +/- 0.007
gaussian_copula,0.1876 +/- 0.0077,0.1951 +/- 0.0079,0.2081 +/- 0.0119,0.1955 +/- 0.0098
rtvae,0.1876 +/- 0.0077,0.1493 +/- 0.0074,0.2201 +/- 0.0109,0.2794 +/- 0.0114


## gbsg

In [114]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("gbsg")

df

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,duration,event
0,0.0,0.0,0.0,32.0,1.0,155.0,168.0,84.000000,0
1,0.0,1.0,0.0,27.0,1.0,717.0,95.0,84.000000,0
2,0.0,1.0,1.0,52.0,1.0,120.0,437.0,84.000000,0
3,0.0,0.0,0.0,28.0,1.0,251.0,11.0,84.000000,0
4,0.0,0.0,0.0,39.0,1.0,241.0,92.0,66.234085,1
...,...,...,...,...,...,...,...,...,...
2227,0.0,1.0,0.0,49.0,3.0,1.0,84.0,23.687885,0
2228,1.0,1.0,1.0,53.0,17.0,0.0,0.0,6.110883,0
2229,0.0,1.0,0.0,51.0,5.0,43.0,0.0,25.264887,1
2230,0.0,1.0,1.0,52.0,3.0,15.0,34.0,23.885010,1


In [8]:
model_metrics = evaluate_dataset("gbsg")

Dataset 'gbsg' not locally available. Downloading...
Done


In [9]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.6776 +/- 0.0169,0.7019 +/- 0.0142,0.6709 +/- 0.0139,0.6266 +/- 0.0135
pategan,0.6776 +/- 0.0169,0.6531 +/- 0.0079,0.5628 +/- 0.0055,0.5129 +/- 0.0075
tvae,0.6776 +/- 0.0169,0.856 +/- 0.0029,0.8071 +/- 0.0078,0.6756 +/- 0.0185
bayesian_network,0.6776 +/- 0.0169,0.6838 +/- 0.024,0.6908 +/- 0.0192,0.6711 +/- 0.0169
privbayes,0.6776 +/- 0.0169,0.4768 +/- 0.0132,0.497 +/- 0.0071,0.5988 +/- 0.018
ctgan,0.6776 +/- 0.0169,0.6057 +/- 0.0262,0.5595 +/- 0.014,0.6014 +/- 0.0144
marginal_distributions,0.6776 +/- 0.0169,0.4849 +/- 0.022,0.522 +/- 0.0409,0.4314 +/- 0.0153
copulagan,0.6776 +/- 0.0169,0.5201 +/- 0.0105,0.5194 +/- 0.0163,0.6302 +/- 0.0122
gaussian_copula,0.6776 +/- 0.0169,0.6307 +/- 0.0208,0.631 +/- 0.0266,0.6787 +/- 0.0196
rtvae,0.6776 +/- 0.0169,0 +/-0,0 +/-0,0.4705 +/- 0.0148


aucroc


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.7292 +/- 0.0071,0.8138 +/- 0.0039,0.7396 +/- 0.0157,0.6778 +/- 0.0077
pategan,0.7292 +/- 0.0071,0.7728 +/- 0.0093,0.5267 +/- 0.0044,0.534 +/- 0.0171
tvae,0.7292 +/- 0.0071,0.9576 +/- 0.0027,0.8986 +/- 0.0067,0.7353 +/- 0.0097
bayesian_network,0.7292 +/- 0.0071,0.7465 +/- 0.0102,0.758 +/- 0.009,0.7312 +/- 0.0027
privbayes,0.7292 +/- 0.0071,0.4782 +/- 0.0148,0.5058 +/- 0.0211,0.6417 +/- 0.0103
ctgan,0.7292 +/- 0.0071,0.6979 +/- 0.0088,0.613 +/- 0.0279,0.6393 +/- 0.0127
marginal_distributions,0.7292 +/- 0.0071,0.4988 +/- 0.0038,0.5178 +/- 0.0116,0.3999 +/- 0.0116
copulagan,0.7292 +/- 0.0071,0.5275 +/- 0.0125,0.5421 +/- 0.0099,0.6875 +/- 0.0094
gaussian_copula,0.7292 +/- 0.0071,0.7029 +/- 0.0116,0.6995 +/- 0.0067,0.7397 +/- 0.0041
rtvae,0.7292 +/- 0.0071,0 +/- 0,0 +/- 0,0.4459 +/- 0.0143


brier_score


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.2048 +/- 0.0022,0.1931 +/- 0.0084,0.2123 +/- 0.0054,0.2242 +/- 0.0021
pategan,0.2048 +/- 0.0022,0.1836 +/- 0.0038,0.2265 +/- 0.0065,0.301 +/- 0.002
tvae,0.2048 +/- 0.0022,0.1059 +/- 0.0036,0.1922 +/- 0.0034,0.2283 +/- 0.0112
bayesian_network,0.2048 +/- 0.0022,0.203 +/- 0.0132,0.2026 +/- 0.0106,0.2063 +/- 0.0019
privbayes,0.2048 +/- 0.0022,0.2057 +/- 0.005,0.2334 +/- 0.0041,0.2348 +/- 0.0037
ctgan,0.2048 +/- 0.0022,0.2277 +/- 0.0032,0.2432 +/- 0.0103,0.2217 +/- 0.0014
marginal_distributions,0.2048 +/- 0.0022,0.1902 +/- 0.0034,0.2622 +/- 0.0172,0.2574 +/- 0.0015
copulagan,0.2048 +/- 0.0022,0.2232 +/- 0.0068,0.238 +/- 0.0102,0.2264 +/- 0.001
gaussian_copula,0.2048 +/- 0.0022,0.2124 +/- 0.0049,0.218 +/- 0.0064,0.2078 +/- 0.0046
rtvae,0.2048 +/- 0.0022,1 +/- 0,1 +/- 0,0.2767 +/- 0.0033


## Support

In [115]:
import pandas as pd

df, X_real, T_real, Y_real, time_horizons = get_dataset("support")

df

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,duration,event
0,82.709961,1.0,2.0,1.0,0.0,0.0,0.0,160.0,55.0,16.0,38.195309,142.0,19.000000,1.099854,30.0,1
1,79.660950,1.0,0.0,1.0,0.0,0.0,1.0,54.0,67.0,16.0,38.000000,142.0,10.000000,0.899902,1527.0,0
2,23.399990,1.0,2.0,3.0,0.0,0.0,1.0,87.0,144.0,45.0,37.296879,130.0,5.199219,1.199951,96.0,1
3,53.075989,1.0,4.0,3.0,0.0,0.0,0.0,55.0,100.0,18.0,36.000000,135.0,8.699219,0.799927,892.0,0
4,71.794983,0.0,1.0,1.0,0.0,0.0,0.0,65.0,135.0,40.0,38.593750,146.0,0.099991,0.399963,7.0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8868,81.064941,0.0,4.0,1.0,0.0,0.0,1.0,111.0,110.0,34.0,39.593750,135.0,13.000000,1.500000,36.0,1
8869,72.560966,0.0,2.0,1.0,0.0,0.0,1.0,53.0,74.0,28.0,34.695309,139.0,7.899414,1.899902,49.0,1
8870,63.228001,0.0,1.0,1.0,0.0,0.0,2.0,95.0,110.0,22.0,38.695309,132.0,7.799805,1.500000,6.0,1
8871,75.405937,0.0,2.0,1.0,1.0,0.0,2.0,109.0,110.0,30.0,36.195309,140.0,15.398438,0.899902,10.0,1


In [11]:
# Baseline evaluation
from synthcity.plugins import Plugins
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
import optuna

optuna.logging.set_verbosity(optuna.logging.DEBUG)
optuna.logging.enable_propagation()
optuna.logging.enable_default_handler()

predictor = RiskEstimation().get_type("deephit")
n_folds = 3

const_cols = constant_columns(X_real)
X_real = X_real.drop(columns=const_cols)


def objective(trial):
    params = _trial_params(trial, predictor.hyperparameter_space())

    try:
        score = evaluate_survival_estimator(
            predictor(**params),
            X_real,
            T_real,
            Y_real,
            time_horizons=time_horizons,
            n_folds=n_folds,
        )["clf"]["c_index"][0]
    except BaseException as e:
        print("   >>>>>>>> trial failed", e)
        score = 0

    return score


study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=30)

[32m[I 2022-04-22 15:46:17,554][0m A new study created in memory with name: no-name-dfdb1edd-5658-4015-9ff4-9baecf02b103[0m
[32m[I 2022-04-22 15:48:33,589][0m Trial 0 finished with value: 0.5502709557988722 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 50, 'alpha': 0.06042287324857454, 'sigma': 0.36422384946059305, 'dropout': 0.11645823812927213, 'patience': 20}. Best is trial 0 with value: 0.5502709557988722.[0m
[32m[I 2022-04-22 15:50:39,668][0m Trial 1 finished with value: 0.5085104566257822 and parameters: {'batch_size': 500, 'lr': 0.001, 'dim_hidden': 30, 'alpha': 0.013266474341216694, 'sigma': 0.0315171123791268, 'dropout': 0.006484416044899222, 'patience': 37}. Best is trial 0 with value: 0.5502709557988722.[0m
[32m[I 2022-04-22 15:51:29,555][0m Trial 2 finished with value: 0.5548003190485068 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 100, 'alpha': 0.2262973576839481, 'sigma': 0.3299211212604994, 'dropout': 0.15590544126228495, '

   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:47,857][0m Trial 6 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.01, 'dim_hidden': 10, 'alpha': 0.11606913400475982, 'sigma': 0.460656612158441, 'dropout': 0.02848921131048996, 'patience': 23}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:48,850][0m Trial 7 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.01, 'dim_hidden': 30, 'alpha': 0.035132233761227816, 'sigma': 0.3098157406257913, 'dropout': 0.16515114487039093, 'patience': 25}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:49,939][0m Trial 8 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.01, 'dim_hidden': 40, 'alpha': 0.15647647518529512, 'sigma': 0.30464823944809266, 'dropout': 0.1808975040893317, 'patience': 12}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:50,742][0m Trial 9 finished with value: 0.0 and parameters: {'batch_size': 200, 'lr': 0.0001, 'dim_hidden': 50, 'alpha': 0.028788624579713595, 'sigma': 0.17802989652633067, 'dropout': 0.08699026051503163, 'patience': 47}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:50,941][0m Trial 10 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.001, 'dim_hidden': 90, 'alpha': 0.4491850300328254, 'sigma': 0.2052650937786391, 'dropout': 0.11894459769556441, 'patience': 46}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:51,133][0m Trial 11 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.0001, 'dim_hidden': 100, 'alpha': 0.3361121555324741, 'sigma': 0.2528440388818439, 'dropout': 0.14129458062000472, 'patience': 39}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:51,335][0m Trial 12 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.001, 'dim_hidden': 20, 'alpha': 0.27490296173413126, 'sigma': 0.46740254189861036, 'dropout': 0.14669667588995566, 'patience': 16}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:51,533][0m Trial 13 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.0001, 'dim_hidden': 100, 'alpha': 0.22707099183619733, 'sigma': 0.38099842993213723, 'dropout': 0.08946517704614483, 'patience': 50}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:51,715][0m Trial 14 finished with value: 0.0 and parameters: {'batch_size': 200, 'lr': 0.001, 'dim_hidden': 60, 'alpha': 0.343850788740538, 'sigma': 0.2704074985735507, 'dropout': 0.18984182916201142, 'patience': 30}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:51,905][0m Trial 15 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.0001, 'dim_hidden': 80, 'alpha': 0.1776895627279264, 'sigma': 0.1584867027233737, 'dropout': 0.12591802223715876, 'patience': 48}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:52,095][0m Trial 16 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.001, 'dim_hidden': 50, 'alpha': 0.3115482087704986, 'sigma': 0.33520125878267426, 'dropout': 0.06189547771282751, 'patience': 15}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:52,295][0m Trial 17 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 100, 'alpha': 0.38969049975656567, 'sigma': 0.41549368498073574, 'dropout': 0.15944038025737362, 'patience': 11}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:52,488][0m Trial 18 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.0001, 'dim_hidden': 70, 'alpha': 0.4995502070468597, 'sigma': 0.20975217046103184, 'dropout': 0.09450775378779251, 'patience': 42}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:52,677][0m Trial 19 finished with value: 0.0 and parameters: {'batch_size': 200, 'lr': 0.001, 'dim_hidden': 90, 'alpha': 0.1947640525837956, 'sigma': 0.09330063933532487, 'dropout': 0.03680142014844953, 'patience': 33}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:52,879][0m Trial 20 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 60, 'alpha': 0.2665128996875936, 'sigma': 0.2890194306158457, 'dropout': 0.1084509470839192, 'patience': 49}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:53,088][0m Trial 21 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 50, 'alpha': 0.08108644177396808, 'sigma': 0.32606455819124636, 'dropout': 0.12847508990359496, 'patience': 45}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:53,289][0m Trial 22 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 50, 'alpha': 0.08021881647985404, 'sigma': 0.36091978482394604, 'dropout': 0.07581306476287604, 'patience': 34}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:53,500][0m Trial 23 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 50, 'alpha': 0.13771523558521848, 'sigma': 0.4263180684271063, 'dropout': 0.11047529620723545, 'patience': 20}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 
   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:53,701][0m Trial 24 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 20, 'alpha': 0.21454931955649598, 'sigma': 0.3486662181699749, 'dropout': 0.14251976052357104, 'patience': 29}. Best is trial 3 with value: 0.5863910854138442.[0m
[32m[I 2022-04-22 15:58:53,892][0m Trial 25 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.001, 'dim_hidden': 40, 'alpha': 0.07951500257646571, 'sigma': 0.23498705500786515, 'dropout': 0.17082591116835763, 'patience': 28}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:54,171][0m Trial 26 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 80, 'alpha': 0.29117520360188365, 'sigma': 0.41141945527691903, 'dropout': 0.10308119502753599, 'patience': 14}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:54,766][0m Trial 27 finished with value: 0.0 and parameters: {'batch_size': 100, 'lr': 0.001, 'dim_hidden': 10, 'alpha': 0.3809323086245342, 'sigma': 0.2752494143306842, 'dropout': 0.19680685904843945, 'patience': 10}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:55,033][0m Trial 28 finished with value: 0.0 and parameters: {'batch_size': 200, 'lr': 0.01, 'dim_hidden': 100, 'alpha': 0.1802282327772528, 'sigma': 0.4868832210315739, 'dropout': 0.14976972421187615, 'patience': 22}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


[32m[I 2022-04-22 15:58:55,296][0m Trial 29 finished with value: 0.0 and parameters: {'batch_size': 500, 'lr': 0.0001, 'dim_hidden': 50, 'alpha': 0.2502638045686847, 'sigma': 0.3626733978912186, 'dropout': 0.13469199664674325, 'patience': 17}. Best is trial 3 with value: 0.5863910854138442.[0m


   >>>>>>>> trial failed 


In [10]:
model_metrics = evaluate_dataset("support")

In [11]:
plot_metrics(model_metrics)

c_index


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.5706 +/- 0.0054,0.6813 +/- 0.0069,0.5457 +/- 0.0055,0.5209 +/- 0.0069
pategan,0.5706 +/- 0.0054,0 +/-0,0 +/-0,0.5004 +/- 0.0039
tvae,0.5706 +/- 0.0054,0.7028 +/- 0.0034,0.6946 +/- 0.0036,0.5664 +/- 0.0025
bayesian_network,0.5706 +/- 0.0054,0.5642 +/- 0.0068,0.5656 +/- 0.0086,0.5707 +/- 0.0046
privbayes,0.5706 +/- 0.0054,0.5028 +/- 0.0025,0.4985 +/- 0.0024,0.483 +/- 0.0044
ctgan,0.5706 +/- 0.0054,0.6706 +/- 0.0035,0.6604 +/- 0.0038,0.5549 +/- 0.0051
marginal_distributions,0.5706 +/- 0.0054,0.5068 +/- 0.0065,0.5041 +/- 0.0102,0.5034 +/- 0.0041
copulagan,0.5706 +/- 0.0054,0.6137 +/- 0.0032,0.4954 +/- 0.0156,0.5104 +/- 0.0041
gaussian_copula,0.5706 +/- 0.0054,0.5691 +/- 0.0066,0.5662 +/- 0.0067,0.5688 +/- 0.007
rtvae,0.5706 +/- 0.0054,0 +/-0,0 +/-0,0.4993 +/- 0.0066


aucroc


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.6379 +/- 0.0056,0.823 +/- 0.0013,0.6578 +/- 0.0025,0.5522 +/- 0.0068
pategan,0.6379 +/- 0.0056,0 +/- 0,0 +/- 0,0.5402 +/- 0.0074
tvae,0.6379 +/- 0.0056,0.8354 +/- 0.003,0.8159 +/- 0.0017,0.6451 +/- 0.0106
bayesian_network,0.6379 +/- 0.0056,0.6459 +/- 0.0062,0.6625 +/- 0.0049,0.6541 +/- 0.0097
privbayes,0.6379 +/- 0.0056,0.4928 +/- 0.0091,0.494 +/- 0.0226,0.4522 +/- 0.0057
ctgan,0.6379 +/- 0.0056,0.8536 +/- 0.0075,0.8318 +/- 0.0048,0.6618 +/- 0.0071
marginal_distributions,0.6379 +/- 0.0056,0.4828 +/- 0.0063,0.5006 +/- 0.006,0.5804 +/- 0.0075
copulagan,0.6379 +/- 0.0056,0.7389 +/- 0.0057,0.5095 +/- 0.0128,0.5259 +/- 0.0113
gaussian_copula,0.6379 +/- 0.0056,0.6291 +/- 0.0074,0.6238 +/- 0.0091,0.6542 +/- 0.0076
rtvae,0.6379 +/- 0.0056,0 +/- 0,0 +/- 0,0.5415 +/- 0.005


brier_score


generator,real_real,syn_syn,real_syn,syn_real
adsgan,0.2084 +/- 0.004,0.1911 +/- 0.0078,0.2425 +/- 0.0053,0.2712 +/- 0.005
pategan,0.2084 +/- 0.004,1 +/- 0,1 +/- 0,0.3047 +/- 0.0052
tvae,0.2084 +/- 0.004,0.1457 +/- 0.0037,0.1771 +/- 0.0032,0.2484 +/- 0.0027
bayesian_network,0.2084 +/- 0.004,0.2094 +/- 0.0013,0.209 +/- 0.0015,0.2091 +/- 0.0043
privbayes,0.2084 +/- 0.004,0.2227 +/- 0.0011,0.3029 +/- 0.0021,0.2808 +/- 0.0025
ctgan,0.2084 +/- 0.004,0.1676 +/- 0.0061,0.1782 +/- 0.0068,0.2167 +/- 0.0042
marginal_distributions,0.2084 +/- 0.004,0.1909 +/- 0.0014,0.5345 +/- 0.0146,0.3501 +/- 0.0014
copulagan,0.2084 +/- 0.004,0.1987 +/- 0.0069,0.2284 +/- 0.0078,0.2411 +/- 0.0063
gaussian_copula,0.2084 +/- 0.004,0.2157 +/- 0.0037,0.2183 +/- 0.0042,0.2103 +/- 0.0045
rtvae,0.2084 +/- 0.004,1 +/- 0,1 +/- 0,0.2775 +/- 0.0063
