# Generate data optmism plots

_Note_ : Must be executed after experiments_01_benchmark_synthetic_survival_data.ipynb

In [None]:
import platform
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
from lifelines import KaplanMeierFitter
from sklearn.manifold import TSNE
from synthcity.plugins import Plugins
from synthcity.plugins.core.models.survival_analysis.metrics import (
    nonparametric_distance,
)
from synthcity.plugins.core.models.time_to_event.loader import get_model_template
from synthcity.utils.serialization import dataframe_hash, load_from_file, save_to_file

from datasets import get_dataset

out_dir = Path("workspace")
fontsize = 14
plt.style.use("seaborn-whitegrid")


def generate_score(metric: np.ndarray) -> tuple:
    percentile_val = 1.96
    score = (np.mean(metric), percentile_val * np.std(metric) / np.sqrt(len(metric)))

    return round(score[0], 4), round(score[1], 4)


def generate_score_str(metric: np.ndarray) -> str:
    mean, std = generate_score(metric)
    return str(mean) + " +/- " + str(std)


def map_models(model):
    return {
        "survival_gan": "SurvivalGAN",
        "adsgan": "AdsGAN",
        "ctgan": "CTGAN",
        "tvae": "TVAE",
        "privbayes": "PrivBayes",
        "nflow": "nFlows",
    }[model]


def plot_km(
    ax,
    title,
    T: pd.Series,
    E: pd.Series,
    syn_T: pd.Series,
    syn_E: pd.Series,
    ci_show: bool = True,
    ci_alpha: float = 0.2,
    show_censors: bool = False,
    syn_label: str = None,
    xaxis_label: str = None,
    yaxis_label: str = None,
    show_title: bool = True,
) -> None:
    gt_kmf = KaplanMeierFitter()
    gt_kmf.fit(T, E, label="Real data")  # t = Timepoints, Rx: 0=censored, 1=event
    ax = gt_kmf.plot(
        ax=ax, ci_show=ci_show, ci_alpha=ci_alpha, show_censors=show_censors
    )

    if syn_label is None:
        syn_label = "Synthetic data"

    syn_kmf = KaplanMeierFitter()
    syn_kmf.fit(
        syn_T, syn_E, label=syn_label
    )  # t = Timepoints, Rx: 0=censored, 1=event
    ax = syn_kmf.plot(
        ax=ax, ci_show=ci_show, ci_alpha=ci_alpha, show_censors=show_censors
    )
    ax.axvline(T[E == 1].max(), color="r", linestyle="--")  # vertical
    if show_title:
        ax.set_xlabel(title, horizontalalignment="center")
    else:
        ax.set_xlabel("Days", horizontalalignment="center")

    if yaxis_label is None:
        yaxis_label = "Temporal fidelity(Kaplan-Meier)"
    ax.set_ylabel(yaxis_label, fontsize=20)


def plot_grouped_km(
    dataset: str,
    models: list,
    individual: bool = False,
    ci_show: bool = True,
    ci_alpha: float = 0.2,
    save: bool = False,
    **kwargs,
):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)
    df_hash = dataframe_hash(df)

    results = {}
    for repeat_id in range(5):
        model = "survival_gan"
        model_bkp = (
            out_dir
            / f"{df_hash}_{model}_{model}__{platform.python_version()}_{repeat_id}.bkp"
        )
        if not model_bkp.exists():
            continue

        fig, axs = plt.subplots(1, len(models), figsize=(5 * len(models), 4))

        for idx, model in enumerate(models):
            if model not in results:
                results[model] = {
                    "opt": [],
                    "abs_opt": [],
                    "sight": [],
                }
            model_bkp = (
                out_dir
                / f"{df_hash}_{model}_{model}__{platform.python_version()}_{repeat_id}.bkp"
            )
            if model_bkp.exists():
                syn_df = load_from_file(model_bkp).dataframe()
            else:
                continue

            auc_opt, auc_abs_opt, sightedness = nonparametric_distance(
                (df[duration_col], df[event_col]),
                (syn_df[duration_col], syn_df[event_col]),
            )
            results[model]["opt"].append(auc_opt)
            results[model]["abs_opt"].append(auc_abs_opt)
            results[model]["sight"].append(sightedness)
            plot_km(
                axs[idx],
                map_models(model),
                df[duration_col],
                df[event_col],
                syn_df[duration_col],
                syn_df[event_col],
                syn_label=f"Synthetic model: {map_models(model)}",
                show_title=True,
                yaxis_label=f"Seed {repeat_id}",
            )

        if save:
            plt.savefig(
                f"diagrams/grouped_kmplot_synth_gen_{dataset}_{model}_{repeat_id}.pdf"
            )
        plt.show()

    cols = ["model", "opt", "opt_avg", "abs_opt", "abs_opt_avg", "sight", "sight_avg"]
    out = pd.DataFrame([], columns=cols)
    for model in results:
        local = [model]
        for metric in results[model]:
            local.append(generate_score_str(results[model][metric]))
            local.append(np.mean(results[model][metric]))
        out = out.append(pd.DataFrame([local], columns=cols))
    return out

In [None]:
baseline_models = ["survival_gan", "adsgan", "ctgan", "tvae", "privbayes", "nflow"]

In [None]:
aids_metrics = plot_grouped_km("aids", models=baseline_models, save=True)
aids_metrics["dataset"] = "aids"

aids_metrics