In [28]:
import fairlearn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from fairlearn.datasets import fetch_diabetes_hospital
from sdv.utils import load_synthesizer
import pickle

from scipy.stats import ks_2samp, chi2_contingency
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, average_precision_score
from sklearn.feature_selection import mutual_info_classif

import os
from pathlib import Path

In [2]:
data = fetch_diabetes_hospital(as_frame=True)

X = data.data.copy()
y = data.target.copy()

dropped_columns = ['readmitted', 'readmit_binary']
X = X.drop(columns=dropped_columns)

real_data = X.copy()
real_data['readmit_binary'] = (y == 1)

real_train, real_test = train_test_split(
    real_data,
    test_size=0.2,
    random_state=66,
    stratify=real_data['readmit_binary']
)

real_train = real_train.reset_index(drop=True)
real_test = real_test.reset_index(drop=True)

real_train.shape, real_test.shape

((81412, 23), (20354, 23))

In [6]:
gc_path = Path("../artifacts/gaussian_copuula_diabetes.pkl")

In [10]:
def load_model_syn_data(model_path, sample_len):
    if model_path.exists():
        with model_path.open("rb") as f:
            model = pickle.load(f)
        synthetic_dataset = model.sample(num_rows=sample_len)
        return model, synthetic_dataset

In [11]:
gc_model, gc_gendata = load_model_syn_data(gc_path, len(real_train))
gc_gendata.head()

Unnamed: 0,race,gender,age,discharge_disposition_id,admission_source_id,time_in_hospital,medical_specialty,num_lab_procedures,num_procedures,num_medications,...,A1Cresult,insulin,change,diabetesMed,medicare,medicaid,had_emergency,had_inpatient_days,had_outpatient_days,readmit_binary
0,Caucasian,Male,'Over 60 years','Discharged to Home',Emergency,4,Missing,54,3,9,...,,Steady,Ch,Yes,True,False,False,True,False,False
1,Caucasian,Male,'30 years or younger','Discharged to Home',Emergency,6,Family/GeneralPractice,29,0,15,...,,No,Ch,Yes,False,False,False,False,True,False
2,Caucasian,Male,'30-60 years','Discharged to Home',Emergency,4,Emergency/Trauma,11,0,5,...,,Steady,No,Yes,True,False,False,False,False,False
3,Caucasian,Male,'30-60 years','Discharged to Home',Emergency,2,Missing,6,6,8,...,,No,Ch,No,False,False,False,True,False,True
4,Unknown,Female,'Over 60 years','Discharged to Home',Other,6,InternalMedicine,53,0,27,...,,No,Ch,Yes,False,False,False,False,True,False


In [12]:
synthetic_datasets = {
    "GaussianCopula": gc_gendata,
}

In [13]:
Target = 'readmit_binary'

def infer_column_types(df: pd.DataFrame, target: str):
    columns = [col for col in df.columns if col != target]
    numeric_columns = []
    categorical_columns = []

    for col in columns:
        att = df[col]
        if pd.api.types.is_bool_dtype(att):
            categorical_columns.append(col)
        elif pd.api.types.is_numeric_dtype(att):
            numeric_columns.append(col)
        else:
            categorical_columns.append(col)
    
    return numeric_columns, categorical_columns

In [14]:
numeric_columns, categorical_columns = infer_column_types(real_train, Target)
len(numeric_columns), len(categorical_columns)

(5, 17)

In [15]:
numeric_columns[:], categorical_columns[:10]

(['time_in_hospital',
  'num_lab_procedures',
  'num_procedures',
  'num_medications',
  'number_diagnoses'],
 ['race',
  'gender',
  'age',
  'discharge_disposition_id',
  'admission_source_id',
  'medical_specialty',
  'primary_diagnosis',
  'max_glu_serum',
  'A1Cresult',
  'insulin'])

In [22]:
def as_safe_string_series(s : pd.Series) -> pd.Series:
    s_obj = s.astype('object')
    s_obj = s_obj.where(~s_obj.isna(), "MISSING")
    return s_obj.astype(str)

In [37]:
# Statistical similarity
def ks_similarity_table(real_data, syn_data, numeric_columns):
    rows = []
    for col in numeric_columns:
        r = pd.to_numeric(real_data[col], errors="coerce").dropna()
        s = pd.to_numeric(syn_data[col], errors="coerce").dropna()
        
        if len(r) < 10 or len(s) < 10:
            continue

        stat, p = ks_2samp(r, s)
        rows.append({
            "column": col,
            "ks_stat": float(stat),
            "ks_pvalue": float(p),
            "real_mean": float(r.mean()),
            "syn_mean": float(s.mean()),
            "real_std": float(r.std(ddof=1)),
            "syn_std": float(s.std(ddof=1)),
        })
    out = pd.DataFrame(rows).sort_values("ks_stat", ascending=False)
    return out

In [17]:
def summarize_ks(ks_df):
    if ks_df.empty:
        return {}
    return {
        "ks_mean": float(ks_df["ks_stat"].mean()),
        "ks_median": float(ks_df["ks_stat"].median()),
        "ks_worst": float(ks_df["ks_stat"].max())
    }

In [18]:
def tvd(p : pd.Series, q : pd.Series) -> float:
    idx = p.index.union(q.index)
    p2 = p.reindex(idx, fill_value=0.0)
    q2 = q.reindex(idx, fill_value=0.0)
    return 0.5 * float(np.abs(p2 - q2).sum())

In [49]:
def chi2_and_tvd_table(real_data, syn_data, cat_cols):
    rows = []
    for col in cat_cols:
        r = as_safe_string_series(real_data[col])
        s = as_safe_string_series(syn_data[col])

        all_cats = sorted(set(r.unique()).union(set(s.unique())))
        r_counts = r.value_counts().reindex(all_cats, fill_value=0)
        s_counts = s.value_counts().reindex(all_cats, fill_value=0)

        table = np.vstack([r_counts.values, s_counts.values])

        chi2, p, dof, expected = chi2_contingency(table)

        r_dist = (r_counts / r_counts.sum())
        s_dist = (s_counts / s_counts.sum())

        rows.append({
            "column" : col,
            "chi2_stat" : float(chi2),
            "chi2_pvalue" : float(p),
            "tvd": float(tvd(r_dist, s_dist)),
            "real_unique" : int(r.nunique()),
            "syn_unique" : int(s.nunique()),
            "missing_rate_real" : float(real_data[col].isna().mean()),
            "missing_rate_synth" : float(syn_data[col].isna().mean())
        })
    
    out = pd.DataFrame(rows).sort_values("tvd", ascending=False)
    return out

In [50]:
def summarize_cat(cat_df):
    if cat_df.empty:
        return {}
    return {
        "tvd_mean" : float(cat_df["tvd"].mean()),
        "tvd_median" : float(cat_df["tvd"].median()),
        "tvd_worst" : float(cat_df["tvd"].max()),
    }

In [51]:
def numeric_corr_drift(real_df, syn_df, numeric_cols):
    if len(numeric_cols) < 2:
        return {"corr_mae": np.nan}

    r = real_df[numeric_cols].apply(pd.to_numeric, errors="coerce")
    s = syn_df[numeric_cols].apply(pd.to_numeric, errors="coerce")

    r_corr = r.corr(method="pearson").fillna(0.0)
    s_corr = s.corr(method="pearson").fillna(0.0)

    diff = (r_corr - s_corr).abs()
    
    mask = ~np.eye(diff.shape[0], dtype=bool)
    mae = float(diff.values[mask].mean())
    return {"corr_mae": mae}

In [60]:
# Downstream ML utility
def make_clf_pipline(numeric_columns, categorical_columns):
    numeric_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler())
    ])

    cat_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("onehot", OneHotEncoder(handle_unknown="ignore"))
    ])

    pre = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, numeric_columns),
            ("cat", cat_pipe, categorical_columns),
        ],
        remainder='drop'
    )

    clf = LogisticRegression(
        max_iter=2000,
        solver='saga',
        n_jobs=-1,
        class_weight='balanced',
    )

    return Pipeline(steps=[
        ("pre", pre),
        ("clf", clf)
    ])

In [61]:
def eval_train_test(train_df, test_df, target, numeric_columns, categorical_columns):
    y_train = train_df[target].astype(int)
    y_test = test_df[target].astype(int)

    X_train = train_df.drop(columns=[target])
    X_test = test_df.drop(columns=[target])

    pipe = make_clf_pipline(numeric_columns, categorical_columns)
    pipe.fit(X_train, y_train)

    proba = pipe.predict_proba(X_test)[:, 1]
    pred = (proba >= 0.5).astype(int)

    return {
        "roc_auc" : float(roc_auc_score(y_test, proba)),
        "avg_precision" : float(average_precision_score(y_test, proba)),
        "accuracy" : float(accuracy_score(y_test, pred)),
        "f1" : float(f1_score(y_test, pred)),
    }

In [62]:
def numeric_out_of_range_table(real_df, syn_df, numeric_columns):
    rows = []

    for col in numeric_columns:
        r = pd.to_numeric(real_df[col], errors="coerce")
        s = pd.to_numeric(syn_df[col], errors="coerce")

        r_valid = r.dropna()
        if r_valid.empty:
            continue

        rmin, rmax = float(r_valid.min()), float(r_valid.max())

        s_valid = s.dropna()
        out_mask = (s_valid < rmin) | (s_valid > rmax)
        out_rate = float(out_mask.mean()) if len(s_valid) else np.nan

        rows.append({
            "column" : col,
            "real_min" : rmin,
            "real_max" : rmax,
            "syn_out_of_range_rate" : out_rate,
            "syn_out_of_range_count" : int(out_mask.sum()) if len(s_valid) else 0
        })
    
    return pd.DataFrame(rows).sort_values("syn_out_of_range_rate", ascending=False)

In [63]:
def categorical_unseen_table(real_df, syn_df, categorical_columns):
    rows = []
    
    for col in categorical_columns:
        r = as_safe_string_series(real_df[col])
        s = as_safe_string_series(syn_df[col])

        real_set = set(r.unique())
        syn_set = set(s.unique())

        unseen = syn_set - real_set
        unseen_rate = float((~s.isin(list(real_set))).mean())

        rows.append({
            "column" : col,
            "syn_unseen_rate" : unseen_rate,
            "n_unseen_categories" : int(len(unseen)),
            "unseen_examples" : ", ".join(list(sorted(unseen))[:10])
        })
    
    return pd.DataFrame(rows).sort_values("syn_unseen_rate", ascending=False)

In [64]:
def missingness_change_table(real_df, syn_df):
    real_miss = real_df.isna().mean()
    syn_miss = syn_df.isna().mean()
    out = pd.DataFrame({
        "missing_rate_real" : real_miss,
        "missing_rate_synth" : syn_miss,
        "abs_diff" : (real_miss - syn_miss).abs()
    }).sort_values("abs_diff", ascending=False)

    return out

In [65]:
def rare_category_collapse_table(real_df, syn_df, categorical_columns, rare_thresh=0.01):
    rows = []

    for col in categorical_columns:
        r = as_safe_string_series(real_df[col])
        s = as_safe_string_series(syn_df[col])

        r_freq = r.value_counts(normalize=True)
        s_freq = s.value_counts(normalize=True)

        rare_cats = r_freq[r_freq < rare_thresh].index.tolist()
        if len(rare_cats) == 0:
            continue

        n_rare = len(rare_cats)
        n_rare_present = int((s_freq.reindex(rare_cats, fill_value=0.0) > 0).sum())
        recall = n_rare_present / n_rare if n_rare else np.nan

        rare_mass_real = float(r_freq.reindex(rare_cats, fill_value=0.0).sum())
        rare_mass_syn  = float(s_freq.reindex(rare_cats, fill_value=0.0).sum())

        missing_rare = [c for c in rare_cats if s_freq.get(c, 0.0) == 0.0]

        rows.append({
            "column": col,
            "n_rare_real": n_rare,
            "rare_recall": float(recall),
            "rare_mass_real": rare_mass_real,
            "rare_mass_syn": rare_mass_syn,
            "rare_mass_drop": float(rare_mass_real - rare_mass_syn),
            "missing_rare_examples": ", ".join(missing_rare[:10])
        })

    return pd.DataFrame(rows).sort_values("rare_recall", ascending=True)

In [66]:
def evaluate_one_model(
    name : str,
    real_train : pd.DataFrame,
    real_test : pd.DataFrame,
    syn_train : pd.DataFrame,
    target : str
):

    syn_train = syn_train.reindex(columns=real_train.columns)

    num_cols, cat_cols = infer_column_types(real_train, target)

    ks_df = ks_similarity_table(real_train, syn_train, num_cols)
    cat_df = chi2_and_tvd_table(real_train, syn_train, cat_cols)
    corr_summary = numeric_corr_drift(real_train, syn_train, num_cols)

    tstr = eval_train_test(syn_train, real_test, target, num_cols, cat_cols)  
    rtst = eval_train_test(real_train, syn_train, target, num_cols, cat_cols)

    oor_df = numeric_out_of_range_table(real_train, syn_train, num_cols)
    unseen_df = categorical_unseen_table(real_train, syn_train, cat_cols)
    miss_df = missingness_change_table(real_train, syn_train)
    rare_df = rare_category_collapse_table(real_train, syn_train, cat_cols, rare_thresh=0.01)

    summary = {
        "model": name,
        "n_rows": int(len(syn_train)),
        **summarize_ks(ks_df),
        **summarize_cat(cat_df),
        **corr_summary,
        "tstr_roc_auc": tstr["roc_auc"],
        "tstr_accuracy": tstr["accuracy"],
        "tstr_f1": tstr["f1"],
        "rtst_roc_auc": rtst["roc_auc"],
        "rtst_accuracy": rtst["accuracy"],
        "rtst_f1": rtst["f1"],
        "avg_out_of_range_rate_num": float(oor_df["syn_out_of_range_rate"].mean()) if not oor_df.empty else np.nan,
        "avg_unseen_rate_cat": float(unseen_df["syn_unseen_rate"].mean()) if not unseen_df.empty else np.nan,
        "mean_abs_missingness_diff": float(miss_df["abs_diff"].mean()),
        "mean_rare_recall": float(rare_df["rare_recall"].mean()) if not rare_df.empty else np.nan,
    }

    return {
        "summary": summary,
        "ks_table": ks_df,
        "cat_table": cat_df,
        "out_of_range_table": oor_df,
        "unseen_category_table": unseen_df,
        "missingness_table": miss_df,
        "rare_collapse_table": rare_df,
        "tstr": tstr,
        "rtst": rtst,
    }

In [67]:
all_results = {}
summary_rows = []

for name, syn_df in synthetic_datasets.items():
    res = evaluate_one_model(name, real_train, real_test, syn_df, Target)
    all_results[name] = res
    summary_rows.append(res["summary"])

summary_df = pd.DataFrame(summary_rows).sort_values("tstr_roc_auc", ascending=False)
summary_df



Unnamed: 0,model,n_rows,ks_mean,ks_median,ks_worst,tvd_mean,tvd_median,tvd_worst,corr_mae,tstr_roc_auc,tstr_accuracy,tstr_f1,rtst_roc_auc,rtst_accuracy,rtst_f1,avg_out_of_range_rate_num,avg_unseen_rate_cat,mean_abs_missingness_diff,mean_rare_recall
0,GaussianCopula,81412,0.081813,0.052351,0.184494,0.001563,0.001056,0.005024,0.037062,0.608109,0.58755,0.231016,0.531826,0.579349,0.190555,0.0,0.0,0.0,1.0
