In [1]:
from datasets import get_dataset
from plots import plot_survival_individual, plot_survival_grouped

import matplotlib.pyplot as plt
import pandas as pd
from synthcity.plugins.models.time_to_event.loader import get_model_template
from synthcity.plugins import Plugins
from synthcity.utils.serialization import save_to_file, load_from_file
from pathlib import Path

In [2]:
out_dir = Path("output")


def plot_dataset_perf_baselines(
    dataset: str,
    ci_show: bool = True,
    ci_alpha: float = 0.2,
    **kwargs,
):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)

    scenario = "imbalanced_sampler"

    Xcov = df.drop(columns=[duration_col, event_col])
    T = df[duration_col]
    E = df[event_col]

    preds = []
    for dataloader_sampling_strategy in [
        "none",
        "imbalanced_censoring",
        "imbalanced_time_censoring",
        "imbalanced_cov_censoring",
        "imbalanced_full",
    ]:
        model_bkp = (
            out_dir / f"experiment_{scenario}_{dataset}_{dataloader_sampling_strategy}"
        )

        print("eval ", model_bkp)
        label = f"sampling strategy: {dataloader_sampling_strategy}"

        if model_bkp.exists():
            syn_df = load_from_file(model_bkp)
        else:
            syn_model = Plugins().get(
                "survival_gan",
                dataloader_sampling_strategy=dataloader_sampling_strategy,
            )

            try:
                syn_model.fit(df)

                syn_df = syn_model.generate(len(df))
            except BaseException as e:
                print("plugin failed", e)
                continue

            save_to_file(model_bkp, syn_df)

        syn_T = syn_df[duration_col]
        syn_E = syn_df[event_col]

        local_data = (label, syn_T, syn_E)

        plot_survival_individual(
            scenario,
            dataset,
            label,
            T,
            E,
            [local_data],
            ci_show=ci_show,
            ci_alpha=ci_alpha,
        )
        preds.append(local_data)

    plot_survival_grouped(
        scenario, dataset, T, E, preds, ci_show=ci_show, ci_alpha=ci_alpha
    )

## AIDS

In [None]:
plot_dataset_perf_baselines(
    "aids",
)

eval  output/experiment_imbalanced_sampler_aids_none


## CUTRACT

In [None]:
plot_dataset_perf_baselines(
    "cutract",
)

## MAGGIC

In [None]:
plot_dataset_perf_baselines(
    "maggic",
)

## METABRIC

In [None]:
plot_dataset_perf_baselines(
    "metabric",
)

## SEER

In [None]:
plot_dataset_perf_baselines(
    "survival_function",
    "seer",
    models=baseline_models,
)