In [None]:
from datasets import get_dataset
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
import pandas as pd
from synthcity.plugins.models.time_to_event.loader import get_model_template
from synthcity.plugins import Plugins
import cloudpickle
from synthcity.utils.serialization import save_to_file, load_from_file
from sklearn.model_selection import train_test_split
from pathlib import Path
from synthcity.plugins.models.time_to_event import get_model_template


out_dir = Path("output")

In [None]:
tte_models = [
    "cox_ph",
    "weibull_aft",
    "survival_xgboost",
    "random_survival_forest",
    "deephit",
    "date",
]

In [None]:
from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from pathlib import Path
from synthcity.utils.serialization import save_to_file, load_from_file

out_dir = Path("output")
plt.style.use("seaborn-whitegrid")


def plot_tte(title, X, T, E, models, **kwargs):
    model_names = "_".join([model.name() for model in models])
    fig_name = Path(f"diagrams/tte_km_plot_grouped_{title}_{model_names}.pdf")

    fig, axs = plt.subplots(
        1, len(models), figsize=(4 * len(models), 4), constrained_layout=True
    )

    for idx, model in enumerate(models):
        model_bkp = out_dir / f"tte_km_plot_model_{title}_{model.name()}"
        ax = axs[idx]

        if model_bkp.exists():
            preds = load_from_file(model_bkp)
        else:
            try:
                model.fit(X, T, E)
                preds = model.predict(X)
                save_to_file(model_bkp, preds)
            except BaseException as e:
                print("model failed", e)
                continue

        kmf = KaplanMeierFitter()
        kmf.fit(T, E, label="Real dataset")
        ax = kmf.plot(ax=ax)

        model_name = model.name()
        if model_name == "random_survival_forest":
            model_name = "rsf"

        syn_kmf = KaplanMeierFitter()
        syn_kmf.fit(preds, [1] * len(preds), label=f"Time-to-event model: {model_name}")
        syn_kmf.plot(ax=ax)

        ax.axvline(T[E == 1].max(), color="r", linestyle="--")  # vertical
        ax.set_xlabel("", fontsize=14)

    axs[0].set_ylabel("Event probability", fontsize=14)
    plt.suptitle(f"Dataset: {title}", fontsize=18)
    fig.text(0.5, -0.05, "Days to event", ha="center", fontsize=14)

    plt.savefig(fig_name)

    plt.show()

In [None]:
for dataset in ["aids", "cutract", "maggic", "seer"]:
    models = []
    for n in tte_models:
        models.append(get_model_template(n)())
    df, duration_col, event_col, time_horizons = get_dataset(dataset)

    Xcov = df.drop(columns=[duration_col, event_col])
    T = df[duration_col]
    E = df[event_col]

    plot_tte(dataset, Xcov, T, E, models)