In [1]:
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import synthcity.logger as log

from datasets import get_dataset

warnings.filterwarnings("ignore", category=FutureWarning)

log.remove()
warnings.filterwarnings("ignore")
log.add(sink=sys.stderr, level="DEBUG")

In [7]:
from pathlib import Path

import tabulate
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.metrics.eval_performance import (PerformanceEvaluatorLinear,
                                                PerformanceEvaluatorMLP,
                                                PerformanceEvaluatorXGB)
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

datasets = [
    "aids",
    "cutract",
    "maggic",
    "seer",
]
methods = ["survival_gan", "ctgan", "nflow", "tvae", "privbayes", "adsgan"]

out_dir = Path("workspace")
headers = ["dataset"] + methods

results = {}
for ref_df in [
    "aids",
    "cutract",
    "maggic",
    "seer",
]:
    results[ref_df] = {}

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)

    real_dataloader = SurvivalAnalysisDataLoader(
        df,
        time_to_event_column=duration_col,
        target_column=event_col,
        time_horizons=time_horizons,
    )
    local_distance = [ref_df]
    for src_eval in [
        PerformanceEvaluatorMLP,
        PerformanceEvaluatorXGB,
        PerformanceEvaluatorLinear,
    ]:
        results[ref_df][src_eval.name()] = {
            "gt": {
                "c_index": [],
                "brier_score": [],
            }
        }

        for method in methods:
            results[ref_df][src_eval.name()][method] = {
                "c_index": [],
                "brier_score": [],
            }

            scores = []
            for seed in range(3):
                model_bkp = out_dir / f"{df_hash}_{method}_{seed}.bkp"
                syn_df = load_from_file(model_bkp)
                try:
                    syn_df = syn_df.dataframe()
                except BaseException as e:
                    pass

                syn_dataloader = SurvivalAnalysisDataLoader(
                    syn_df,
                    time_to_event_column=duration_col,
                    target_column=event_col,
                    time_horizons=time_horizons,
                )
                score = src_eval(task_type="survival_analysis").evaluate(
                    real_dataloader, syn_dataloader
                )
                print(ref_df, src_eval.name(), method, seed, score)
                results[ref_df][src_eval.name()]["gt"]["c_index"] = score["gt.c_index"]
                results[ref_df][src_eval.name()]["gt"]["brier_score"] = score[
                    "gt.brier_score"
                ]
                results[ref_df][src_eval.name()][method]["c_index"].append(
                    score["syn_ood.c_index"]
                )
                results[ref_df][src_eval.name()][method]["brier_score"].append(
                    score["syn_ood.brier_score"]
                )

results

Evaluate  aids
aids mlp survival_gan 0 {'gt.c_index': 0.7059304149697022, 'gt.brier_score': 0.06806232093352434, 'syn_id.c_index': 0.7213261019376978, 'syn_id.brier_score': 0.06819882377771379, 'syn_ood.c_index': 0.6454852592877075, 'syn_ood.brier_score': 0.07742694851561303}
aids mlp survival_gan 1 {'gt.c_index': 0.7059304149697022, 'gt.brier_score': 0.06806232093352434, 'syn_id.c_index': 0.6549150597652914, 'syn_id.brier_score': 0.06818851102562119, 'syn_ood.c_index': 0.7165071752721839, 'syn_ood.brier_score': 0.07776424405215514}
aids mlp survival_gan 2 {'gt.c_index': 0.7059304149697022, 'gt.brier_score': 0.06806232093352434, 'syn_id.c_index': 0.7161469010591421, 'syn_id.brier_score': 0.08966694941026791, 'syn_ood.c_index': 0.7644641382612013, 'syn_ood.brier_score': 0.07990441130531196}
aids mlp ctgan 0 {'gt.c_index': 0.7059304149697022, 'gt.brier_score': 0.06806232093352434, 'syn_id.c_index': 0.5876149645646521, 'syn_id.brier_score': 0.06679559797339725, 'syn_ood.c_index': 0.0, 'sy

cutract mlp survival_gan 2 {'gt.c_index': 0.7227856874466392, 'gt.brier_score': 0.09078878664723468, 'syn_id.c_index': 0.7984934437588945, 'syn_id.brier_score': 0.09093075652740061, 'syn_ood.c_index': 0.7781613895561756, 'syn_ood.brier_score': 0.10017100712450633}
cutract mlp ctgan 0 {'gt.c_index': 0.7227856874466392, 'gt.brier_score': 0.09078878664723468, 'syn_id.c_index': 0.7959520880770041, 'syn_id.brier_score': 0.16775249071453444, 'syn_ood.c_index': 0.7553184846865162, 'syn_ood.brier_score': 0.1849173173231573}
cutract mlp ctgan 1 {'gt.c_index': 0.7227856874466392, 'gt.brier_score': 0.09078878664723468, 'syn_id.c_index': 0.7012475926951557, 'syn_id.brier_score': 0.23557166878759894, 'syn_ood.c_index': 0.6777539893172868, 'syn_ood.brier_score': 0.23579553681046647}
cutract mlp ctgan 2 {'gt.c_index': 0.7227856874466392, 'gt.brier_score': 0.09078878664723468, 'syn_id.c_index': 0.7952237452968512, 'syn_id.brier_score': 0.18644685290042226, 'syn_ood.c_index': 0.7390306737034017, 'syn_o

cutract linear_model survival_gan 2 {'gt.c_index': 0.8098309079742018, 'gt.brier_score': 0.07536036182071539, 'syn_id.c_index': 0.8087741715962581, 'syn_id.brier_score': 0.09581397133039615, 'syn_ood.c_index': 0.7687297248123496, 'syn_ood.brier_score': 0.10819228259054055}
cutract linear_model ctgan 0 {'gt.c_index': 0.8098309079742018, 'gt.brier_score': 0.07536036182071539, 'syn_id.c_index': 0.7976123629288813, 'syn_id.brier_score': 0.16112846550197069, 'syn_ood.c_index': 0.7586253153808049, 'syn_ood.brier_score': 0.1686652101277557}
cutract linear_model ctgan 1 {'gt.c_index': 0.8098309079742018, 'gt.brier_score': 0.07536036182071539, 'syn_id.c_index': 0.8073633392249625, 'syn_id.brier_score': 0.2032599226008458, 'syn_ood.c_index': 0.7676091861839117, 'syn_ood.brier_score': 0.21098839351746976}
cutract linear_model ctgan 2 {'gt.c_index': 0.8098309079742018, 'gt.brier_score': 0.07536036182071539, 'syn_id.c_index': 0.8087018009483685, 'syn_id.brier_score': 0.16374077187360953, 'syn_ood.c

maggic xgb survival_gan 0 {'gt.c_index': 0.6654334315503624, 'gt.brier_score': 0.15120604857477493, 'syn_id.c_index': 0.6330676237824038, 'syn_id.brier_score': 0.17433459614622904, 'syn_ood.c_index': 0.6347711754571775, 'syn_ood.brier_score': 0.17898369974451991}
maggic xgb survival_gan 1 {'gt.c_index': 0.6654334315503624, 'gt.brier_score': 0.15120604857477493, 'syn_id.c_index': 0.575788000307857, 'syn_id.brier_score': 0.18412801136219992, 'syn_ood.c_index': 0.5774669041777848, 'syn_ood.brier_score': 0.19151403324599658}
maggic xgb survival_gan 2 {'gt.c_index': 0.6654334315503624, 'gt.brier_score': 0.15120604857477493, 'syn_id.c_index': 0.6262730181469522, 'syn_id.brier_score': 0.1742823519423781, 'syn_ood.c_index': 0.6252707348144699, 'syn_ood.brier_score': 0.180402368625159}
maggic xgb ctgan 0 {'gt.c_index': 0.6654334315503624, 'gt.brier_score': 0.15120604857477493, 'syn_id.c_index': 0.5879598233338613, 'syn_id.brier_score': 0.1921943952931453, 'syn_ood.c_index': 0.5828279669020512, 

seer mlp survival_gan 0 {'gt.c_index': 0.5000004041745689, 'gt.brier_score': 0.02412814886111232, 'syn_id.c_index': 0.5, 'syn_id.brier_score': 0.026039657151882833, 'syn_ood.c_index': 0.5, 'syn_ood.brier_score': 0.025576156578280145}
seer mlp survival_gan 1 {'gt.c_index': 0.5000004041745689, 'gt.brier_score': 0.02412814886111232, 'syn_id.c_index': 0.5, 'syn_id.brier_score': 0.024177845268765107, 'syn_ood.c_index': 0.5, 'syn_ood.brier_score': 0.02413232719705384}
seer mlp survival_gan 2 {'gt.c_index': 0.5000004041745689, 'gt.brier_score': 0.02412814886111232, 'syn_id.c_index': 0.5000252864096146, 'syn_id.brier_score': 0.039814137272375864, 'syn_ood.c_index': 0.5001266850867621, 'syn_ood.brier_score': 0.03877040837008744}
seer mlp ctgan 0 {'gt.c_index': 0.5000004041745689, 'gt.brier_score': 0.02412814886111232, 'syn_id.c_index': 0.8116450109088582, 'syn_id.brier_score': 0.08061988389914133, 'syn_ood.c_index': 0.5, 'syn_ood.brier_score': 0.15493806825413384}
seer mlp ctgan 1 {'gt.c_index'

seer xgb adsgan 2 {'gt.c_index': 0.8521464322079112, 'gt.brier_score': 0.021661690647627116, 'syn_id.c_index': 0.5, 'syn_id.brier_score': 0.024903875094444892, 'syn_ood.c_index': 0.5, 'syn_ood.brier_score': 0.024896275616259202}
seer linear_model survival_gan 0 {'gt.c_index': 0.8478194755263354, 'gt.brier_score': 0.020776657214431535, 'syn_id.c_index': 0.8432312236522831, 'syn_id.brier_score': 0.02157766562628803, 'syn_ood.c_index': 0.8462229709304018, 'syn_ood.brier_score': 0.021506527394729944}
seer linear_model survival_gan 1 {'gt.c_index': 0.8478194755263354, 'gt.brier_score': 0.020776657214431535, 'syn_id.c_index': 0.8288225255745708, 'syn_id.brier_score': 0.02144128327592952, 'syn_ood.c_index': 0.8286079327437789, 'syn_ood.brier_score': 0.02144260306350441}
seer linear_model survival_gan 2 {'gt.c_index': 0.8478194755263354, 'gt.brier_score': 0.020776657214431535, 'syn_id.c_index': 0.8257583258265292, 'syn_id.brier_score': 0.028407781851670696, 'syn_ood.c_index': 0.823190110041310

{'aids': {'mlp': {'gt': {'c_index': 0.7059304149697022,
    'brier_score': 0.06806232093352434},
   'survival_gan': {'c_index': [0.6454852592877075,
     0.7165071752721839,
     0.7644641382612013],
    'brier_score': [0.07742694851561303,
     0.07776424405215514,
     0.07990441130531196]},
   'ctgan': {'c_index': [0.0, 0.520568649905171, 0.5542162720615037],
    'brier_score': [1.0, 0.07613622330571594, 0.10876656241751877]},
   'nflow': {'c_index': [0.3836290426816391,
     0.37855223712306135,
     0.5309118212646274],
    'brier_score': [0.1360939367406394,
     0.16071780684920026,
     0.1273532349081433]},
   'tvae': {'c_index': [0.0, 0.4973162193929452, 0.0],
    'brier_score': [1.0, 0.07745913854287934, 1.0]},
   'privbayes': {'c_index': [0.47586271354782933,
     0.4173548020659998,
     0.2991700915100185],
    'brier_score': [0.07847409831109108,
     0.09323820812133977,
     0.07883306714479758]},
   'adsgan': {'c_index': [0.0, 0.5985419306860772, 0.0],
    'brier_scor

In [12]:
## metabric

from pathlib import Path

import tabulate
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.metrics.eval_statistical import JensenShannonDistance
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

out_dir = Path("workspace_rebuttal")
methods = ["survival_gan", "ctgan", "nflow", "tvae", "privbayes", "adsgan"]

headers = ["dataset"] + methods

for ref_df in ["metabric"]:
    results[ref_df] = {}

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)

    real_dataloader = SurvivalAnalysisDataLoader(
        df,
        time_to_event_column=duration_col,
        target_column=event_col,
        time_horizons=time_horizons,
    )
    local_distance = [ref_df]

    for src_eval in [
        PerformanceEvaluatorMLP,
        PerformanceEvaluatorXGB,
        PerformanceEvaluatorLinear,
    ]:
        results[ref_df][src_eval.name()] = {
            "gt": {
                "c_index": [],
                "brier_score": [],
            }
        }
        for method in methods:
            results[ref_df][src_eval.name()][method] = {
                "c_index": [],
                "brier_score": [],
            }
            for seed in range(3):
                model_bkp = out_dir / f"{df_hash}_{method}_{method}__{seed}.bkp"
                syn_df = load_from_file(model_bkp)
                try:
                    syn_df = syn_df.dataframe()
                except BaseException as e:
                    pass

                syn_dataloader = SurvivalAnalysisDataLoader(
                    syn_df,
                    time_to_event_column=duration_col,
                    target_column=event_col,
                    time_horizons=time_horizons,
                )
                try:
                    score = src_eval(task_type="survival_analysis").evaluate(
                        real_dataloader, syn_dataloader
                    )
                except:
                    continue
                print(ref_df, src_eval.name(), method, seed, score)
                results[ref_df][src_eval.name()]["gt"]["c_index"] = score["gt.c_index"]
                results[ref_df][src_eval.name()]["gt"]["brier_score"] = score[
                    "gt.brier_score"
                ]
                results[ref_df][src_eval.name()][method]["c_index"].append(
                    score["syn_id.c_index"]
                )
                results[ref_df][src_eval.name()][method]["brier_score"].append(
                    score["syn_id.brier_score"]
                )

Evaluate  metabric
metabric mlp survival_gan 0 {'gt.c_index': 0.6341500002733197, 'gt.brier_score': 0.17096242975213496, 'syn_id.c_index': 0.6652024503200676, 'syn_id.brier_score': 0.17794654775701013, 'syn_ood.c_index': 0.0, 'syn_ood.brier_score': 1.0}
metabric mlp survival_gan 1 {'gt.c_index': 0.6341500002733197, 'gt.brier_score': 0.17096242975213496, 'syn_id.c_index': 0.5963303172720473, 'syn_id.brier_score': 0.18972298382369654, 'syn_ood.c_index': 0.0, 'syn_ood.brier_score': 1.0}
metabric mlp survival_gan 2 {'gt.c_index': 0.6341500002733197, 'gt.brier_score': 0.17096242975213496, 'syn_id.c_index': 0.5765371676049208, 'syn_id.brier_score': 0.17498457365085787, 'syn_ood.c_index': 0.0, 'syn_ood.brier_score': 1.0}
metabric mlp ctgan 0 {'gt.c_index': 0.6341500002733197, 'gt.brier_score': 0.17096242975213496, 'syn_id.c_index': 0.5827398561563402, 'syn_id.brier_score': 0.17423682934493914, 'syn_ood.c_index': 0.0, 'syn_ood.brier_score': 1.0}
metabric mlp ctgan 1 {'gt.c_index': 0.6341500002

In [13]:
results

{'aids': {'mlp': {'gt': {'c_index': 0.7059304149697022,
    'brier_score': 0.06806232093352434},
   'survival_gan': {'c_index': [0.6454852592877075,
     0.7165071752721839,
     0.7644641382612013],
    'brier_score': [0.07742694851561303,
     0.07776424405215514,
     0.07990441130531196]},
   'ctgan': {'c_index': [0.0, 0.520568649905171, 0.5542162720615037],
    'brier_score': [1.0, 0.07613622330571594, 0.10876656241751877]},
   'nflow': {'c_index': [0.3836290426816391,
     0.37855223712306135,
     0.5309118212646274],
    'brier_score': [0.1360939367406394,
     0.16071780684920026,
     0.1273532349081433]},
   'tvae': {'c_index': [0.0, 0.4973162193929452, 0.0],
    'brier_score': [1.0, 0.07745913854287934, 1.0]},
   'privbayes': {'c_index': [0.47586271354782933,
     0.4173548020659998,
     0.2991700915100185],
    'brier_score': [0.07847409831109108,
     0.09323820812133977,
     0.07883306714479758]},
   'adsgan': {'c_index': [0.0, 0.5985419306860772, 0.0],
    'brier_scor