In [1]:
import numpy as np

import sys
from pycox import datasets
from lifelines.datasets import load_rossi
from sksurv.datasets import (
    load_aids,
    load_breast_cancer,
    load_flchain,
    load_gbsg2,
    load_whas500,
)
from sklearn.preprocessing import LabelEncoder
import synthcity.logger as log

log.add(sink=sys.stderr, level="INFO")

def get_dataset(name: str):
    if name == "metabric":
        df = datasets.metabric.read_df()
    elif name == "support":
        df = datasets.support.read_df()
    elif name == "gbsg":
        df = datasets.gbsg.read_df()
    elif name == "rossi":
        df = load_rossi()
        df = df.rename(columns={"week": "duration", "arrest": "event"})
    elif name == "aids":
        X, Y = load_aids()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "flchain":
        X, Y = load_flchain()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "gbsg2":
        X, Y = load_gbsg2()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]
    elif name == "whas500":
        X, Y = load_whas500()
        Y_unp = np.array(Y, dtype=[("event", "int"), ("duration", "float")])
        df = X.copy()
        df["event"] = Y_unp["event"]
        df["duration"] = Y_unp["duration"]

    for col in df.columns:
        if df[col].dtype.name in ["object", "category"]:
            df[col] = LabelEncoder().fit_transform(df[col])

    duration_col = "duration"
    event_col = "event"

    df = df.fillna(0)
    
    T = df[duration_col]
    
    time_horizons = np.linspace(T.min(), T.max(), num=5)[1:-1].tolist()

    return df, duration_col, event_col, time_horizons

In [2]:
from synthcity.plugins import Plugins
from synthcity.benchmark import Benchmarks

plugins = Plugins().list(skip_debug = True)

plugins

['gaussian_copula',
 'survival_tvae',
 'survival_bayesian_network',
 'nflow',
 'bayesian_network',
 'survival_ctgan',
 'adsgan',
 'tvae',
 'survival_gan',
 'privbayes',
 'survival_adsgan',
 'ctgan',
 'copulagan',
 'marginal_distributions',
 'rtvae',
 'survival_nflow',
 'pategan']

In [3]:
plugins = [
    'survival_gan',
    'privbayes',
    'adsgan',
    'bayesian_network',
    'ctgan',
    'tvae',
    'nflow',
    'survival_ctgan',
    'survival_tvae',
    'survival_bayesian_network',
    'survival_nflow',
]
repeats = 3

## AIDS dataset

In [4]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("aids")

df

Unnamed: 0,age,cd4,hemophil,ivdrug,karnof,priorzdv,raceth,sex,strat2,tx,txgrp,event,duration
0,34.0,169.0,0,0,0,39.0,0,0,1,0,0,0,189.0
1,34.0,149.5,0,0,3,15.0,1,1,1,0,0,0,287.0
2,20.0,23.5,1,0,0,9.0,0,0,0,1,1,0,242.0
3,48.0,46.0,0,0,3,53.0,0,0,1,0,0,0,199.0
4,46.0,10.0,0,2,3,12.0,0,0,0,1,1,0,286.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1146,44.0,65.5,0,0,0,103.0,0,0,1,1,1,0,273.0
1147,41.0,7.5,0,0,2,20.0,1,0,0,1,1,1,47.0
1148,43.0,170.0,0,2,3,27.0,1,0,1,0,0,0,272.0
1149,44.0,282.5,0,2,2,12.0,0,0,1,0,0,0,192.0


In [5]:
score = Benchmarks.evaluate(
    plugins,
    df,
    task_type = "survival_analysis",
    target_column = event_col,
    time_to_event_column = duration_col,
    time_horizons = time_horizons,
    synthetic_size = len(df),
    repeats = repeats,
)


[2022-05-04T19:51:25.943137+0300][128303][INFO] Benchmarking plugin : survival_gan
[2022-05-04T19:51:25.944488+0300][128303][INFO]  Experiment repeat: 0
[2022-05-04T19:59:23.280813+0300][128303][ERROR] Failed to evaluate synthetic performance. cox_ph: delta contains nan value(s). Convergence halted. Please see the following tips in the lifelines documentation: https://lifelines.readthedocs.io/en/latest/Examples.html#problems-with-convergence-in-the-cox-proportional-hazard-model
[2022-05-04T20:00:30.511435+0300][128303][INFO]  Experiment repeat: 1
[2022-05-04T20:09:40.833245+0300][128303][INFO]  Experiment repeat: 2
[2022-05-04T20:18:39.194353+0300][128303][INFO] Benchmarking plugin : privbayes
[2022-05-04T20:18:39.195158+0300][128303][INFO]  Experiment repeat: 0
[2022-05-04T20:19:49.402978+0300][128303][INFO]  Experiment repeat: 1
[2022-05-04T20:20:59.562638+0300][128303][INFO]  Experiment repeat: 2
[2022-05-04T20:22:09.488974+0300][128303][INFO] Benchmarking plugin : adsgan
[2022-05-0

In [6]:
Benchmarks.print(score)


[4m[1mComparatives[0m[0m


Unnamed: 0,survival_gan,privbayes,adsgan,bayesian_network,ctgan,tvae,nflow,survival_ctgan,survival_tvae,survival_bayesian_network,survival_nflow
sanity.data_mismatch.score,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
sanity.common_rows_proportion.score,0.0,0.0,0.0,0.046626,0.0,0.0,0.0,0.0,0.0,0.023458,0.0
sanity.nearest_syn_neighbor_distance.mean,0.122971,0.12264,0.088133,0.117401,0.087674,0.082663,0.080699,0.139765,0.24956,0.116652,0.096006
sanity.close_values_probability.score,0.814075,0.886476,0.924414,0.889951,0.965827,0.924703,0.961483,0.757892,0.54764,0.864466,0.916594
sanity.distant_values_probability.score,0.002027,0.001738,0.001158,0.001738,0.001738,0.002606,0.002027,0.001448,0.019114,0.000869,0.001738
stats.jensenshannon_dist.marginal,0.146731,0.097651,0.096583,0.03512,0.121531,0.183452,0.063241,0.19272,0.261684,0.11259,0.130447
stats.chi_squared_test.marginal,0.464525,0.748581,0.47426,0.76151,0.727763,0.222259,0.857147,0.648351,0.143921,0.818769,0.732552
stats.feature_corr.joint,1.994026,2.006168,2.494336,0.732558,1.989845,2.055663,0.685895,2.213278,2.767587,1.533355,1.533333
stats.inv_kl_divergence.marginal,0.834091,0.923741,0.884493,0.987808,0.930694,0.641119,0.971272,0.843153,0.537013,0.902814,0.89372
stats.ks_test.marginal,0.830426,0.909555,0.900688,0.965225,0.866894,0.822028,0.931542,0.767671,0.716434,0.867473,0.840072


## FLChain dataset

In [4]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("flchain")

df

Unnamed: 0,age,chapter,creatinine,flc.grp,kappa,lambda,mgus,sample.yr,sex,event,duration
0,97.0,1,1.7,1,5.700,4.860,0,2,0,1,85.0
1,92.0,12,0.9,0,0.870,0.683,0,5,0,1,1281.0
2,94.0,1,1.4,1,4.360,3.850,0,2,0,1,69.0
3,92.0,1,1.0,9,2.420,2.220,0,1,0,1,115.0
4,93.0,1,1.1,6,1.320,1.690,0,1,0,1,1039.0
...,...,...,...,...,...,...,...,...,...,...,...
7869,52.0,16,1.0,6,1.210,1.610,0,0,0,0,4997.0
7870,52.0,16,0.8,0,0.858,0.581,0,4,0,0,3652.0
7871,54.0,16,0.0,8,1.700,1.720,0,7,0,0,2507.0
7872,53.0,16,0.0,9,1.710,2.690,0,0,0,0,4982.0


In [None]:
score = Benchmarks.evaluate(
    plugins,
    df,
    task_type = "survival_analysis",
    target_column = event_col,
    time_to_event_column = duration_col,
    time_horizons = time_horizons,
    synthetic_size = len(df),
    repeats = repeats,
)


[2022-05-04T22:28:41.755781+0300][130765][INFO] Benchmarking plugin : survival_gan
[2022-05-04T22:28:41.757035+0300][130765][INFO]  Experiment repeat: 0
[2022-05-04T23:20:14.730509+0300][130765][ERROR] Failed to evaluate synthetic performance. cox_ph: delta contains nan value(s). Convergence halted. Please see the following tips in the lifelines documentation: https://lifelines.readthedocs.io/en/latest/Examples.html#problems-with-convergence-in-the-cox-proportional-hazard-model
[2022-05-04T23:24:30.350751+0300][130765][INFO]  Experiment repeat: 1
[2022-05-05T00:16:06.799151+0300][130765][ERROR] Failed to evaluate synthetic performance. cox_ph: delta contains nan value(s). Convergence halted. Please see the following tips in the lifelines documentation: https://lifelines.readthedocs.io/en/latest/Examples.html#problems-with-convergence-in-the-cox-proportional-hazard-model
[2022-05-05T00:20:31.198485+0300][130765][INFO]  Experiment repeat: 2
[2022-05-05T01:12:10.043415+0300][130765][ERROR

In [None]:
Benchmarks.print(score)

## gbsg2 dataset

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("gbsg2")

df

In [None]:
score = Benchmarks.evaluate(
    plugins,
    df,
    task_type = "survival_analysis",
    target_column = event_col,
    time_to_event_column = duration_col,
    time_horizons = time_horizons,
    synthetic_size = len(df),
    repeats = repeats,
)


In [None]:
Benchmarks.print(score)

## Metabric

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("gbsg2")

df

In [None]:
score = Benchmarks.evaluate(
    plugins,
    df,
    task_type = "survival_analysis",
    target_column = event_col,
    time_to_event_column = duration_col,
    time_horizons = time_horizons,
    synthetic_size = len(df),
    repeats = repeats,
)


In [None]:
Benchmarks.print(score)

## gbsg

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("gbsg")

df

In [None]:
score = Benchmarks.evaluate(
    plugins,
    df,
    task_type = "survival_analysis",
    target_column = event_col,
    time_to_event_column = duration_col,
    time_horizons = time_horizons,
    synthetic_size = len(df),
    repeats = repeats,
)


In [None]:
Benchmarks.print(score)

## Support

In [None]:
import pandas as pd

df, duration_col, event_col, time_horizons = get_dataset("support")

df

In [None]:
score = Benchmarks.evaluate(
    plugins,
    df,
    task_type = "survival_analysis",
    target_column = event_col,
    time_to_event_column = duration_col,
    time_horizons = time_horizons,
    synthetic_size = len(df),
    repeats = repeats,
)


In [None]:
Benchmarks.print(score)