### Augmentation performance Benchmarks for tabular data

Imports

In [None]:
import pickle
import pandas as pd
from pathlib import Path

from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.benchmark import Benchmarks
import synthcity.logger as log

Set parameters for model

In [None]:
log.add("synthcity_logs", "INFO")

KWARGS = {"n_iter": 100}
KWARGS_str = "-".join([f"{k}:{v}" for k, v in KWARGS.items()])

Main functions

In [None]:

def preprocess_data(file_path, file, time_horizon=14):

    X = pd.read_csv(f"{file_path}/{file}")
    # Create the target column
    X.loc[
        (X["Days_hospital_to_outcome"] <= time_horizon) & (X["is_dead"] == 1),
        f"is_dead_at_time_horizon={time_horizon}",
    ] = 1
    X.loc[
        (X["Days_hospital_to_outcome"] > time_horizon),
        f"is_dead_at_time_horizon={time_horizon}",
    ] = 0
    X.loc[(X["is_dead"] == 0), f"is_dead_at_time_horizon={time_horizon}"] = 0
    X[f"is_dead_at_time_horizon={time_horizon}"] = X[
        f"is_dead_at_time_horizon={time_horizon}"
    ].astype(int)

    X.drop(columns=["is_dead", "Days_hospital_to_outcome"], inplace=True)
    return X

def run_dataset(loader, workspace_path, models):
    try:
        score = Benchmarks.evaluate(
            [(model, model, KWARGS) for model in models],
            loader.train(),
            loader.test(),
            task_type="classification",
            synthetic_size=loader.dataframe().shape[0],
            synthetic_reuse_if_exists=False,
            augmented_reuse_if_exists=False,
            augmentation_rule="equal", # equal, log, or ad-hoc
            metrics={
                "performance": [
                    "linear_model_augmentation",
                    "mlp_augmentation",
                    "xgb_augmentation",
                ],
            },
            workspace=workspace_path,
            repeats=1,
            device="cpu",
        )
        print(score)
    except Exception as e:
        print("\n\n", e)
        score = None

    return score

def create_absolute_path(cwd, path):
    if cwd.name not in ["tutorials", "tests", "synthcity-benckmarking"]:
        cwd = cwd / Path("../")
    path = (cwd / path).resolve()
    seen = set()
    seen_add= seen.add
    path = "/".join([p for p in path.split("/") if not (p in seen or seen_add(p))])

    return path

def run_synthcity(models=["ctgan"], save=False):
    cwd = Path.cwd()
    file_path = create_absolute_path(cwd, f"../data/augmentation/")
    workspace_path = create_absolute_path(cwd, f"../workspace/augmentation/")
    result_path = create_absolute_path(cwd, f"../results/augmentation")
    Path(result_path).mkdir(parents=True, exist_ok=True)

    # Load and Prep data
    # Save the preprocessed data to a csv file with the following name into the data/augmentation folder
    file = "covid_normalised_numericalised.csv"
    print(f"{file_path}/{file}")

    time_horizon = 14 # time horizon for prediction set to 14 days
    X = preprocess_data(file_path, file, time_horizon=time_horizon)

    loader = GenericDataLoader(
        X,
        target_column=f"is_dead_at_time_horizon={time_horizon}",
        sensitive_features=["Ethnicity"],
        fairness_column="Ethnicity",
        domain_column="Ethnicity",
        random_state=42,
    )

    score = run_dataset(loader, workspace_path, models)

    if score:
        Benchmarks.print(score)
        Benchmarks.highlight(score)
        if save:
            with open(f"{result_path}/{file}-{'-'.join(models)}-{KWARGS_str}.pkl", "wb") as f:
                pickle.dump(score, f)



In [None]:
run_synthcity(["ctgan", "tvae"])