In [3]:
import sys
from pathlib import Path

import synthcity.logger as log
from synthcity.benchmark import Benchmarks
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import load_from_file, save_to_file

from datasets import get_dataset

log.remove()
log.add(sink=sys.stderr, level="INFO")

out_dir = Path("output")


def evaluate_dataset(dataset: str, plugin: tuple, repeats: int = 5):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)
    dataloader = SurvivalAnalysisDataLoader(
        df,
        target_column=event_col,
        time_to_event_column=duration_col,
        time_horizons=time_horizons,
    )

    score = Benchmarks.evaluate(
        [plugin],
        dataloader,
        task_type="survival_analysis",
        target_column=event_col,
        time_to_event_column=duration_col,
        time_horizons=time_horizons,
        synthetic_size=10 * len(df),
        repeats=repeats,
        metrics={"performance": ["linear_model", "xgb"],},
        workspace=Path("workspace_rebuttal"),
    )

    return score

In [4]:
evaluate_dataset("aids", ("survival_gan", "survival_gan", {}))

[2022-08-07T10:39:41.099059+0300][61974][INFO] Testcase : survival_gan
[2022-08-07T10:39:41.099645+0300][61974][INFO] workspace_rebuttal/1772407149569308346_survival_gan_survival_gan__generator_0.bkp
[2022-08-07T10:39:41.100084+0300][61974][INFO] [testcase] Experiment repeat: 0 task type: survival_analysis Train df hash = 1772407149569308346
[2022-08-07T10:39:41.143944+0300][61974][INFO]  Performance eval for df hash = 2678763277623207393 ood hash = 9102943594045962678
[2022-08-07T10:39:41.732212+0300][61974][INFO] Baseline performance score: {'c_index': (0.7411857900611897, 0.07439925818906981), 'brier_score': (0.0638745778490171, 0.004673590915155082)}
[2022-08-07T10:39:42.167747+0300][61974][INFO] Synthetic ID performance score: {'c_index': (0.6422237426127403, 0.08429824568371114), 'brier_score': (0.15447922527488733, 0.02682873132400233)}
[2022-08-07T10:39:42.466661+0300][61974][INFO] Synthetic OOD performance score: {'c_index': (0.7187026024931593, 0.0976476408474261), 'brier_sco





[2022-08-07T10:43:55.368146+0300][61974][INFO] Baseline performance score: {'c_index': (0.7471664251463902, 0.02464876187281145), 'brier_score': (0.0642595906341933, 0.0031067638293718197)}


[2022-08-07T10:44:01.623980+0300][61974][INFO] Synthetic ID performance score: {'c_index': (0.7321883432214492, 0.009905974172991888), 'brier_score': (0.06293135467123516, 0.0028256114554899557)}


[2022-08-07T10:44:11.771706+0300][61974][INFO] Synthetic OOD performance score: {'c_index': (0.6945300903598842, 0.12369699109610874), 'brier_score': (0.06351847660332295, 0.011159395413589592)}


[2022-08-07T10:44:11.791989+0300][61974][INFO] workspace_rebuttal/1772407149569308346_survival_gan_survival_gan__generator_3.bkp
[2022-08-07T10:44:11.792521+0300][61974][INFO] [testcase] Experiment repeat: 3 task type: survival_analysis Train df hash = 1772407149569308346
[2022-08-07T10:44:11.794337+0300][61974][INFO] 
            Training SurvivalGAN using
                dataloader_sampling_strategy = imbalanced_time_censoring;
                tte_strategy = survival_function;
                uncensoring_model=survival_function_regression
                censoring_strategy = random
                device=cuda
            
[2022-08-07T10:44:11.802625+0300][61974][INFO] Using imbalanced time and censoring sampling
[2022-08-07T10:44:12.244741+0300][61974][INFO] Train the uncensoring model
[2022-08-07T10:44:15.694823+0300][61974][INFO] Train the synthetic generator
[2022-08-07T10:44:16.340929+0300][61974][INFO] Training GAN on device cuda. features = 57
[2022-08-07T10:46:34.145092+0300][

[2022-08-07T10:49:46.748618+0300][61974][ERROR] [cd4] quality loss for constraints ge = 0.0. Remaining 11126. prev length 11287. Original dtype float64.
[2022-08-07T10:49:47.436991+0300][61974][ERROR] [duration] quality loss for constraints le = 364.0. Remaining 11478. prev length 11510. Original dtype float64.
[2022-08-07T10:49:47.500907+0300][61974][INFO]  Performance eval for df hash = 7387946097098023449 ood hash = 4574549252857847638
[2022-08-07T10:49:48.076471+0300][61974][INFO] Baseline performance score: {'c_index': (0.7343020843880352, 0.0635506795719404), 'brier_score': (0.06404864023388862, 0.003001440507468054)}
[2022-08-07T10:49:48.499754+0300][61974][INFO] Synthetic ID performance score: {'c_index': (0.6206730871679912, 0.06344932760082479), 'brier_score': (0.1122089445197915, 0.018894808098772402)}
[2022-08-07T10:49:48.794161+0300][61974][INFO] Synthetic OOD performance score: {'c_index': (0.7180606195055469, 0.13044289902545717), 'brier_score': (0.10204648973947915, 0.0

{'survival_gan':                                                    min       max      mean  \
 performance.linear_model.gt.c_index           0.734302  0.741186  0.736023   
 performance.linear_model.gt.brier_score       0.063875  0.064049  0.064005   
 performance.linear_model.syn_id.c_index       0.620673  0.697232  0.657809   
 performance.linear_model.syn_id.brier_score   0.101031  0.154479  0.123823   
 performance.linear_model.syn_ood.c_index      0.653226  0.718703  0.688131   
 performance.linear_model.syn_ood.brier_score  0.102046  0.149606  0.128743   
 performance.xgb.gt.c_index                    0.727303  0.759002  0.745561   
 performance.xgb.gt.brier_score                0.064082  0.064260  0.064207   
 performance.xgb.syn_id.c_index                0.582190  0.732188  0.679951   
 performance.xgb.syn_id.brier_score            0.062931  0.064932  0.063604   
 performance.xgb.syn_ood.c_index               0.639837  0.694530  0.659042   
 performance.xgb.syn_ood.brier_score









In [24]:
import pandas as pd
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.metrics import generate_score, print_score
from adjutorium.utils.tester import evaluate_survival_estimator
from sklearn.model_selection import train_test_split
from synthcity.utils.serialization import dataframe_hash, load_from_file, save_to_file

out_dir = Path("workspace")


def eval_dataset(dataset: str):
    df, duration_col, event_col, time_horizons = get_dataset(dataset)
    df_hash = dataframe_hash(df)

    cindex = []
    brier_score = []

    for seed in range(5):
        train_df, test_df = train_test_split(df, random_state=seed)

        model_bkp = out_dir / f"{df_hash}_survival_gan_{seed}.bkp"
        syn_df = load_from_file(model_bkp)

        eval_df = pd.concat([syn_df, train_df], ignore_index=True)
        Xtrain = eval_df.drop(columns=[duration_col, event_col])
        Ttrain = eval_df[duration_col]
        Etrain = eval_df[event_col]

        Xtest = test_df.drop(columns=[duration_col, event_col])
        Ttest = test_df[duration_col]
        Etest = test_df[event_col]

        model = RiskEstimation().get("survival_xgboost")
        model.fit(Xtrain, Ttrain, Etrain)

        score = evaluate_survival_estimator(
            [model] * 3,
            Xtest,
            Ttest,
            Etest,
            pretrained=True,
            time_horizons=time_horizons,
        )
        print(score["str"])
        cindex.append(score["clf"]["c_index"][0])
        brier_score.append(score["clf"]["brier_score"][0])

    return print_score(generate_score(cindex)), print_score(generate_score(brier_score))

In [25]:
eval_dataset("aids")

{'c_index': '0.742 +/- 0.117', 'brier_score': '0.06 +/- 0.015', 'aucroc': '0.749 +/- 0.058', 'sensitivity': '0.03 +/- 0.024', 'specificity': '0.995 +/- 0.008', 'PPV': '0.148 +/- 0.157', 'NPV': '0.915 +/- 0.001', 'predicted_cases': '1.333 +/- 1.411'}
{'c_index': '0.679 +/- 0.052', 'brier_score': '0.057 +/- 0.014', 'aucroc': '0.69 +/- 0.096', 'sensitivity': '0.048 +/- 0.044', 'specificity': '0.998 +/- 0.002', 'PPV': '0.222 +/- 0.178', 'NPV': '0.925 +/- 0.003', 'predicted_cases': '1.333 +/- 1.067'}
{'c_index': '0.722 +/- 0.15', 'brier_score': '0.049 +/- 0.01', 'aucroc': '0.734 +/- 0.047', 'sensitivity': '0.0 +/- 0.0', 'specificity': '0.997 +/- 0.005', 'PPV': '0.0 +/- 0.0', 'NPV': '0.923 +/- 0.003', 'predicted_cases': '0.333 +/- 0.533'}
{'c_index': '0.742 +/- 0.068', 'brier_score': '0.064 +/- 0.004', 'aucroc': '0.781 +/- 0.089', 'sensitivity': '0.06 +/- 0.066', 'specificity': '0.985 +/- 0.007', 'PPV': '0.157 +/- 0.141', 'NPV': '0.907 +/- 0.008', 'predicted_cases': '3.667 +/- 2.667'}
{'c_in

('0.72 +/- 0.02', '0.058 +/- 0.004')

In [26]:
eval_dataset("cutract")

{'c_index': '0.827 +/- 0.016', 'brier_score': '0.087 +/- 0.01', 'aucroc': '0.868 +/- 0.012', 'sensitivity': '0.34 +/- 0.022', 'specificity': '0.951 +/- 0.023', 'PPV': '0.51 +/- 0.037', 'NPV': '0.835 +/- 0.011', 'predicted_cases': '116.667 +/- 6.289'}
{'c_index': '0.823 +/- 0.009', 'brier_score': '0.089 +/- 0.008', 'aucroc': '0.855 +/- 0.014', 'sensitivity': '0.314 +/- 0.007', 'specificity': '0.944 +/- 0.013', 'PPV': '0.504 +/- 0.042', 'NPV': '0.823 +/- 0.005', 'predicted_cases': '114.0 +/- 6.468'}
{'c_index': '0.824 +/- 0.022', 'brier_score': '0.085 +/- 0.009', 'aucroc': '0.863 +/- 0.021', 'sensitivity': '0.39 +/- 0.017', 'specificity': '0.939 +/- 0.013', 'PPV': '0.484 +/- 0.02', 'NPV': '0.855 +/- 0.009', 'predicted_cases': '138.667 +/- 8.072'}
{'c_index': '0.833 +/- 0.006', 'brier_score': '0.081 +/- 0.002', 'aucroc': '0.87 +/- 0.005', 'sensitivity': '0.345 +/- 0.013', 'specificity': '0.945 +/- 0.002', 'PPV': '0.491 +/- 0.011', 'NPV': '0.839 +/- 0.002', 'predicted_cases': '122.333 +/- 

('0.824 +/- 0.006', '0.085 +/- 0.003')

In [27]:
eval_dataset("maggic")

{'c_index': '0.649 +/- 0.006', 'brier_score': '0.179 +/- 0.004', 'aucroc': '0.796 +/- 0.018', 'sensitivity': '0.867 +/- 0.002', 'specificity': '0.529 +/- 0.064', 'PPV': '0.935 +/- 0.001', 'NPV': '0.182 +/- 0.005', 'predicted_cases': '3305.667 +/- 13.272'}
{'c_index': '0.643 +/- 0.012', 'brier_score': '0.19 +/- 0.023', 'aucroc': '0.803 +/- 0.002', 'sensitivity': '0.853 +/- 0.013', 'specificity': '0.544 +/- 0.013', 'PPV': '0.932 +/- 0.003', 'NPV': '0.173 +/- 0.013', 'predicted_cases': '3209.333 +/- 47.269'}
{'c_index': '0.653 +/- 0.013', 'brier_score': '0.18 +/- 0.014', 'aucroc': '0.803 +/- 0.008', 'sensitivity': '0.903 +/- 0.001', 'specificity': '0.431 +/- 0.023', 'PPV': '0.933 +/- 0.001', 'NPV': '0.207 +/- 0.002', 'predicted_cases': '3482.667 +/- 2.134'}


FileNotFoundError: [Errno 2] No such file or directory: 'workspace/4879234145147014154_survival_gan_3.bkp'

In [20]:
eval_dataset("seer")

KeyboardInterrupt: 