In [1]:
!export LD_LIBRARY_PATH=/usr/local/lib64:$LD_LIBRARY_PATH 

import sys

from datasets import get_dataset
from synthcity.plugins import Plugins
from synthcity.benchmark import Benchmarks
from pathlib import Path
from synthcity.utils.serialization import save_to_file, load_from_file
import synthcity.logger as log

log.add(sink=sys.stderr, level="INFO")

plugins = Plugins().list()

out_dir = Path("output")

metric = "detection_mlp"
cat = "detection"

def evaluate_dataset(name: str, dataset: str, plugins: list):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)

    for plugin in plugins:
        bkp = out_dir / f"{cat}.{metric}.{dataset}_{name}_{plugin}.bkp"

        if bkp.exists():
            score = load_from_file(bkp)
        else:
            score = Benchmarks.evaluate(
                [plugin],
                df,
                task_type="survival_analysis",
                target_column=event_col,
                time_to_event_column=duration_col,
                time_horizons=time_horizons,
                synthetic_size=len(df),
                repeats=repeats,
                metrics = {
                    cat: [metric],
                },
            )
            save_to_file(bkp, score)

    return score

In [2]:
base_plugins = [
    "privbayes",
    "adsgan",
    "ctgan",
    "tvae",
    "nflow",
]
survival_plugins = [
    "survival_gan",
]

repeats = 3 #5

## AIDS dataset

In [3]:
base_score = evaluate_dataset("baseline", "aids", base_plugins)
Benchmarks.print(base_score)

[2022-05-24T09:38:07.094935+0300][124247][INFO] Benchmarking plugin : privbayes
[2022-05-24T09:38:07.098614+0300][124247][INFO]  Experiment repeat: 0 task type: survival_analysis Train df hash = 7600251698133035800
[2022-05-24T09:38:35.156840+0300][124247][INFO]  Experiment repeat: 1 task type: survival_analysis Train df hash = 4871573768128818830
[2022-05-24T09:39:01.663281+0300][124247][INFO]  Experiment repeat: 2 task type: survival_analysis Train df hash = 5947425095989233042
[2022-05-24T09:39:28.061866+0300][124247][INFO] Benchmarking plugin : adsgan
[2022-05-24T09:39:28.064808+0300][124247][INFO]  Experiment repeat: 0 task type: survival_analysis Train df hash = 7600251698133035800
[2022-05-24T09:39:54.915448+0300][124247][INFO]  Experiment repeat: 1 task type: survival_analysis Train df hash = 4871573768128818830
[2022-05-24T09:40:21.283266+0300][124247][INFO]  Experiment repeat: 2 task type: survival_analysis Train df hash = 5947425095989233042
[2022-05-24T09:40:47.924992+0300]


[4m[1mPlugin : nflow[0m[0m


Unnamed: 0,min,max,mean,stddev,median,iqr,rounds,errors,durations
detection.detection_mlp.mean,0.895673,0.935881,0.914321,0.016543,0.91141,0.020104,3,0,26.82





In [None]:
survival_score = evaluate_dataset("survival", "aids", survival_plugins)
Benchmarks.print(survival_score)

[2022-05-24T09:44:48.859220+0300][124247][INFO] Benchmarking plugin : survival_gan
[2022-05-24T09:44:48.862369+0300][124247][INFO]  Experiment repeat: 0 task type: survival_analysis Train df hash = 7600251698133035800
[2022-05-24T09:45:15.894345+0300][124247][INFO]  Experiment repeat: 1 task type: survival_analysis Train df hash = 4871573768128818830


## CUTRACT

In [None]:
    
base_score = evaluate_dataset("baseline", "cutract", base_plugins)
Benchmarks.print(base_score)

In [None]:
survival_score = evaluate_dataset("survival", "cutract", survival_plugins)
Benchmarks.print(survival_score)

## MAGGIC dataset 

In [None]:
base_score = evaluate_dataset("baseline", "maggic", base_plugins)
Benchmarks.print(base_score)

In [None]:
survival_score = evaluate_dataset("survival", "maggic", survival_plugins)
Benchmarks.print(survival_score)

## SEER prostate 

In [None]:
base_score = evaluate_dataset("baseline", "seer", base_plugins)
Benchmarks.print(base_score)

In [None]:
survival_score = evaluate_dataset("survival", "seer", survival_plugins)
Benchmarks.print(survival_score)