# Single Table Models Template

In [None]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]


In [None]:
import os

VIEW = os.getenv("DATASET_VIEW") or "tab_adult"
TABLE = os.getenv("DATASET_TABLE") or "table"
ALG = os.getenv("SYNTH_ALG") or "ref"
MULTI_PROCESS = (
    os.getenv("MULTI_PROCESS") if os.getenv("MULTI_PROCESS") is not None else True
)

import numpy as np
import pandas as pd

wrk: pd.DataFrame = catalog.load(f"{VIEW}.wrk.{TABLE}")
alg: pd.DataFrame = catalog.load(f"{VIEW}.{ALG}.{TABLE}")
dev: pd.DataFrame = catalog.load(f"{VIEW}.dev.{TABLE}")

from pasteur.metadata import Metadata

meta = catalog.load(f"params:{VIEW}.metadata")
meta = Metadata(meta, wrk).get_table(TABLE)

random_state = catalog.load("params:random_state")


2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `tab_adult.wrk.table` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `tab_adult.ref.table` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `tab_adult.dev.table` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `params:tab_adult.metadata` (MemoryDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `params:random_state` (MemoryDataSet)...


In [None]:
wrk.head()


Unnamed: 0,id,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country
0,4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba
1,6,49,Private,160187,9th,5,Married-spouse-absent,Other-service,Not-in-family,Black,Female,0,0,16,Jamaica
2,7,52,Self-emp-not-inc,209642,HS-grad,9,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,45,United-States
3,9,42,Private,159449,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,5178,0,40,United-States
4,10,37,Private,280464,Some-college,10,Married-civ-spouse,Exec-managerial,Husband,Black,Male,0,0,80,United-States


In [None]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split

from sklearn.svm import SVC
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.naive_bayes import GaussianNB, CategoricalNB
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor, RandomForestRegressor, RandomForestClassifier

models = [
    ("lr", LogisticRegression, LinearRegression),
    ("svm", SVC, SVC),
    ("tree", DecisionTreeClassifier, DecisionTreeRegressor),
    ("forest", RandomForestClassifier, RandomForestRegressor),
    ("gb", GradientBoostingClassifier, GradientBoostingRegressor),
    # ("bayes", GaussianNB, None),
]
ratio = 0.2
targets = [*meta.targets, *meta.sensitive]


def fit_data(target_col: str, train: pd.DataFrame, *tests: List[pd.DataFrame]):
    columns = []

    for name, col in meta.cols.items():
        if name == target_col:
            continue

        if col.is_id():
            continue
        elif col.is_cat():
            columns.append(
                (name, OneHotEncoder(handle_unknown="infrequent_if_exist"), [name])
            )
        else:
            columns.append((name, StandardScaler(), [name]))

    trans = ColumnTransformer(
        columns, remainder="drop", verbose_feature_names_out=False
    )

    train_t = trans.fit_transform(train)
    test_t = [trans.transform(test) for test in tests]
    return train_t, *test_t


def test_alg(args):
    target, type, train, test, model, clf_l, clf_r = args
    x_train, x_test, x_wrk, x_dev = fit_data(target, train, test, wrk, dev)
    y_train, y_test, y_wrk, y_dev = (
        train[target],
        test[target],
        wrk[target],
        dev[target],
    )

    clf_c = clf_l if meta[target].is_cat() else clf_r

    if clf_c is None:
        return

    clf = clf_c()
    clf.fit(x_train, y_train)

    res_train = clf.score(x_train, y_train)
    res_dev = clf.score(x_dev, y_dev)

    if type == "alg":
        res_test = clf.score(x_test, y_test)
        res_wrk = clf.score(x_wrk, y_wrk)
    else:
        res_test = np.NAN
        res_wrk = np.NAN

    return (model, type, target, res_train, res_test, res_wrk, res_dev)


def test_targets():
    jobs = []

    for model, clf_l, clf_r in models:
        for type, data in [("wrk", wrk), ("alg", alg)]:
            train, test = train_test_split(
                data, test_size=ratio, random_state=random_state
            )

            for target in targets:
                jobs.append((target, type, train, test, model, clf_l, clf_r))

    from tqdm import tqdm
    if MULTI_PROCESS:
        from tqdm.contrib.concurrent import process_map

        target_res = process_map(test_alg, jobs, tqdm_class=tqdm)
    else:
        target_res = tqdm(map(test_alg, jobs), total=len(jobs))

    target_res = pd.DataFrame(
        target_res,
        columns=[
            "model",
            "data",
            "target",
            "train_results",
            "test_results",
            "wrk_results",
            "dev_results",
        ],
    )
    return target_res


target_res = test_targets()
target_res.head()


100%|██████████| 30/30 [00:00<00:00,  1.00it/s]


Unnamed: 0,model,data,target,train_results,test_results,wrk_results,dev_results
0,lr,wrk,education,0.873884,,,0.855344
1,lr,wrk,race,0.878203,,,0.876843
2,lr,wrk,relationship,0.790767,,,0.782862
3,lr,alg,education,0.867262,0.849136,0.844287,0.84398
4,lr,alg,race,0.877243,0.874856,0.877764,0.877457


In [None]:
res = pd.merge(
    target_res[target_res["data"] == "wrk"],
    target_res[target_res["data"] == "alg"],
    on=["model", "target"],
    suffixes=["_wrk", "_alg"],
)

columns = {
    "model": "model",
    "target": "target",
    "train_results_wrk": "orig_train",
    "train_results_alg": "synth_train",
    "dev_results_wrk": "orig_test",
    "test_results_alg": "synth_test_synth",
    "dev_results_alg": "synth_test_real",
    "privacy_leak": "privacy_leak",
}

res["privacy_leak"] = res["wrk_results_alg"] - res["dev_results_alg"]

res = res.rename(columns=columns)[columns.values()]

from IPython.display import display

for target in targets:
    if target in meta.targets:
        caption = "Target: "
    elif target in meta.sensitive:
        caption = "Sensitive: "
    else:
        caption = ""
    caption += target.capitalize()

    pt = res[res["target"] == target].drop(columns=["target"]).set_index("model")

    pt = (
        pt.style.set_caption(caption)
        .format(lambda x: f"{100*x:.1f}%")
        .background_gradient(axis=0)
        .applymap(
            lambda x: "color: transparent; background-color: transparent"
            if pd.isnull(x)
            else ""
        )
    )

    display(pt)


Unnamed: 0_level_0,orig_train,synth_train,orig_test,synth_test_synth,synth_test_real,privacy_leak
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
lr,87.4%,86.7%,85.5%,84.9%,84.4%,0.0%
svm,98.5%,98.1%,97.0%,95.7%,96.0%,-0.0%
tree,100.0%,100.0%,100.0%,100.0%,100.0%,0.0%
forest,100.0%,100.0%,89.5%,89.1%,89.0%,-0.1%
gb,100.0%,100.0%,100.0%,100.0%,100.0%,0.0%


Unnamed: 0_level_0,orig_train,synth_train,orig_test,synth_test_synth,synth_test_real,privacy_leak
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
lr,87.8%,87.7%,87.7%,87.5%,87.7%,0.0%
svm,87.4%,87.1%,86.9%,86.8%,86.9%,0.2%
tree,100.0%,100.0%,79.8%,77.9%,78.7%,0.1%
forest,100.0%,100.0%,87.1%,87.1%,86.9%,0.1%
gb,89.5%,89.2%,87.2%,87.0%,87.4%,0.2%


Unnamed: 0_level_0,orig_train,synth_train,orig_test,synth_test_synth,synth_test_real,privacy_leak
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
lr,79.1%,78.7%,78.3%,77.5%,78.5%,0.0%
svm,80.9%,80.8%,79.1%,78.6%,79.1%,0.0%
tree,100.0%,100.0%,72.0%,70.6%,71.0%,0.9%
forest,100.0%,100.0%,77.7%,77.1%,77.2%,0.6%
gb,82.1%,82.2%,79.1%,78.3%,79.4%,-0.0%
