In [None]:
import numpy as np
import os

import sys
from pycox import datasets
from lifelines.datasets import load_rossi
from sksurv.datasets import (
    load_aids,
    load_breast_cancer,
    load_flchain,
    load_gbsg2,
    load_whas500,
)
from sklearn.preprocessing import LabelEncoder
import synthcity.logger as log
from medicaldata.CUTRACT import download as cutract_download, load as cutract_load
from medicaldata.SEER_prostate_cancer import (
    download as seer_download,
    load as seer_load,
)
from pathlib import Path

log.add(sink=sys.stderr, level="INFO")


def get_dataset(name: str):
    if name == "metabric":
        df = datasets.metabric.read_df()
    elif name == "support":
        df = datasets.support.read_df()
    elif name == "gbsg":
        df = datasets.gbsg.read_df()
    elif name == "rossi":
        df = load_rossi()
        df = df.rename(columns={"week": "duration", "arrest": "event"})
    elif name == "aids":
        X, Y = load_aids()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "flchain":
        X, Y = load_flchain()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "gbsg2":
        X, Y = load_gbsg2()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "whas500":
        X, Y = load_whas500()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "cutract":
        file_id = "1mew1S3-N2GdVu5nGjaqmpLo7sTKRf4Vj"
        csv_path = Path("data/cutract.csv")
        if not csv_path.exists():
            cutract_download(file_id, csv_path)

        X, T, Y = cutract_load(csv_path, preprocess=False)
        df = X.copy()
        df["event"] = Y
        df["duration"] = T

    elif name == "seer":
        file_id = "1PNXLjy8r1xHZq7SspduAMK6SGUTvuwM6"

        csv_path = Path("data/seer.csv")
        if not csv_path.exists():
            seer_download(file_id, csv_path)

        X, T, Y = seer_load(csv_path, preprocess=False)
        df = X.copy()
        df["event"] = Y
        df["duration"] = T
    for col in df.columns:
        if df[col].dtype.name in ["object", "category"]:
            df[col] = LabelEncoder().fit_transform(df[col])

    duration_col = "duration"
    event_col = "event"

    df = df.fillna(0)

    T = df[duration_col]

    time_horizons = np.linspace(T.min(), T.max(), num=5)[1:-1].tolist()

    return df, duration_col, event_col, time_horizons

In [None]:
from lifelines import KaplanMeierFitter
from lifelines.datasets import load_waltons

waltons = load_waltons()

kmf = KaplanMeierFitter(label="waltons_data")
kmf.fit(waltons["T"], waltons["E"])
kmf.plot(ci_alpha=0.1)

In [None]:
df, duration_col, event_col, time_horizons = get_dataset("cutract")

df

In [None]:
from synthcity.plugins import Plugins
from synthcity.benchmark import Benchmarks

plugins = Plugins().list(skip_debug=True)

plugins

In [None]:
base_plugins = [
    "privbayes",
    "adsgan",
    "bayesian_network",
    "ctgan",
    "tvae",
    "nflow",
]
survival_plugins = [
    "survival_gan",
    "survival_ctgan",
    "survival_tvae",
    "survival_bayesian_network",
    "survival_nflow",
]

repeats = 3

## AIDS dataset

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("aids")

df

In [None]:
# base_plugins
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)

## FLChain dataset

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("flchain")

df

In [None]:
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)

## gbsg2 dataset

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("gbsg2")

df

In [None]:
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)

## Metabric

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("gbsg2")

df

In [None]:
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)

## gbsg

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("gbsg")

df

In [None]:
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)

## Support

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("support")

df

In [None]:
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)

## CUTRACT

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("cutract")

df

In [None]:
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)

## SEER prostate 

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("seer")

df

In [None]:
base_score = Benchmarks.evaluate(
    base_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(base_score)

In [None]:
# survival plugins
survival_score = Benchmarks.evaluate(
    survival_plugins,
    df,
    task_type="survival_analysis",
    target_column=event_col,
    time_to_event_column=duration_col,
    time_horizons=time_horizons,
    synthetic_size=len(df),
    repeats=repeats,
)

In [None]:
Benchmarks.print(survival_score)