In [1]:
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import synthcity.logger as log
from sklearn.mixture import GaussianMixture as GMM

from datasets import get_dataset

warnings.filterwarnings("ignore", category=FutureWarning)

log.remove()
warnings.filterwarnings("ignore")
log.add(sink=sys.stderr, level="DEBUG")

In [2]:
from synthcity.plugins.core.models.time_to_event import \
    get_model_template as get_tte_model_template


class TabularGMM:
    def __init__(self, components: int = 100, random_state: int = 0):
        self.model = GMM(100, covariance_type="full", random_state=random_state)
        self.tte_regressor = get_tte_model_template("survival_function_regression")()

    def fit(self, X, T, E):
        self.model.fit(X)
        self.tte_regressor.fit(X, T, E)

        self.E = E
        self.count = len(X)
        self.columns = X.columns

    def generate(self, count: int = None):
        if count is None:
            count = self.count
        sampled, _ = self.model.sample(count)
        sampled = pd.DataFrame(sampled, columns=self.columns)

        E = self.E.reset_index(drop=True).head(count)

        T = pd.Series(self.tte_regressor.predict_any(sampled, E,))
        return sampled, T, E

<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.

<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.

<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.



In [3]:
import numpy as np
import scipy
from scipy.stats import norm


def censored_nll(params, durations, events):
    mu, std_dev = params
    observed_durations = durations[events == 1]  # or events==True
    censored_durations = durations[~(events == 1)]  # or events==True
    return -(
        np.log(1e-8 + norm.pdf(observed_durations, loc=mu, scale=std_dev)).sum()
        + np.log(1e-8 + norm.sf(censored_durations, loc=mu, scale=std_dev)).sum()
    )


def negative_log_likelihood(T, E):
    return scipy.optimize.minimize(
        censored_nll, method="Nelder-Mead", x0=np.array([0, 1]), args=(T, E)
    ).fun

In [14]:
from pathlib import Path

import tabulate
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

out_dir = Path("workspace")
headers = [
    "dataset",
    "method",
    "degrees of freedom/covariates",
    "with outcome",
    "censored",
    "log_likelihood",
    "log_likelihood_ratio_test",
]

for ref_df in ["aids", "cutract", "maggic", "seer"]:
    distances = []

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)
    model = RiskEstimation().get("cox_ph")

    X = df.drop(columns=[duration_col, event_col])
    T = df[duration_col]
    E = df[event_col]

    model.fit(X, T, E)

    ref_log_likelihood = -model.model.model.log_likelihood_
    ref_log_likelihood_ratio_test = (
        model.model.model.log_likelihood_ratio_test().test_statistic
    )
    censored = (E == 0).sum()
    outcome = (E == 1).sum()

    distances.append(
        [
            ref_df,
            "Real data",
            X.shape[1],
            censored,
            outcome,
            ref_log_likelihood,
            ref_log_likelihood_ratio_test,
        ]
    )

    for method in ["survival_gan", "ctgan", "nflow", "tvae", "privbayes"]:
        local_log_likelihood = []
        local_log_likelihood_ratio = []
        local_censored = []
        local_outcome = []
        for seed in range(3):
            model_bkp = out_dir / f"{df_hash}_{method}_{seed}.bkp"
            syn_df = load_from_file(model_bkp)

            Xsyn = syn_df.drop(columns=[duration_col, event_col])
            Tsyn = syn_df[duration_col]
            Esyn = syn_df[event_col]
            model = RiskEstimation().get("cox_ph")
            model.fit(Xsyn, Tsyn, Esyn)

            log_likelihood = -model.model.model.log_likelihood_
            log_likelihood_ratio_test = (
                model.model.model.log_likelihood_ratio_test().test_statistic
            )
            # other_log_likelihood = negative_log_likelihood(syn_df[duration_col], syn_df[event_col])

            local_log_likelihood.append(log_likelihood)
            local_log_likelihood_ratio.append(log_likelihood_ratio_test)
            local_censored.append((Esyn == 0).sum())
            local_outcome.append((Esyn == 1).sum())

        log_likelihood_str = print_score(generate_score(local_log_likelihood))
        log_likelihood_ratio_test_str = print_score(
            generate_score(local_log_likelihood_ratio)
        )
        censored_str = print_score(generate_score(local_censored))
        outcome_str = print_score(generate_score(local_outcome))

        distances.append(
            [
                ref_df,
                method,
                X.shape[1],
                censored_str,
                outcome_str,
                log_likelihood_str,
                log_likelihood_ratio_test_str,
            ]
        )
    tabulate.tabulate(distances, headers=headers, tablefmt="html")

distances

Evaluate  aids
summary 725.1744256224594 87.07914805031237
summary 579.7611765235978 160.9820293857499
summary 605.8748153791545 151.21944345920042
summary 1088.9391273014364 19.842472995845583
summary 876.3413535156371 15.202155522317753
summary 1047.4136412176324 67.99492827240738
summary 2256.745217037288 27.689669766745283
summary 2598.165643151686 42.71808069144117
summary 3070.908661414584 85.23526189618315
summary -0.0 0.0
summary 29.090882696880946 6.368913339048561
summary 6.627024835734486 0.6475798654161977
summary 130.4083230622987 6.937359196501745
summary 739.3935437802996 3.7917565861569074
summary 489.4890599234387 14.118437238718911
Evaluate  cutract
summary 8409.027030284571 1675.5190982276254
summary 8439.319123212346 1613.9388135831941
summary 16881.155979047864 4948.247420701933
summary 48079.6871022308 4547.000117689124
summary 48994.96792428096 2569.000571900673
summary 46082.64325910512 4007.1992522378277
summary 56413.043634462636 70.19856599083869
summary 5305

[['aids', 'Real data', 11, 1055, 96, 621.926143124486, 73.03276504216637],
 ['aids',
  'survival_gan',
  11,
  '1049.667 +/- 7.858',
  '101.333 +/- 7.858',
  '636.937 +/- 71.628',
  '133.094 +/- 37.094'],
 ['aids',
  'ctgan',
  11,
  '980.667 +/- 18.019',
  '170.333 +/- 18.019',
  '1004.231 +/- 104.116',
  '34.347 +/- 27.01'],
 ['aids',
  'nflow',
  11,
  '713.333 +/- 65.139',
  '437.667 +/- 65.139',
  '2641.94 +/- 377.752',
  '51.881 +/- 27.577'],
 ['aids',
  'tvae',
  11,
  '1149.0 +/- 2.445',
  '2.0 +/- 2.445',
  '11.906 +/- 14.087',
  '2.339 +/- 3.239'],
 ['aids',
  'privbayes',
  11,
  '1077.333 +/- 45.652',
  '73.667 +/- 45.652',
  '453.097 +/- 282.84',
  '8.283 +/- 4.891'],
 ['cutract', 'Real data', 6, 8881, 1205, 9788.128760249894, 1325.219160236502],
 ['cutract',
  'survival_gan',
  6,
  '8664.0 +/- 597.733',
  '1422.0 +/- 597.733',
  '11243.167 +/- 4511.352',
  '2745.902 +/- 1762.473'],
 ['cutract',
  'ctgan',
  6,
  '4208.333 +/- 153.733',
  '5877.667 +/- 153.733',
  '47719.

dataset,method,degrees of freedom/covariates,with outcome,censored,log_likelihood,log_likelihood_ratio_test
aids,Real data,11,1055,96,621.926143124486,73.03276504216637
aids,survival_gan,11,1049.667 +/- 7.858,101.333 +/- 7.858,636.937 +/- 71.628,133.094 +/- 37.094
aids,ctgan,11,980.667 +/- 18.019,170.333 +/- 18.019,1004.231 +/- 104.116,34.347 +/- 27.01
aids,nflow,11,713.333 +/- 65.139,437.667 +/- 65.139,2641.94 +/- 377.752,51.881 +/- 27.577
aids,tvae,11,1149.0 +/- 2.445,2.0 +/- 2.445,11.906 +/- 14.087,2.339 +/- 3.239
aids,privbayes,11,1077.333 +/- 45.652,73.667 +/- 45.652,453.097 +/- 282.84,8.283 +/- 4.891
cutract,Real data,6,8881,1205,9788.128760249894,1325.219160236502
cutract,survival_gan,6,8664.0 +/- 597.733,1422.0 +/- 597.733,11243.167 +/- 4511.352,2745.902 +/- 1762.473
cutract,ctgan,6,4208.333 +/- 153.733,5877.667 +/- 153.733,47719.099 +/- 1376.016,3707.733 +/- 944.685
cutract,nflow,6,3351.667 +/- 239.561,6734.333 +/- 239.561,55333.749 +/- 1824.276,112.582 +/- 40.116


In [7]:
## metabric

from pathlib import Path

import tabulate
from adjutorium.plugins.prediction.risk_estimation import RiskEstimation
from adjutorium.utils.metrics import generate_score, print_score
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import (dataframe_hash, load_from_file,
                                           save_to_file)

log.remove()

out_dir = Path("workspace_rebuttal")
headers = [
    "dataset",
    "method",
    "degrees of freedom/covariates",
    "with outcome",
    "censored",
    "log_likelihood",
    "log_likelihood_ratio_test",
]

for ref_df in ["metabric"]:
    distances = []

    print("=======================")
    print("Evaluate ", ref_df)

    df, duration_col, event_col, time_horizons = get_dataset(ref_df)
    df_hash = dataframe_hash(df)
    model = RiskEstimation().get("cox_ph")

    X = df.drop(columns=[duration_col, event_col])
    T = df[duration_col]
    E = df[event_col]

    model.fit(X, T, E)

    ref_log_likelihood = -model.model.model.log_likelihood_
    ref_log_likelihood_ratio_test = (
        model.model.model.log_likelihood_ratio_test().test_statistic
    )
    censored = (E == 0).sum()
    outcome = (E == 1).sum()

    distances.append(
        [
            ref_df,
            "Real data",
            X.shape[1],
            censored,
            outcome,
            ref_log_likelihood,
            ref_log_likelihood_ratio_test,
        ]
    )

    for method in ["survival_gan", "ctgan", "tvae"]:
        local_log_likelihood = []
        local_log_likelihood_ratio = []
        local_censored = []
        local_outcome = []
        for seed in range(3):
            model_bkp = out_dir / f"{df_hash}_{method}_{method}__{seed}.bkp"
            syn_df = load_from_file(model_bkp)
            try:
                syn_df = syn_df.dataframe()
            except:
                pass

            Xsyn = syn_df.drop(columns=[duration_col, event_col])
            Tsyn = syn_df[duration_col]
            Esyn = syn_df[event_col]
            model = RiskEstimation().get("cox_ph")
            model.fit(Xsyn, Tsyn, Esyn)

            log_likelihood = -model.model.model.log_likelihood_
            log_likelihood_ratio_test = (
                model.model.model.log_likelihood_ratio_test().test_statistic
            )
            # other_log_likelihood = negative_log_likelihood(syn_df[duration_col], syn_df[event_col])

            local_log_likelihood.append(log_likelihood)
            local_log_likelihood_ratio.append(log_likelihood_ratio_test)
            local_censored.append((Esyn == 0).sum())
            local_outcome.append((Esyn == 1).sum())

        log_likelihood_str = print_score(generate_score(local_log_likelihood))
        log_likelihood_ratio_test_str = print_score(
            generate_score(local_log_likelihood_ratio)
        )
        censored_str = print_score(generate_score(local_censored))
        outcome_str = print_score(generate_score(local_outcome))

        distances.append(
            [
                ref_df,
                method,
                X.shape[1],
                censored_str,
                outcome_str,
                log_likelihood_str,
                log_likelihood_ratio_test_str,
            ]
        )
    tabulate.tabulate(distances, headers=headers, tablefmt="html")

distances

Evaluate  metabric


[['metabric',
  'Real data',
  689,
  609,
  484,
  2023.0801846856543,
  1370.5835433178845],
 ['metabric',
  'survival_gan',
  689,
  '537.333 +/- 15.743',
  '336.667 +/- 15.743',
  '1345.862 +/- 141.743',
  '790.366 +/- 428.547'],
 ['metabric',
  'ctgan',
  689,
  '577.0 +/- 37.429',
  '297.0 +/- 37.429',
  '1189.339 +/- 188.613',
  '1046.335 +/- 40.113'],
 ['metabric',
  'tvae',
  689,
  '564.0 +/- 39.233',
  '310.0 +/- 39.233',
  '1204.746 +/- 367.507',
  '1047.265 +/- 244.918']]

In [8]:
import tabulate

tabulate.tabulate(distances, headers=headers, tablefmt="html")

dataset,method,degrees of freedom/covariates,with outcome,censored,log_likelihood,log_likelihood_ratio_test
metabric,Real data,689,609,484,2023.0801846856543,1370.5835433178845
metabric,survival_gan,689,537.333 +/- 15.743,336.667 +/- 15.743,1345.862 +/- 141.743,790.366 +/- 428.547
metabric,ctgan,689,577.0 +/- 37.429,297.0 +/- 37.429,1189.339 +/- 188.613,1046.335 +/- 40.113
metabric,tvae,689,564.0 +/- 39.233,310.0 +/- 39.233,1204.746 +/- 367.507,1047.265 +/- 244.918
