# Tutorial 2: Generating static survival-analysis data

`synthcity` includes models targeting specific tabular modalities. One of the specific modalities is for survival-analysis data. The general-purpose models can also be used for this task.

The main requirement of survival-analysis is to use a `SurvivalAnalysisDataLoader` dataloader.

In [None]:
# stdlib
import sys
import warnings

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

In [None]:
# third party
from pycox import datasets

df = datasets.gbsg.read_df()
df = df[df["duration"] > 0]

df

In [None]:
loader = SurvivalAnalysisDataLoader(
    df,
    target_column="event",
    time_to_event_column="duration",
)

## List the available generative models

In [None]:
# synthcity absolute
from synthcity.plugins import Plugins

Plugins(categories=["survival_analysis"]).list()

## Load and train a generative model

In [None]:
# synthcity absolute
from synthcity.plugins import Plugins

syn_model = Plugins().get("survival_gan", n_iter = 100)

syn_model.fit(loader)

## Generate new data using the model

In [None]:
syn_model.generate(count=10).dataframe()

## Generate new data using a conditional

We will use the `event` outcome to condition the data generation.

In [None]:
# synthcity absolute
from synthcity.plugins import Plugins

syn_model = Plugins().get("survival_gan", n_iter = 100)

cond = df["event"]

syn_model.fit(loader, cond = cond)

In [None]:
import numpy as np

count = 10
syn_model.generate(count=count, cond = np.ones(count)).dataframe()

## Serialization

In [None]:
# synthcity absolute
from synthcity.utils.serialization import load, load_from_file, save, save_to_file

buff = save(syn_model)

type(buff)

In [None]:
reloaded = load(buff)

reloaded.name()

## Plot real-synthetic distributions

Compared to the general case, the survival-analysis data includes KM plots for real and synthetic data.

In [None]:
# third party
import matplotlib.pyplot as plt

syn_model.plot(plt, loader)

plt.show()

## Benchmark the quality of plugins

For survival analysis, general purpose generators can be used as well.

In [None]:
# synthcity absolute
from synthcity.benchmark import Benchmarks

score = Benchmarks.evaluate(
    [
        (f"test_{model}", model, {}),
        for model in ["adsgan", "survival_gan", "survae"]
    ],
    loader,
    synthetic_size=1000,
    repeats=2,
)

In [None]:
Benchmarks.print(score)

In [None]:
# third party
import numpy as np
import pandas as pd

means = []
for plugin in score:
    data = score[plugin]["mean"]
    directions = score[plugin]["direction"].to_dict()
    means.append(data)

out = pd.concat(means, axis=1)
out.set_axis(score.keys(), axis=1, inplace=True)

bad_highlight = "background-color: lightcoral;"
ok_highlight = "background-color: green;"
default = ""


def highlights(row):
    metric = row.name
    if directions[metric] == "minimize":
        best_val = np.min(row.values)
        worst_val = np.max(row)
    else:
        best_val = np.max(row.values)
        worst_val = np.min(row)

    styles = []
    for val in row.values:
        if val == best_val:
            styles.append(ok_highlight)
        elif val == worst_val:
            styles.append(bad_highlight)
        else:
            styles.append(default)

    return styles


out.style.apply(highlights, axis=1)

# 