In [None]:
import sys

from synthcity.plugins import Plugins
from synthcity.benchmark import Benchmarks
import synthcity.logger as log

from pathlib import Path
from datasets import get_dataset
from synthcity.utils.serialization import save_to_file, load_from_file

log.add(sink=sys.stderr, level="INFO")

plugins = Plugins().list()

out_dir = Path("output")

gain_scenarios = [
    (
        "uncensoring_with_date",
        {
            "uncensoring_model": "date",
            "tte_strategy": "uncensoring",
            "dataloader_sampling_strategy": "none",
            "use_conditional": False,
        },
    ),
    (
        "survival_function_with_date",
        {
            "uncensoring_model": "date",
            "tte_strategy": "survival_function",
            "dataloader_sampling_strategy": "none",
            "use_conditional": False,
        },
    ),
    (
        "survival_function_regression_no_sampling",
        {
            "uncensoring_model": "survival_function_regression",
            "tte_strategy": "survival_function",
            "dataloader_sampling_strategy": "none",
            "use_conditional": False,
        },
    ),
    (
        "survival_function_regression_imbalanced_cens_sampling",
        {
            "uncensoring_model": "survival_function_regression",
            "tte_strategy": "survival_function",
            "dataloader_sampling_strategy": "imbalanced_censoring",
            "use_conditional": False,
        },
    ),
    (
        "survival_function_regression_imbalanced_time_cens_no_cond",
        {
            "uncensoring_model": "survival_function_regression",
            "tte_strategy": "survival_function",
            "dataloader_sampling_strategy": "imbalanced_time_censoring",
            "use_conditional": False,
        },
    ),
    (
        "survival_function_regression_imbalanced_time_cens_sampling",
        {
            "uncensoring_model": "survival_function_regression",
            "tte_strategy": "survival_function",
            "dataloader_sampling_strategy": "imbalanced_time_censoring",
        },
    ),
]


def evaluate_dataset(dataset: str, scenarios: list):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)
    # experiment = "gain_of_function_parametric"
    experiment = "sources_of_gain_parametric"
    for scenario_name, scenario_args in scenarios:
        bkp = out_dir / f"experiment_{experiment}_{dataset}_{scenario_name}.bkp"

        scenario_args["device"] = "cpu"

        if bkp.exists():
            score = load_from_file(bkp)
        else:
            score = Benchmarks.evaluate(
                ["survival_gan"],
                df,
                task_type="survival_analysis",
                target_column=event_col,
                time_to_event_column=duration_col,
                time_horizons=time_horizons,
                synthetic_size=len(df),
                repeats=repeats,
                plugin_kwargs={"survival_gan": scenario_args},
            )
            save_to_file(bkp, score)

        print("Scenario", scenario_name, scenario_args)
        Benchmarks.print(score)

In [None]:
repeats = 3

## AIDS dataset

In [None]:
base_score = evaluate_dataset("aids", gain_scenarios)