In [1]:
from utils.config import load_datasets, get_train_dataset_indexes, get_test_dataset_indexes

folder = "../outputs/datasets"

dataset_indexes = get_train_dataset_indexes(folder)

datasets = load_datasets(folder=folder, names=dataset_indexes)


In [2]:
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader


# https://github.com/vanderschaarlab/synthcity/issues/249
def run_surv_gan(data_loader, device):
    model = Plugins().get("survival_gan", device=device)
    model.fit(data_loader)
    return model
    
def run_surv_vae(data_loader, device):
    model = Plugins().get("survae", device=device)
    model.fit(data_loader)
    return model

ModuleNotFoundError: No module named 'synthcity'

In [None]:
from synthcity.metrics.eval_sanity import CloseValuesProbability, DataMismatchScore, CommonRowsProportion, NearestSyntheticNeighborDistance, DistantValuesProbability

def evaluate_model(data_loader, generated_data):
    # Initialize metrics
    close = CloseValuesProbability()
    data_mismatch = DataMismatchScore()
    proportion = CommonRowsProportion()
    nn_distance = NearestSyntheticNeighborDistance()
    distant = DistantValuesProbability()
    
    # Evaluate metrics
    close_val = close.evaluate(data_loader, generated_data)['score']
    mis = data_mismatch.evaluate(data_loader, generated_data)['score']
    prop = proportion.evaluate(data_loader, generated_data)['score']
    nn_dist = nn_distance.evaluate(data_loader, generated_data)['mean']
    dist = distant.evaluate(data_loader, generated_data)['score']
    
    # Determine if metrics are within expected values
    correct = True

    # Define correctness based on descriptions
    if not (0 <= close_val <= 1):
        correct = False
    if not (0 <= mis <= 1):
        correct = False
    if not (0 <= prop <= 1):
        correct = False
    if nn_dist < 0:
        correct = False
    if not (0 <= dist <= 1):
        correct = False
    
    # Prepare results
    results = {
        'close_values': {
            "value": close_val,
            "description": "0 means there is no chance to have synthetic rows similar to the real. 1 means that all the synthetic rows are similar to some real rows."
        },
        'data_mismatch': {
            "value": mis,
            "description": "0: no datatype mismatch. 1: complete data type mismatch between the datasets."
        },
        'proportion': {
            "value": prop,
            "description": "0: there are no common rows between the real and synthetic datasets. 1: all the rows in the real dataset are leaked in the synthetic dataset.",
        },
        'nn_distance': {
            "value": nn_dist,
            "description": "Computes the distance from the real data to the closest neighbor in the synthetic data"
        },
        'distant_values': {
            "value": dist,
            "description": "0 means there is no chance to have rows in the synthetic far away from the real data. 1 means all the synthetic datapoints are far away from the real data."
        }
    }
    
    return results, correct


In [10]:
from utils.preprocess import impute_missing_values
from synthcity.utils.serialization import save_to_file
from utils.config import save_dataset, save_checkpoint, load_checkpoint
import torch
start_index = load_checkpoint()
device = torch.device('cuda')

for index, dataset_index in enumerate(dataset_indexes):
    if index <= start_index:
        continue
    print(f"training model on {dataset_index}")
    ds_train = datasets[dataset_index]
    ds_train = impute_missing_values(ds_train)
    # print(ds_train.head())
    data_loader = SurvivalAnalysisDataLoader(ds_train, target_column="event", time_to_event_column="time")
    surv_gan_model = run_surv_gan(data_loader, device)
    surv_vae_model = run_surv_vae(data_loader, device)
    
    save_to_file(f"../outputs/model_outputs/sim_model_{dataset_index}_gan.pkl", surv_gan_model)
    save_to_file(f"../outputs/model_outputs/sim_model_{dataset_index}_vae.pkl", surv_vae_model)
    
    generated_data_gan = surv_gan_model.generate(5000)
    generated_data_vae = surv_vae_model.generate(5000)
    
    _, eval_gan = evaluate_model(data_loader, generated_data_gan)
    _, eval_vae = evaluate_model(data_loader, generated_data_vae)
    
    print(f"training completed, eval: gan:{eval_gan} vae:{eval_vae}")
    
    generated_data_vae = generated_data_vae.dataframe()
    generated_data_gan = generated_data_gan.dataframe()
    
    save_dataset(generated_data_gan, f"{dataset_index}_gan","../outputs/generated_datasets")
    save_dataset(generated_data_vae, f"{dataset_index}_vae","../outputs/generated_datasets")
    save_checkpoint(index)
    

[2024-08-11T17:48:28.727445+0200][47092][CRITICAL] module disabled: C:\Users\johan\anaconda3\envs\research_project_simulation_env\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2024-08-11T17:48:28.888376+0200][47092][CRITICAL] module disabled: C:\Users\johan\anaconda3\envs\research_project_simulation_env\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py


training model on cost_train


  5%|▍         | 499/10000 [00:23<07:23, 21.40it/s]
[2024-08-11T17:48:53.035792+0200][47092][CRITICAL] module disabled: C:\Users\johan\anaconda3\envs\research_project_simulation_env\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2024-08-11T17:48:53.066791+0200][47092][CRITICAL] module disabled: C:\Users\johan\anaconda3\envs\research_project_simulation_env\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
 40%|████      | 400/1000 [00:17<00:26, 22.63it/s]
[2024-08-11T17:49:12.705438+0200][47092][CRITICAL] module disabled: C:\Users\johan\anaconda3\envs\research_project_simulation_env\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2024-08-11T17:49:12.805151+0200][47092][CRITICAL] module disabled: C:\Users\johan\anaconda3\envs\research_project_simulation_env\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py


training completed, eval: gan:True vae:True
training model on d.oropha.rec_train


ValueError: Need `time` to have same type as `self.durations`.