# Single Table Analysis Template

In [None]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]

In [None]:
import os

VIEW = os.getenv("DATASET_VIEW") or "tab_adult"
TABLE = os.getenv("DATASET_TABLE") or "table"
ALG = os.getenv("SYNTH_ALG") or "privbayes"

import numpy as np
import pandas as pd
from pasteur.transform.table import TableTransformer

wrk: pd.DataFrame = catalog.load(f"{VIEW}.wrk.idx_{TABLE}")
alg: pd.DataFrame = catalog.load(f"{VIEW}.{ALG}.idx_{TABLE}")
tst: pd.DataFrame = catalog.load(f"{VIEW}.tst.idx_{TABLE}")

trn: TableTransformer = catalog.load(f"{VIEW}.wrk.trn_{TABLE}")

from pasteur.metadata import Metadata
params = catalog.load(f"parameters")
meta = Metadata.from_kedro_params(params, VIEW, {TABLE: wrk}).get_table(TABLE)

random_state = catalog.load("params:random_state")

In [None]:
cols = list(wrk.keys())

a = wrk[cols].to_numpy(dtype="uint16")
b = alg[cols].to_numpy(dtype="uint16")
c = tst[cols].to_numpy(dtype="uint16")

domain = np.concatenate([a, b, c]).max(axis=0) + 1

In [None]:
from pasteur.synth.privbayes import calc_marginal_1way

# Add at least one sample prob to distr chisquare valid
zero_fill = 1/len(a)

def marg(s: np.ndarray, names: list[str]):
    x = [cols.index(name) for name in names]
    return calc_marginal_1way(s, domain, x, True, zero_fill)

In [None]:
from scipy.stats import chisquare

res = []
for col in cols:
    k = marg(a, [col])
    j = marg(c, [col])
    chi, p = chisquare(k, j)
    # print(f"{name:12}.{col:20}: X^2={chi:3.3f} p={100*p:7.3f}%")
    res.append([col, chi, p])

res = pd.DataFrame(res, columns=["col", "X^2", "p"])
res.set_index(keys=["col"]).sort_index().style.background_gradient(axis=0)

<pandas.io.formats.style.Styler object at 0x7fde3fb8e5f0>


Unnamed: 0_level_0,X^2,p
col,Unnamed: 1_level_1,Unnamed: 2_level_1
age,0.001818,1.0
capital-gain,0.000749,1.0
capital-loss,0.001481,1.0
education,0.00114,1.0
education-num,0.00114,1.0
fnlwgt,0.002013,1.0
hours-per-week,0.003205,1.0
income,9.5e-05,0.992243
marital-status,0.000498,1.0
native-country,0.003623,1.0


In [None]:
from scipy.special import rel_entr

res = []
for col_i in cols:
    for col_j in cols:
        k = marg(a, [col_i, col_j])
        j = marg(c, [col_i, col_j])
        
        kl = rel_entr(k, j).sum()
        kl_norm = 1 / (1 + kl)
        res.append([col_i, col_j, kl, kl_norm, len(k)])

res = pd.DataFrame(
    res,
    columns=[
        "col_i",
        "col_j",
        "kl",
        "kl_norm",
        "mlen",
    ],
)
res.head().style

<pandas.io.formats.style.Styler object at 0x7fde411bfbb0>


Unnamed: 0,col_i,col_j,kl,kl_norm,mlen
0,age,age,0.000909,0.999092,1089
1,age,workclass,0.009549,0.990542,297
2,age,fnlwgt,0.017001,0.983283,1089
3,age,education,0.018385,0.981947,528
4,age,education-num,0.018385,0.981947,528


In [None]:
def mk_kl_plot(filter=None, val="kl_norm"):
    pt = res[filter] if filter is not None else res
    pt = pt.pivot_table(
        values=val, index=["col_i"], columns=["col_j"]
    )

    # Try to tweak colormap
    vmin = vmax = None
    # match val:
    #     case "kl_norm":
    #         vmin = 0.996
    #     case "kl":
    #         vmax = 0.04
    #     case "mlen":
    #         vmax = 100

    # pt = pt.reindex(index=idx, columns=cols)
    pt = pt.style.background_gradient(axis=None, vmin=vmin, vmax=vmax).applymap(
        lambda x: "color: transparent; background-color: transparent"
        if pd.isnull(x)
        else ""
    )

    return pt


In [None]:
mk_kl_plot()

<pandas.io.formats.style.Styler object at 0x7fde401c3fd0>


col_j,age,capital-gain,capital-loss,education,education-num,fnlwgt,hours-per-week,income,marital-status,native-country,occupation,race,relationship,sex,workclass
col_i,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
age,0.999092,0.993576,0.989944,0.981947,0.981947,0.983283,0.970795,0.997874,0.993315,0.971974,0.982241,0.994418,0.992813,0.997862,0.990542
capital-gain,0.993576,0.999646,0.998975,0.995962,0.995962,0.995505,0.993548,0.999456,0.998008,0.995224,0.995769,0.999017,0.998117,0.999082,0.997691
capital-loss,0.989944,0.998975,0.99933,0.993175,0.993175,0.993721,0.991688,0.998912,0.996861,0.995208,0.993671,0.99796,0.996815,0.998756,0.995885
education,0.981947,0.995962,0.993175,0.999428,0.999428,0.991199,0.982406,0.998556,0.996235,0.982727,0.990311,0.99658,0.996914,0.997753,0.995466
education-num,0.981947,0.995962,0.993175,0.999428,0.999428,0.991199,0.982406,0.998556,0.996235,0.982727,0.990311,0.99658,0.996914,0.997753,0.995466
fnlwgt,0.983283,0.995505,0.993721,0.991199,0.991199,0.999168,0.984487,0.998461,0.99535,0.986219,0.988499,0.997092,0.995149,0.998269,0.995434
hours-per-week,0.970795,0.993548,0.991688,0.982406,0.982406,0.984487,0.99853,0.996875,0.992768,0.980495,0.986004,0.995115,0.992923,0.997414,0.992336
income,0.997874,0.999456,0.998912,0.998556,0.998556,0.998461,0.996875,0.999953,0.999339,0.997195,0.998192,0.99969,0.999733,0.999878,0.999274
marital-status,0.993315,0.998008,0.996861,0.996235,0.996235,0.99535,0.992768,0.999339,0.999744,0.990499,0.995741,0.999166,0.999195,0.999336,0.997589
native-country,0.971974,0.995224,0.995208,0.982727,0.982727,0.986219,0.980495,0.997195,0.990499,0.998185,0.982576,0.994955,0.989265,0.996231,0.990752
