In [1]:
import sys

import numpy as np

from datasets import get_dataset

for dataset in ["aids", "cutract", "maggic", "seer", "metabric"]:
    df, duration_col, event_col, time_horizons = get_dataset(dataset)

    print(dataset, np.linspace(df[duration_col].min(), df[duration_col].max(), 7)[1:-1])

aids [ 61.5 122.  182.5 243.  303.5]
cutract [1051. 2082. 3113. 4144. 5175.]
maggic [1404.66666667 2809.33333333 4214.         5618.66666667 7023.33333333]
seer [ 775. 1550. 2325. 3100. 3875.]
metabric [ 56.25555555 112.4111111  168.56666665 224.7222222  280.87777775]


In [41]:
import sys
from pathlib import Path

import synthcity.logger as log
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.metrics import generate_score, print_score
from adjutorium.utils.tester import evaluate_survival_estimator
from synthcity.benchmark import Benchmarks
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

from datasets import get_dataset

log.remove()
log.add(sink=sys.stderr, level="INFO")

workspace_dir = Path("workspace")

plugins = [
    "survival_gan",
    "ctgan",
    "tvae",
]


def evaluate_dataset(dataset: str):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)
    df_hash = dataframe_hash(df)

    X = df.drop(columns=[duration_col, event_col])
    T = df[duration_col]
    E = df[event_col]

    for plugin in plugins:
        print(" >>> ", plugin)

        for horizon in time_horizons:
            cindex = []
            brier = []

            for seed in range(5):
                model_bkp = workspace_dir / f"{df_hash}_{plugin}_{seed}.bkp"

                syn_df = load_from_file(model_bkp)
                try:
                    syn_df = syn_df.dataframe()
                except BaseException:
                    pass

                Xsyn = syn_df.drop(columns=[duration_col, event_col])
                Tsyn = syn_df[duration_col]
                Esyn = syn_df[event_col]

                model = RiskEstimation().get("survival_xgboost")
                try:
                    model.fit(Xsyn, Tsyn, Esyn)

                    score = evaluate_survival_estimator(
                        [model] * 3,
                        X,
                        T,
                        E,
                        time_horizons=[horizon],
                        pretrained=True,
                        metrics=["c_index", "brier_score"],
                    )
                    # print("          ", seed, horizon, score["str"])

                    cindex.append(score["clf"]["c_index"][0])
                    brier.append(score["clf"]["brier_score"][0])

                except BaseException:
                    continue
            print(
                f"          horizon = {horizon}. ciondex = {print_score(generate_score(cindex))} brier = {print_score(generate_score(brier))}"
            )

In [42]:
evaluate_dataset("aids")

 >>>  survival_gan
          horizon = 91.75. ciondex = 0.748 +/- 0.032 brier = 0.042 +/- 0.001
          horizon = 182.5. ciondex = 0.719 +/- 0.034 brier = 0.071 +/- 0.004
          horizon = 273.25. ciondex = 0.705 +/- 0.034 brier = 0.086 +/- 0.004
 >>>  ctgan
          horizon = 91.75. ciondex = 0.506 +/- 0.106 brier = 0.043 +/- 0.001
          horizon = 182.5. ciondex = 0.502 +/- 0.101 brier = 0.074 +/- 0.004
          horizon = 273.25. ciondex = 0.502 +/- 0.102 brier = 0.101 +/- 0.009
 >>>  tvae
None
          horizon = 91.75. ciondex = 0.611 +/- 0.091 brier = 0.043 +/- 0.0
None
          horizon = 182.5. ciondex = 0.59 +/- 0.073 brier = 0.072 +/- 0.001
None
          horizon = 273.25. ciondex = 0.578 +/- 0.063 brier = 0.093 +/- 0.001


In [43]:
evaluate_dataset("cutract")

 >>>  survival_gan
          horizon = 1566.5. ciondex = 0.783 +/- 0.012 brier = 0.044 +/- 0.004
          horizon = 3113.0. ciondex = 0.772 +/- 0.016 brier = 0.154 +/- 0.041
          horizon = 4659.5. ciondex = 0.74 +/- 0.017 brier = 0.208 +/- 0.052
 >>>  ctgan
          horizon = 1566.5. ciondex = 0.822 +/- 0.01 brier = 0.064 +/- 0.007
          horizon = 3113.0. ciondex = 0.809 +/- 0.008 brier = 0.216 +/- 0.038
          horizon = 4659.5. ciondex = 0.782 +/- 0.007 brier = 0.327 +/- 0.048
 >>>  tvae
          horizon = 1566.5. ciondex = 0.719 +/- 0.024 brier = 0.062 +/- 0.018
          horizon = 3113.0. ciondex = 0.706 +/- 0.024 brier = 0.111 +/- 0.02
          horizon = 4659.5. ciondex = 0.68 +/- 0.02 brier = 0.163 +/- 0.01


In [44]:
evaluate_dataset("maggic")

 >>>  survival_gan


FileNotFoundError: [Errno 2] No such file or directory: 'workspace/4879234145147014154_survival_gan_3.bkp'

In [None]:
evaluate_dataset("seer")

In [3]:
import sys
from pathlib import Path

import synthcity.logger as log
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.tester import evaluate_survival_estimator
from synthcity.benchmark import Benchmarks
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

from datasets import get_dataset

log.remove()
log.add(sink=sys.stderr, level="INFO")

workspace_dir = Path("workspace_rebuttal")


def evaluate_dataset(dataset: str, plugin: tuple, repeats: int = 2):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)
    df_hash = dataframe_hash(df)

    X = df.drop(columns=[duration_col, event_col])
    T = df[duration_col]
    E = df[event_col]

    for horizon in time_horizons:
        model_bkp = workspace_dir / f"{df_hash}_{plugin}_{plugin}__0.bkp"

        syn_df = load_from_file(model_bkp)
        try:
            syn_df = syn_df.dataframe()
        except BaseException:
            pass

        Xsyn = syn_df.drop(columns=[duration_col, event_col])
        Tsyn = syn_df[duration_col]
        Esyn = syn_df[event_col]

        model = RiskEstimation().get("survival_xgboost")
        model.fit(Xsyn, Tsyn, Esyn)

        score = evaluate_survival_estimator(
            [model] * 3,
            X,
            T,
            E,
            time_horizons=[horizon],
            pretrained=True,
            metrics=["c_index", "brier_score"],
        )
        print(horizon, score["str"])

In [4]:
evaluate_dataset("metabric", "survival_gan")

84.333333325 {'c_index': '0.737 +/- 0.038', 'brier_score': '0.074 +/- 0.013'}
168.56666665 {'c_index': '0.738 +/- 0.024', 'brier_score': '0.224 +/- 0.025'}
252.79999997500002 {'c_index': '0.665 +/- 0.011', 'brier_score': '0.245 +/- 0.018'}
