## Aggregating results to DataFrame

In [1]:
import os
import lib
import numpy as np
import pandas as pd

DATASETS = [
    "abalone",
    "adult",
    "buddy",
    "california",
    "cardio",
    "churn2",
    "default",
    "diabetes",
    "fb-comments",
    "gesture",
    "higgs-small",
    "house",
    "insurance",
    "king",
    "miniboone",
    "wilt"
]

_REGRESSION = [
    "abalone",
    "california",
    "fb-comments",
    "house",
    "insurance",
    "king",
]


method2exp = {
    "real": "exp/{}/ddpm_cb_best/",
    "tab-ddpm": "exp/{}/ddpm_cb_best/",
    "smote": "exp/{}/smote/",
    "ctabgan+": "exp/{}/ctabgan-plus/",
    "ctabgan": "exp/{}/ctabgan/",
    "tvae": "exp/{}/tvae/"
}

eval_file = "eval_catboost.json"
show_std = False
df = pd.DataFrame(columns=["method"] + [_[:3].upper() for _ in DATASETS])

for algo in method2exp: 
    algo_res = []
    for ds in DATASETS:
        if not os.path.exists(os.path.join(method2exp[algo].format(ds), eval_file)):
            algo_res.append("--")
            continue
        metric = "r2" if ds in _REGRESSION else "f1"
        res_dict = lib.load_json(os.path.join(method2exp[algo].format(ds), eval_file))

        if algo == "real":
            res = f'{res_dict["real"]["test"][metric + "-mean"]:.4f}' 
            if show_std: res += f'+-{res_dict["real"]["test"][metric + "-std"]:.4f}'
        else:
            res = f'{res_dict["synthetic"]["test"][metric + "-mean"]:.4f}'
            if show_std: res += f'+-{res_dict["synthetic"]["test"][metric + "-std"]:.4f}'

        algo_res.append(res)
    df.loc[len(df)] = [algo] + algo_res

In [2]:
df

Unnamed: 0,method,ABA,ADU,BUD,CAL,CAR,CHU,DEF,DIA,FB-,GES,HIG,HOU,INS,KIN,MIN,WIL
0,real,0.5562,0.8152,0.9063,0.8568,0.7379,0.7403,0.688,0.7849,0.8371,0.6365,0.7238,0.6616,0.8137,0.9070,0.9342,0.8982
1,tab-ddpm,0.5499,0.7951,0.9057,0.8362,0.7374,0.7548,0.691,0.7398,0.7128,0.5967,0.7218,0.6766,0.8092,0.8331,0.9362,0.9045
2,smote,0.5486,0.7912,0.8906,0.8397,0.7323,0.7432,0.693,0.6835,0.8035,0.6579,0.7219,0.6625,0.8119,0.8416,0.9323,0.9127
3,ctabgan+,0.4672,0.7724,0.8844,0.5247,0.7327,0.7024,0.6865,0.7339,0.5088,0.4055,0.6639,0.5040,0.7966,0.4438,0.892,0.7983
4,ctabgan,--,0.7831,0.8552,--,0.7171,0.6875,0.6437,0.731,--,0.3922,0.5748,--,--,--,0.8892,0.906
5,tvae,0.4328,0.781,0.8638,0.7518,0.7174,0.7317,0.6564,0.7136,0.6853,0.434,0.6378,0.4926,0.7842,0.8238,0.9125,0.5006
