In [1]:
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import synthcity.logger as log

from datasets import get_dataset

warnings.filterwarnings("ignore", category=FutureWarning)

log.remove()
warnings.filterwarnings("ignore")
log.add(sink=sys.stderr, level="DEBUG")

In [2]:
from pathlib import Path

import tabulate
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.metrics.eval_statistical import WassersteinDistance
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

datasets = [
    "aids",
    "cutract",
    "maggic",
    "seer",
]
methods = ["survival_gan", "ctgan", "nflow", "tvae", "privbayes", "adsgan"]

out_dir = Path("workspace")
headers = ["dataset"] + methods
distances = []
for ref_df in [
    "aids",
    "cutract",
    "maggic",
    "seer",
]:

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)

    real_dataloader = SurvivalAnalysisDataLoader(
        df, time_to_event_column=event_col, target_column=event_col
    )
    local_distance = [ref_df]
    for method in methods:
        scores = []
        for seed in range(3):
            model_bkp = out_dir / f"{df_hash}_{method}_{seed}.bkp"
            syn_df = load_from_file(model_bkp)
            try:
                syn_df = syn_df.dataframe()
            except BaseException as e:
                pass

            syn_dataloader = SurvivalAnalysisDataLoader(
                syn_df, time_to_event_column=event_col, target_column=event_col
            )
            score = WassersteinDistance().evaluate(real_dataloader, syn_dataloader)[
                "joint"
            ]
            assert not np.isnan(score), score
            scores.append(score)
        final_score = print_score(generate_score(scores))

        local_distance.append(final_score)
    distances.append(local_distance)

tabulate.tabulate(distances, headers=headers, tablefmt="html")

Evaluate  aids
Evaluate  cutract
Evaluate  maggic
Evaluate  seer


dataset,survival_gan,ctgan,nflow,tvae,privbayes,adsgan
aids,0.153 +/- 0.021,0.851 +/- 0.072,1.701 +/- 0.193,1.967 +/- 0.053,1.08 +/- 0.273,1.694 +/- 0.505
cutract,0.228 +/- 0.111,0.032 +/- 0.006,0.325 +/- 0.05,1.43 +/- 0.199,0.101 +/- 0.002,2.393 +/- 0.206
maggic,0.441 +/- 0.124,0.621 +/- 0.053,2.086 +/- 0.025,0.892 +/- 0.295,0.828 +/- 0.014,2.207 +/- 0.549
seer,0.42 +/- 0.291,0.019 +/- 0.007,0.169 +/- 0.04,1.902 +/- 0.093,0.146 +/- 0.011,2.108 +/- 0.0


In [4]:
## metabric

from pathlib import Path

import tabulate
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.metrics.eval_statistical import WassersteinDistance
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

out_dir = Path("workspace_rebuttal")
methods = ["survival_gan", "ctgan", "nflow", "tvae", "privbayes", "adsgan"]

headers = ["dataset"] + methods
distances = []

for ref_df in ["metabric"]:

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)

    real_dataloader = SurvivalAnalysisDataLoader(
        df, time_to_event_column=event_col, target_column=event_col
    )
    local_distance = [ref_df]
    for method in methods:
        scores = []
        for seed in range(3):
            model_bkp = out_dir / f"{df_hash}_{method}_{method}__{seed}.bkp"
            syn_df = load_from_file(model_bkp)
            try:
                syn_df = syn_df.dataframe()
            except BaseException as e:
                pass

            syn_dataloader = SurvivalAnalysisDataLoader(
                syn_df, time_to_event_column=event_col, target_column=event_col
            )
            score = WassersteinDistance().evaluate(real_dataloader, syn_dataloader)[
                "joint"
            ]
            assert not np.isnan(score), score
            scores.append(score)
        final_score = print_score(generate_score(scores))

        local_distance.append(final_score)
    distances.append(local_distance)

tabulate.tabulate(distances, headers=headers, tablefmt="html")

Evaluate  metabric


dataset,survival_gan,ctgan,nflow,tvae,privbayes,adsgan
metabric,9.613 +/- 0.074,16.309 +/- 0.92,14.649 +/- 0.764,9.654 +/- 0.159,15.718 +/- 4.394,17.806 +/- 0.281
