In [None]:
import sys

from datasets import get_dataset
from synthcity.plugins import Plugins
from synthcity.benchmark import Benchmarks
from pathlib import Path
from synthcity.utils.serialization import save_to_file, load_from_file
import synthcity.logger as log

log.add(sink=sys.stderr, level="INFO")

plugins = Plugins().list()

out_dir = Path("output")


def evaluate_dataset(name: str, dataset: str, plugins: list):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)

    for plugin in plugins:
        bkp = out_dir / f"metrics.{dataset}_{name}_{plugin}.bkp"

        if bkp.exists():
            score = load_from_file(bkp)
        else:
            score = Benchmarks.evaluate(
                [plugin],
                df,
                task_type="survival_analysis",
                target_column=event_col,
                time_to_event_column=duration_col,
                time_horizons=time_horizons,
                synthetic_size=len(df),
                repeats=repeats,
            )
            save_to_file(bkp, score)

    return score

In [None]:
base_plugins = [
    "privbayes",
    "adsgan",
    "ctgan",
    "tvae",
    "nflow",
]
survival_plugins = [
    "survival_gan",
]

repeats = 3  # 5

## AIDS dataset

In [None]:
base_score = evaluate_dataset("baseline", "aids", base_plugins)
Benchmarks.print(base_score)

In [None]:
survival_score = evaluate_dataset("survival", "aids", survival_plugins)
Benchmarks.print(survival_score)

## CUTRACT

In [None]:
base_score = evaluate_dataset("baseline", "cutract", base_plugins)
Benchmarks.print(base_score)

In [None]:
survival_score = evaluate_dataset("survival", "cutract", survival_plugins)
Benchmarks.print(survival_score)

## MAGGIC dataset 

In [None]:
base_score = evaluate_dataset("baseline", "maggic", base_plugins)
Benchmarks.print(base_score)

In [None]:
survival_score = evaluate_dataset("survival", "maggic", survival_plugins)
Benchmarks.print(survival_score)

## SEER prostate 

In [None]:
base_score = evaluate_dataset("baseline", "seer", base_plugins)
Benchmarks.print(base_score)

In [None]:
survival_score = evaluate_dataset("survival", "seer", survival_plugins)
Benchmarks.print(survival_score)