In [1]:
from datasets import get_dataset

from sklearn.manifold import TSNE
import plotly.express as px
from umap import UMAP
import matplotlib.pyplot as plt

import pandas as pd
from synthcity.plugins.models.time_to_event.loader import get_model_template
from synthcity.plugins import Plugins
from synthcity.utils.serialization import save_to_file, load_from_file
from pathlib import Path

out_dir = Path("output")


def plot_2d(
    dataset: str,
    models: list,
    individual: bool = False,
    ci_show: bool = True,
    ci_alpha: float = 0.2,
    **kwargs,
):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)

    Xcov = df.drop(columns=[duration_col, event_col])
    T = df[duration_col]
    E = df[event_col]
    scenario = "synth_gen"

    # tsne = TSNE(n_components=2, random_state=0)
    # proj_real = tsne.fit_transform(Xcov)

    umap_real = UMAP(n_components=2, init="random", random_state=0)
    proj_real = umap_real.fit_transform(Xcov)

    fig = px.scatter(proj_real, x=0, y=1)

    preds = []
    for ref in models:
        model_bkp = out_dir / f"umap_{scenario}_{dataset}_{ref}"

        if model_bkp.exists():
            syn_df = load_from_file(model_bkp)
        else:
            syn_model = Plugins().get(ref, **kwargs)

            try:
                syn_model.fit(df)

                syn_df = syn_model.generate(len(df))
            except BaseException as e:
                print("plugin failed", e)
                continue

                save_to_file(model_bkp, syn_df)

        umap_syn = UMAP(n_components=2, init="random", random_state=0)
        proj_syn = umap_syn.fit_transform(syn_df)

        px.scatter(proj_syn, x=0, y=1)
        fig.show()

In [2]:
baseline_models = ["adsgan", "ctgan", "tvae", "privbayes", "nflow", "survival_gan"]

## AIDS

In [None]:
plot_2d(
    "aids",
    models=baseline_models,
)

plugin failed 


## CUTRACT

In [None]:
plot_dataset_perf_baselines(
    "cutract",
    models=baseline_models,
)

## MAGGIC

In [None]:
plot_dataset_perf_baselines(
    "maggic",
    models=baseline_models,
)

## SEER

In [None]:
plot_dataset_perf_baselines(
    "seer",
    models=baseline_models,
)