In [1]:
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import synthcity.logger as log

from datasets import get_dataset

warnings.filterwarnings("ignore", category=FutureWarning)

log.remove()
warnings.filterwarnings("ignore")
log.add(sink=sys.stderr, level="DEBUG")

In [2]:
from pathlib import Path

import tabulate
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.metrics.eval_statistical import JensenShannonDistance
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

datasets = [
    "aids",
    "cutract",
    "maggic",
    "seer",
]
methods = ["survival_gan", "ctgan", "nflow", "tvae", "privbayes", "adsgan"]

out_dir = Path("workspace")
headers = ["dataset"] + methods
distances = []
for ref_df in [
    "aids",
    "cutract",
    "maggic",
    "seer",
]:

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)

    real_dataloader = SurvivalAnalysisDataLoader(
        df, time_to_event_column=event_col, target_column=event_col
    )
    local_distance = [ref_df]
    for method in methods:
        scores = []
        for seed in range(3):
            model_bkp = out_dir / f"{df_hash}_{method}_{seed}.bkp"
            syn_df = load_from_file(model_bkp)
            try:
                syn_df = syn_df.dataframe()
            except BaseException as e:
                pass

            syn_dataloader = SurvivalAnalysisDataLoader(
                syn_df, time_to_event_column=event_col, target_column=event_col
            )
            score = JensenShannonDistance().evaluate(real_dataloader, syn_dataloader)[
                "marginal"
            ]
            assert not np.isnan(score), score
            scores.append(score)
        final_score = print_score(generate_score(scores))

        local_distance.append(final_score)
    distances.append(local_distance)

tabulate.tabulate(distances, headers=headers, tablefmt="html")

Evaluate  aids
Evaluate  cutract
Evaluate  maggic
Evaluate  seer


dataset,survival_gan,ctgan,nflow,tvae,privbayes,adsgan
aids,0.012 +/- 0.002,0.031 +/- 0.001,0.048 +/- 0.003,0.054 +/- 0.006,0.028 +/- 0.004,0.052 +/- 0.007
cutract,0.024 +/- 0.01,0.011 +/- 0.001,0.038 +/- 0.001,0.04 +/- 0.003,0.02 +/- 0.0,0.054 +/- 0.013
maggic,0.013 +/- 0.002,0.015 +/- 0.002,0.044 +/- 0.001,0.017 +/- 0.001,0.008 +/- 0.0,0.039 +/- 0.012
seer,0.022 +/- 0.001,0.008 +/- 0.001,0.03 +/- 0.004,0.034 +/- 0.002,0.022 +/- 0.0,0.036 +/- 0.0


In [6]:
## metabric

from pathlib import Path

import tabulate
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.metrics.eval_statistical import JensenShannonDistance
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

out_dir = Path("workspace_rebuttal")
methods = ["survival_gan", "ctgan", "nflow", "tvae", "privbayes", "adsgan"]

headers = ["dataset"] + methods
distances = []

for ref_df in ["metabric"]:

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)

    real_dataloader = SurvivalAnalysisDataLoader(
        df, time_to_event_column=event_col, target_column=event_col
    )
    local_distance = [ref_df]
    for method in methods:
        scores = []
        for seed in range(3):
            model_bkp = out_dir / f"{df_hash}_{method}_{method}__{seed}.bkp"
            syn_df = load_from_file(model_bkp)
            try:
                syn_df = syn_df.dataframe()
            except BaseException as e:
                pass

            syn_dataloader = SurvivalAnalysisDataLoader(
                syn_df, time_to_event_column=event_col, target_column=event_col
            )
            score = JensenShannonDistance().evaluate(real_dataloader, syn_dataloader)[
                "marginal"
            ]
            assert not np.isnan(score), score
            scores.append(score)
        final_score = print_score(generate_score(scores))

        local_distance.append(final_score)
    distances.append(local_distance)

tabulate.tabulate(distances, headers=headers, tablefmt="html")

Evaluate  metabric


dataset,survival_gan,ctgan,nflow,tvae,privbayes,adsgan
metabric,0.008 +/- 0.0,0.015 +/- 0.0,0.007 +/- 0.0,0.008 +/- 0.0,0.043 +/- 0.008,0.041 +/- 0.002
