In [ ]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    catalog: kedro.io.data_catalog.DataCatalog
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]

import pandas as pd

patients_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.patients")
admissions_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.admissions")
transfers_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.transfers")

patients_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.patients")
admissions_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.admissions")
transfers_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.transfers")

metadata: Dict = catalog.load("params:mimic_mm_core.metadata")


2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.patients` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.admissions` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.transfers` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.patients` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.admissions` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.transfers` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `params:mimic_mm_core.metadata` (MemoryDataSet)...


In [ ]:
tables = {
    "patients": [patients_wrk, patients_ref],
    "admissions": [admissions_wrk, admissions_ref],
    "transfers": [transfers_wrk, transfers_ref],
}


def gen_freq(a, b, cols, fillna=1e-6):
    if isinstance(cols, str):
        cols = [cols]

    a, b = a[cols].value_counts(), b[cols].value_counts()
    c = pd.concat([a, b], axis=1)
    c = c / c.sum()
    c = c.fillna(value=fillna)
    c = c / c.sum()
    return c.iloc[:, 0], c.iloc[:, 1]


In [ ]:
from scipy.stats import chisquare

res = []
for name, (a, b) in tables.items():
    for col in a.keys():
        if metadata["tables"][name]["fields"][col]["type"] == "categorical":
            k, j = gen_freq(a, b, col)
            chi, p = chisquare(k, j)
            # print(f"{name:12}.{col:20}: X^2={chi:3.3f} p={100*p:7.3f}%")
            res.append([name, col, chi, p])

res = pd.DataFrame(res, columns=["table", "col", "X^2", "p"])
res.set_index(keys=["table", "col"]).style.background_gradient(axis=0)


Unnamed: 0_level_0,Unnamed: 1_level_0,X^2,p
table,col,Unnamed: 2_level_1,Unnamed: 3_level_1
patients,gender,1.4e-05,0.997065
patients,anchor_year_group,4.7e-05,1.0
admissions,admission_type,7.1e-05,1.0
admissions,admission_location,0.000102,1.0
admissions,discharge_location,0.000225,1.0
admissions,insurance,0.000203,0.999899
admissions,language,2.7e-05,0.995881
admissions,marital_status,0.000221,0.999999
admissions,ethnicity,0.000156,1.0
admissions,hospital_expire_flag,2.8e-05,0.995768


In [ ]:
from scipy.stats import ks_2samp

res = []
for name, (a, b) in tables.items():
    for col in a.keys():
        c, d = a[col].dropna(), b[col].dropna()
        match metadata["tables"][name]["fields"][col]["type"]:
            case "datetime" | "timespan":
                c, d = pd.to_numeric(c), pd.to_numeric(d)
            case "numerical":
                pass
            case _:
                continue
        ks, p = ks_2samp(c, d)
        res.append([name, col, ks, p])

res = pd.DataFrame(res, columns=["table", "col", "K-S", "p"])
res.set_index(keys=["table", "col"]).style.background_gradient(axis=0)


Unnamed: 0_level_0,Unnamed: 1_level_0,K-S,p
table,col,Unnamed: 2_level_1,Unnamed: 3_level_1
patients,anchor_age,0.003486,0.310107
patients,aod,0.014049,0.836151
patients,dod,0.019823,0.430722
admissions,admittime,0.002856,0.359276
admissions,dischtime,0.00295,0.320756
admissions,deathtime,0.021946,0.319443
admissions,edregtime,0.004714,0.124735
admissions,edouttime,0.004698,0.127109
transfers,intime,0.001215,0.647334
transfers,outtime,0.001698,0.432048


In [ ]:
def get_parent(table):
    match table:
        case 'patients':
            return None
        case 'admissions':
            return 'patients'
        case 'transfers':
            return 'admissions'
        case _:
            assert False, 'Table not found'

In [ ]:
from scipy.special import kl_div


def is_categorical(table, col):
    return metadata["tables"][table]["fields"][col]["type"] == "categorical"

res = []
for name, (a, b) in tables.items():
    cols = sorted(a.keys())
    for i, col_i in enumerate(cols):
        if not is_categorical(name, col_i):
            continue

        for col_j in cols[i + 1 :]:
            if not is_categorical(name, col_j):
                continue

            k, j = gen_freq(a, b, [col_i, col_j])
            kl = 1 / (1 + kl_div(k, j).sum())
            res.append([name, col_i, name, col_j, kl])

for table in tables:
    a, b = tables[table]
    cols_i = sorted(a.keys())
    parent = table

    while get_parent(parent):
        parent = get_parent(parent)

        c, d = tables[parent]
        a, b = a.join(c, rsuffix=parent), b.join(d, rsuffix=parent)
        cols_j = sorted(c.keys())

        for col_i in cols_i:
            if not is_categorical(table, col_i):
                continue

            for col_j in cols_j:
                if not is_categorical(parent, col_j):
                    continue

                k, j = gen_freq(a, b, [col_i, col_j])
                kl = 1 / (1 + kl_div(k, j).sum())
                res.append([table, col_i, parent, col_j, kl])


In [ ]:
idx_tuples = []
for table in reversed(tables):
    a, _ = tables[table]
    for col in sorted(a.keys()):
        if is_categorical(table, col):
            idx_tuples.append((table, col))

idx = pd.MultiIndex.from_tuples(idx_tuples, names=['table_i', 'col_i'])
cols = pd.MultiIndex.from_tuples(reversed(idx_tuples), names=['table_j', 'col_j'])

pt = pd.DataFrame(res, columns=["table_i", "col_i", "table_j", "col_j", "KL"])
pt = pt.pivot_table(
    values="KL", index=["table_i", "col_i"], columns=["table_j", "col_j"]
)
pt = pt.reindex(index=idx, columns=cols)
pt.style.background_gradient(axis=1).applymap(
    lambda x: "color: transparent; background-color: transparent"
    if pd.isnull(x)
    else ""
)

Unnamed: 0_level_0,table_j,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers,transfers
Unnamed: 0_level_1,col_j,gender,anchor_year_group,marital_status,language,insurance,hospital_expire_flag,ethnicity,discharge_location,admission_type,admission_location,eventtype,careunit
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
transfers,careunit,0.998714,0.997964,0.997765,0.998871,0.998571,0.998953,0.996487,0.993935,0.996203,0.995754,0.999652,
transfers,eventtype,0.999949,0.999905,0.999832,0.999971,0.999858,0.999955,0.999819,0.999635,0.999863,0.999806,,
admissions,admission_location,0.999811,0.999635,0.99947,0.999899,0.999764,0.999893,0.999275,0.998945,0.999598,,,
admissions,admission_type,0.99985,0.999713,0.999653,0.999909,0.999757,0.999923,0.999565,0.99927,,,,
admissions,discharge_location,0.999568,0.999447,0.999174,0.999657,0.99951,0.999832,0.998801,,,,,
admissions,ethnicity,0.999831,0.999684,0.999183,0.999878,0.999483,0.999855,,,,,,
admissions,hospital_expire_flag,0.999932,0.999901,0.99984,0.999954,0.999863,,,,,,,
admissions,insurance,0.999885,0.999877,0.999594,0.999839,,,,,,,,
admissions,language,0.999939,0.999968,0.999667,,,,,,,,,
admissions,marital_status,0.999797,0.999723,,,,,,,,,,
