# Single Table Analysis Template

In [None]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]

In [None]:
import os

VIEW = os.getenv("DATASET_VIEW") or "tab_adult"
TABLE = os.getenv("DATASET_TABLE") or "table"
ALG = os.getenv("SYNTH_ALG") or "privbayes"

import numpy as np
import pandas as pd
from pasteur.transform.table import TableTransformer

wrk: pd.DataFrame = catalog.load(f"{VIEW}.wrk.idx_{TABLE}")
alg: pd.DataFrame = catalog.load(f"{VIEW}.{ALG}.idx_{TABLE}")
tst: pd.DataFrame = catalog.load(f"{VIEW}.tst.idx_{TABLE}")

trn: TableTransformer = catalog.load(f"{VIEW}.wrk.trn_{TABLE}")

from pasteur.metadata import Metadata 
params = catalog.load(f"parameters")
meta = Metadata.from_kedro_params(params, VIEW, {TABLE: wrk}).get_table(TABLE)

random_state = catalog.load("params:random_state")

In [None]:
cols = list(wrk.keys())

a = wrk[cols].to_numpy(dtype="uint16")
b = alg[cols].to_numpy(dtype="uint16")
c = tst[cols].to_numpy(dtype="uint16")

domain = np.concatenate([a, b, c]).max(axis=0) + 1

In [None]:
from pasteur.math import calc_marginal_1way

# Add at least one sample prob to distr chisquare valid
zero_fill = 1/len(a)

def marg(s: np.ndarray, names: list[str], is_kl = False):
    x = [cols.index(name) for name in names]
    return calc_marginal_1way(s, domain, x, 1e-24 if is_kl else zero_fill)

In [None]:
from scipy.stats import chisquare

res = []
for split, s in [("alg", b), ("tst", c)]:
    for col in cols:
        k = marg(a, [col])
        j = marg(s, [col])
        chi, p = chisquare(k, j)
        # print(f"{name:12}.{col:20}: X^2={chi:3.3f} p={100*p:7.3f}%")
        res.append([split, col, chi, p])

cs_df = pd.DataFrame(res, columns=["split", "col", "X^2", "p"])

In [None]:
# res.set_index(keys=["col"]).sort_index().style.background_gradient(axis=0)
cs_df.pivot(
    index=["col"], columns=["split"], values=["X^2", "p"]
).sort_index().style.set_precision(3).background_gradient(
    axis=0, subset=(slice(None), (slice(None), "alg")), cmap='Reds'
).background_gradient(
    axis=0, subset=(slice(None), (slice(None), "tst"))
)

<pandas.io.formats.style.Styler object at 0x7fd9f04e20e0>


Unnamed: 0_level_0,X^2,X^2,p,p
split,alg,tst,alg,tst
col,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
age,0.084,0.002,1.0,1.0
capital-gain,1.753,0.001,1.0,1.0
capital-loss,0.001,0.001,1.0,1.0
education,0.0,0.001,1.0,1.0
education-num,0.0,0.001,1.0,1.0
fnlwgt,0.0,0.002,1.0,1.0
hours-per-week,0.466,0.003,1.0,1.0
income,0.0,0.0,1.0,0.992
marital-status,0.0,0.0,1.0,1.0
native-country,0.001,0.004,1.0,1.0


In [None]:
diff_df = cs_df.pivot(
    index=["col"], columns=[], values=["X^2", "p", "split"]
)
diff_df = diff_df[diff_df["split"] == "alg"].drop(columns=["split"]) - diff_df[diff_df["split"] == "tst"].drop(columns=["split"])

cs_gmap = diff_df / diff_df.abs().max(axis=0) / 2 + 0.5
cs_gmap = cs_gmap.sort_index().to_numpy()

In [None]:
cs_df.pivot(
    index=["col"], columns=["split"], values=["X^2", "p"]
).sort_index().style.background_gradient(
    axis=0, subset=(slice(None), (slice(None), "tst"))#, cmap="Greys"
).background_gradient(
    axis=None, subset=(slice(None), (slice(None), "alg")), cmap="RdBu", vmin=0, vmax=1, gmap=cs_gmap
)

<pandas.io.formats.style.Styler object at 0x7fd9efa66950>


Unnamed: 0_level_0,X^2,X^2,p,p
split,alg,tst,alg,tst
col,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
age,0.084453,0.001818,1.0,1.0
capital-gain,1.753099,0.000749,1.0,1.0
capital-loss,0.000797,0.001481,1.0,1.0
education,0.000194,0.00114,1.0,1.0
education-num,0.000194,0.00114,1.0,1.0
fnlwgt,0.000397,0.002013,1.0,1.0
hours-per-week,0.465833,0.003205,1.0,1.0
income,0.0,9.5e-05,0.999885,0.992243
marital-status,2.1e-05,0.000498,1.0,1.0
native-country,0.000672,0.003623,1.0,1.0


In [None]:
from scipy.special import rel_entr

res = []
for col_i in cols:
    for col_j in cols:
        for split, s in [("alg", b), ("tst", c)]:
            k = marg(a, [col_i, col_j], True)
            j = marg(s, [col_i, col_j], True)
            
            kl = rel_entr(k, j).sum()
            kl_norm = 1 / (1 + kl)
            res.append([split, col_i, col_j, kl, kl_norm, len(k)])

kl_df = pd.DataFrame(
    res,
    columns=[
        "split",
        "col_i",
        "col_j",
        "kl",
        "kl_norm",
        "mlen",
    ],
)
kl_df.head().style

<pandas.io.formats.style.Styler object at 0x7fd9f04e1d80>


Unnamed: 0,split,col_i,col_j,kl,kl_norm,mlen
0,alg,age,age,0.054128,0.948651,1089
1,tst,age,age,0.000914,0.999086,1089
2,alg,age,workclass,0.076616,0.928836,297
3,tst,age,workclass,0.055948,0.947016,297
4,alg,age,fnlwgt,0.107825,0.90267,1089


In [None]:
def mk_kl_plot(filter=None, val="kl_norm"):
    pt = kl_df[filter] if filter is not None else kl_df
    pt = pt.pivot(
        values=val, index=["col_j"], columns=["col_i", "split"]
    ).sort_index(axis=0).sort_index(axis=1)

    pt = pt.style.set_precision(3).background_gradient(
        axis=None, subset=(slice(None), (slice(None), "alg")), cmap='Reds'
    ).background_gradient(
        axis=None, subset=(slice(None), (slice(None), "tst"))
    ).applymap(
        lambda x: "color: transparent; background-color: transparent"
        if pd.isnull(x)
        else ""
    )

    return pt
mk_kl_plot(val="kl_norm")

<pandas.io.formats.style.Styler object at 0x7fd9f0172680>


col_i,age,age,capital-gain,capital-gain,capital-loss,capital-loss,education,education,education-num,education-num,fnlwgt,fnlwgt,hours-per-week,hours-per-week,income,income,marital-status,marital-status,native-country,native-country,occupation,occupation,race,race,relationship,relationship,sex,sex,workclass,workclass
split,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst
col_j,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2
age,0.949,0.999,0.752,0.923,0.877,0.833,0.882,0.853,0.882,0.853,0.903,0.826,0.765,0.728,0.947,0.988,0.943,0.951,0.844,0.63,0.925,0.893,0.939,0.954,0.944,0.953,0.949,0.997,0.929,0.947
capital-gain,0.752,0.923,0.796,0.998,0.79,0.994,0.794,0.961,0.794,0.961,0.793,0.962,0.713,0.911,0.796,0.998,0.796,0.986,0.772,0.92,0.792,0.968,0.795,0.992,0.797,0.984,0.798,0.994,0.797,0.98
capital-loss,0.877,0.833,0.79,0.994,0.996,0.995,0.971,0.893,0.971,0.893,0.964,0.909,0.838,0.856,0.996,0.992,0.962,0.963,0.937,0.918,0.97,0.917,0.991,0.971,0.96,0.96,0.991,0.991,0.98,0.942
education,0.882,0.853,0.794,0.961,0.971,0.893,1.0,0.999,1.0,0.999,0.967,0.905,0.849,0.839,1.0,0.999,0.987,0.988,0.913,0.762,0.986,0.95,0.997,0.986,0.999,0.991,0.999,0.998,0.981,0.982
education-num,0.882,0.853,0.794,0.961,0.971,0.893,1.0,0.999,1.0,0.999,0.967,0.905,0.849,0.839,1.0,0.999,0.987,0.988,0.913,0.762,0.986,0.95,0.997,0.986,0.999,0.991,0.999,0.998,0.981,0.982
fnlwgt,0.903,0.826,0.793,0.962,0.964,0.909,0.967,0.905,0.967,0.905,0.998,0.992,0.841,0.828,0.994,0.99,0.975,0.938,0.911,0.826,0.961,0.899,0.994,0.97,0.987,0.936,0.997,0.978,0.98,0.952
hours-per-week,0.765,0.728,0.713,0.911,0.838,0.856,0.849,0.839,0.849,0.839,0.841,0.828,0.885,0.999,0.874,0.992,0.863,0.948,0.795,0.708,0.875,0.884,0.876,0.965,0.865,0.956,0.886,0.996,0.863,0.95
income,0.947,0.988,0.796,0.998,0.996,0.992,1.0,0.999,1.0,0.999,0.994,0.99,0.874,0.992,1.0,1.0,1.0,0.999,0.997,0.989,0.992,0.995,0.999,1.0,1.0,1.0,0.999,1.0,0.997,0.997
marital-status,0.943,0.951,0.796,0.986,0.962,0.963,0.987,0.988,0.987,0.988,0.975,0.938,0.863,0.948,1.0,0.999,1.0,1.0,0.936,0.895,0.996,0.994,0.996,0.998,0.999,0.998,1.0,0.999,0.991,0.995
native-country,0.844,0.63,0.772,0.92,0.937,0.918,0.913,0.762,0.913,0.762,0.911,0.826,0.795,0.708,0.997,0.989,0.936,0.895,1.0,0.998,0.955,0.781,0.99,0.943,0.97,0.897,0.998,0.99,0.965,0.873


In [None]:
def color_dataframe(
    df: pd.DataFrame,
    idx: list[str],
    cols: list[str],
    vals: list[str],
    ref_split="tst",
    split_col="split",
    cmap="BrBG",
    cmap_ref="Purples",
    diff_reverse=True,
    formatters: dict[str, dict] | None = None,
):
    pt = (
        df.pivot(index=idx, columns=[*cols, split_col], values=vals)
        .sort_index(0)
        .sort_index(1)
    )
    pts = pt.style

    if formatters:
        for col, form in formatters.items():
            pts = pts.format(
                subset=(
                    slice(None),
                    (col, *[slice(None) for _ in range(len(cols) + 1)]),
                ),
                **form
            )

    # Apply background style to ref columns
    for col in vals:
        pts = pts.background_gradient(
            axis=None,
            subset=(
                slice(None),
                (col, *[slice(None) for _ in range(len(cols))], ref_split),
            ),
            cmap=cmap_ref,
        )

    # Apply background to non-ref columns
    # It is based in the difference between expected value to resulting value
    # red = too low
    # white = same, good
    # white = too high
    df_ref = (
        df[df[split_col] == ref_split]
        .pivot(index=idx, columns=cols, values=vals)
        .sort_index(0)
        .sort_index(1)
    )
    splits = df[split_col].unique()
    for split in splits:
        if split == ref_split:
            continue

        df_split = (
            df[df[split_col] == split]
            .pivot(index=idx, columns=cols, values=vals)
            .sort_index(0)
            .sort_index(1)
        )
        df_diff = df_split - df_ref
        if diff_reverse:
            df_diff = -df_diff
        df_norm = df_diff / df_diff.abs().max(axis=0) / 2 + 0.5

        pts = pts.background_gradient(
            axis=None,
            subset=(
                slice(None),
                (*[slice(None) for _ in range(len(cols) + 1)], split),
            ),
            gmap=df_norm.to_numpy(),
            vmin=0,
            vmax=1,
            cmap=cmap,
        )

    return pts


In [None]:
kl_formatters = {
    "kl_norm": {"precision": 3}
}
color_dataframe(kl_df, idx=["col_j"], cols=["col_i"], vals=["kl_norm"], formatters=kl_formatters)

<pandas.io.formats.style.Styler object at 0x7fd9f0172230>


Unnamed: 0_level_0,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm,kl_norm
col_i,age,age,capital-gain,capital-gain,capital-loss,capital-loss,education,education,education-num,education-num,fnlwgt,fnlwgt,hours-per-week,hours-per-week,income,income,marital-status,marital-status,native-country,native-country,occupation,occupation,race,race,relationship,relationship,sex,sex,workclass,workclass
split,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst,alg,tst
col_j,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3,Unnamed: 22_level_3,Unnamed: 23_level_3,Unnamed: 24_level_3,Unnamed: 25_level_3,Unnamed: 26_level_3,Unnamed: 27_level_3,Unnamed: 28_level_3,Unnamed: 29_level_3,Unnamed: 30_level_3
age,0.949,0.999,0.752,0.923,0.877,0.833,0.882,0.853,0.882,0.853,0.903,0.826,0.765,0.728,0.947,0.988,0.943,0.951,0.844,0.63,0.925,0.893,0.939,0.954,0.944,0.953,0.949,0.997,0.929,0.947
capital-gain,0.752,0.923,0.796,0.998,0.79,0.994,0.794,0.961,0.794,0.961,0.793,0.962,0.713,0.911,0.796,0.998,0.796,0.986,0.772,0.92,0.792,0.968,0.795,0.992,0.797,0.984,0.798,0.994,0.797,0.98
capital-loss,0.877,0.833,0.79,0.994,0.996,0.995,0.971,0.893,0.971,0.893,0.964,0.909,0.838,0.856,0.996,0.992,0.962,0.963,0.937,0.918,0.97,0.917,0.991,0.971,0.96,0.96,0.991,0.991,0.98,0.942
education,0.882,0.853,0.794,0.961,0.971,0.893,1.0,0.999,1.0,0.999,0.967,0.905,0.849,0.839,1.0,0.999,0.987,0.988,0.913,0.762,0.986,0.95,0.997,0.986,0.999,0.991,0.999,0.998,0.981,0.982
education-num,0.882,0.853,0.794,0.961,0.971,0.893,1.0,0.999,1.0,0.999,0.967,0.905,0.849,0.839,1.0,0.999,0.987,0.988,0.913,0.762,0.986,0.95,0.997,0.986,0.999,0.991,0.999,0.998,0.981,0.982
fnlwgt,0.903,0.826,0.793,0.962,0.964,0.909,0.967,0.905,0.967,0.905,0.998,0.992,0.841,0.828,0.994,0.99,0.975,0.938,0.911,0.826,0.961,0.899,0.994,0.97,0.987,0.936,0.997,0.978,0.98,0.952
hours-per-week,0.765,0.728,0.713,0.911,0.838,0.856,0.849,0.839,0.849,0.839,0.841,0.828,0.885,0.999,0.874,0.992,0.863,0.948,0.795,0.708,0.875,0.884,0.876,0.965,0.865,0.956,0.886,0.996,0.863,0.95
income,0.947,0.988,0.796,0.998,0.996,0.992,1.0,0.999,1.0,0.999,0.994,0.99,0.874,0.992,1.0,1.0,1.0,0.999,0.997,0.989,0.992,0.995,0.999,1.0,1.0,1.0,0.999,1.0,0.997,0.997
marital-status,0.943,0.951,0.796,0.986,0.962,0.963,0.987,0.988,0.987,0.988,0.975,0.938,0.863,0.948,1.0,0.999,1.0,1.0,0.936,0.895,0.996,0.994,0.996,0.998,0.999,0.998,1.0,0.999,0.991,0.995
native-country,0.844,0.63,0.772,0.92,0.937,0.918,0.913,0.762,0.913,0.762,0.911,0.826,0.795,0.708,0.997,0.989,0.936,0.895,1.0,0.998,0.955,0.781,0.99,0.943,0.97,0.897,0.998,0.99,0.965,0.873


In [None]:
cs_formatters = {
    "X^2": {"precision": 3},
    "p": {"formatter": lambda x: f"{100*x:.1f}"}
}
color_dataframe(cs_df, idx=["col"], cols=[], vals=["X^2", "p"], formatters=cs_formatters)

<pandas.io.formats.style.Styler object at 0x7fd9f03fe230>


Unnamed: 0_level_0,X^2,X^2,p,p
split,alg,tst,alg,tst
col,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
age,0.084,0.002,100.0,100.0
capital-gain,1.753,0.001,100.0,100.0
capital-loss,0.001,0.001,100.0,100.0
education,0.0,0.001,100.0,100.0
education-num,0.0,0.001,100.0,100.0
fnlwgt,0.0,0.002,100.0,100.0
hours-per-week,0.466,0.003,100.0,100.0
income,0.0,0.0,100.0,99.2
marital-status,0.0,0.0,100.0,100.0
native-country,0.001,0.004,100.0,100.0
