In [ ]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    catalog: kedro.io.data_catalog.DataCatalog
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]

import numpy as np
import pandas as pd

patients_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.patients")
admissions_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.admissions")
transfers_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.transfers")

patients_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.patients")
admissions_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.admissions")
transfers_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.transfers")

metadata: Dict = catalog.load("params:mimic_mm_core.metadata")


2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.patients` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.admissions` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.transfers` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.patients` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.admissions` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.transfers` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `params:mimic_mm_core.metadata` (MemoryDataSet)...


In [ ]:
tables = {
    "patients": [patients_wrk, patients_ref],
    "admissions": [admissions_wrk, admissions_ref],
    "transfers": [transfers_wrk, transfers_ref],
}

def gen_freq(a, b, cols, is_cat=None, fillna=1e-6, bins=32):
    if isinstance(cols, str):
        cols = [cols]
    if isinstance(is_cat, bool):
        is_cat = [is_cat]
    
    # Shallow copy a, b and only keep cols
    # FIXME: stop dropping NAs
    a, b = a[cols].dropna(), b[cols].dropna()

    # Convert any continuous variables to discrete
    if is_cat is not None and not all(is_cat):
        for col, cat in zip(cols, is_cat):
            if cat:
                continue
            
            c, d = pd.to_numeric(a[col]), pd.to_numeric(b[col])

            col_bins = np.histogram_bin_edges(c, bins=bins).astype(np.float32)
            a[col] = np.digitize(c, col_bins)
            b[col] = np.digitize(d, col_bins)


    ## Calculate frequencies of discrete variables
    a, b = a.value_counts(), b.value_counts()
    c = pd.concat([a, b], axis=1)
    c = c / c.sum()
    c = c.fillna(value=fillna)
    c = c / c.sum()
    return c.iloc[:, 0], c.iloc[:, 1]


In [ ]:
from scipy.stats import chisquare

res = []
for name, (a, b) in tables.items():
    for col in a.keys():
        if metadata["tables"][name]["fields"][col]["type"] == "categorical":
            k, j = gen_freq(a, b, col)
            chi, p = chisquare(k, j)
            # print(f"{name:12}.{col:20}: X^2={chi:3.3f} p={100*p:7.3f}%")
            res.append([name, col, chi, p])

res = pd.DataFrame(res, columns=["table", "col", "X^2", "p"])
res.set_index(keys=["table", "col"]).style.background_gradient(axis=0)


Unnamed: 0_level_0,Unnamed: 1_level_0,X^2,p
table,col,Unnamed: 2_level_1,Unnamed: 3_level_1
patients,gender,1.4e-05,0.997065
patients,anchor_year_group,4.7e-05,1.0
admissions,admission_type,7.1e-05,1.0
admissions,admission_location,0.000102,1.0
admissions,discharge_location,0.000225,1.0
admissions,insurance,0.000203,0.999899
admissions,language,2.7e-05,0.995881
admissions,marital_status,0.000221,0.999999
admissions,ethnicity,0.000156,1.0
admissions,hospital_expire_flag,2.8e-05,0.995768


In [ ]:
from scipy.stats import ks_2samp

res = []
for name, (a, b) in tables.items():
    for col in a.keys():
        c, d = a[col].dropna(), b[col].dropna()
        match metadata["tables"][name]["fields"][col]["type"]:
            case "datetime" | "timespan":
                c, d = pd.to_numeric(c), pd.to_numeric(d)
            case "numerical":
                pass
            case _:
                continue
        ks, p = ks_2samp(c, d)
        res.append([name, col, ks, p])

res = pd.DataFrame(res, columns=["table", "col", "K-S", "p"])
res.set_index(keys=["table", "col"]).style.background_gradient(axis=0)


Unnamed: 0_level_0,Unnamed: 1_level_0,K-S,p
table,col,Unnamed: 2_level_1,Unnamed: 3_level_1
patients,anchor_age,0.003486,0.310107
patients,aod,0.014049,0.836151
patients,dod,0.019823,0.430722
admissions,admittime,0.002856,0.359276
admissions,dischtime,0.00295,0.320756
admissions,deathtime,0.021946,0.319443
admissions,edregtime,0.004714,0.124735
admissions,edouttime,0.004698,0.127109
transfers,intime,0.001215,0.647334
transfers,outtime,0.001698,0.432048


In [ ]:
def get_parent(table):
    match table:
        case 'patients':
            return None
        case 'admissions':
            return 'patients'
        case 'transfers':
            return 'admissions'
        case _:
            assert False, 'Table not found'

In [ ]:
from scipy.special import kl_div


def is_categorical(table, col):
    return metadata["tables"][table]["fields"][col]["type"] == "categorical"

def is_id(table, col):
    return metadata["tables"][table]["fields"][col]["type"] == "id"

res = []
for name, (a, b) in tables.items():
    cols = sorted(a.keys())
    for i, col_i in enumerate(cols):
        if is_id(name, col_i):
            continue
        cat_i = is_categorical(name, col_i)

        for col_j in cols[i + 1 :]:
            if is_id(name, col_j):
                continue
            cat_j = is_categorical(name, col_j)

            k, j = gen_freq(a, b, [col_i, col_j], [cat_i, cat_j])
            kl = 1 / (1 + kl_div(k, j).sum())
            res.append([name, col_i, name, col_j, kl, len(k), cat_i, cat_j])

for table in tables:
    a, b = tables[table]
    cols_i = sorted(a.keys())
    parent = table

    while get_parent(parent):
        parent = get_parent(parent)

        c, d = tables[parent]
        a, b = a.join(c, rsuffix=parent), b.join(d, rsuffix=parent)
        cols_j = sorted(c.keys())

        for col_i in cols_i:
            if is_id(table, col_i):
                continue
            cat_i = is_categorical(table, col_i)

            for col_j in cols_j:
                if is_id(parent, col_j):
                    continue
                cat_j = is_categorical(parent, col_j)

                k, j = gen_freq(a, b, [col_i, col_j], [cat_i, cat_j])
                kl = 1 / (1 + kl_div(k, j).sum())
                res.append([table, col_i, parent, col_j, kl, len(k), cat_i, cat_j])

res = pd.DataFrame(res, columns=["table_i", "col_i", "table_j", "col_j", "KL", "mlen", "cat_i", "cat_j"])

In [ ]:
res

Unnamed: 0,table_i,col_i,table_j,col_j,KL,mlen,cat_i,cat_j
0,patients,anchor_age,patients,anchor_year_group,0.999017,112,False,True
1,patients,anchor_age,patients,aod,0.952908,128,False,False
2,patients,anchor_age,patients,dod,0.744475,623,False,False
3,patients,anchor_age,patients,gender,0.999599,56,False,True
4,patients,anchor_year_group,patients,aod,0.951177,107,True,False
...,...,...,...,...,...,...,...,...
226,transfers,outtime,patients,anchor_age,0.999152,66,False,False
227,transfers,outtime,patients,anchor_year_group,0.999727,23,False,True
228,transfers,outtime,patients,aod,0.851185,222,False,False
229,transfers,outtime,patients,dod,0.841235,225,False,False


In [ ]:
def mk_ks_plot(filter=None, val="KL"):
    # idx_tuples = []
    # for table in reversed(tables):
    #     a, _ = tables[table]
    #     for col in sorted(a.keys()):
    #         if is_categorical(table, col):
    #             idx_tuples.append((table, col))

    # idx = pd.MultiIndex.from_tuples(idx_tuples, names=['table_i', 'col_i'])
    # cols = pd.MultiIndex.from_tuples(reversed(idx_tuples), names=['table_j', 'col_j'])

    pt = res[filter] if filter is not None else res
    pt = pt.pivot_table(
        values=val, index=["table_i", "col_i"], columns=["table_j", "col_j"]
    )

    idx_tuples = []
    for table in reversed(tables):
        idx_tuples.extend([i for i in pt.index if i[0] == table])
    idx = pd.MultiIndex.from_tuples(idx_tuples, names=["table_i", "col_i"])

    col_tuples = []
    for table in tables:
        col_tuples.extend([i for i in reversed(pt.columns) if i[0] == table])
    cols = pd.MultiIndex.from_tuples(col_tuples, names=["table_j", "col_j"])

    pt = pt.reindex(index=idx, columns=cols)
    pt = pt.style.background_gradient(axis=1).applymap(
        lambda x: "color: transparent; background-color: transparent"
        if pd.isnull(x)
        else ""
    )

    return pt


In [ ]:
mk_ks_plot(res['cat_i'] & res['cat_j'])

Unnamed: 0_level_0,table_j,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers
Unnamed: 0_level_1,col_j,gender,anchor_year_group,marital_status,language,insurance,hospital_expire_flag,ethnicity,discharge_location,admission_type,admission_location,eventtype
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
transfers,careunit,0.998714,0.997964,0.997765,0.998871,0.998571,0.998953,0.996487,0.993935,0.996203,0.995754,0.999652
transfers,eventtype,0.999949,0.999905,0.999832,0.999971,0.999858,0.999955,0.999819,0.999635,0.999863,0.999806,
admissions,admission_location,0.999811,0.999635,0.99947,0.999899,0.999764,0.999893,0.999275,0.998945,0.999598,,
admissions,admission_type,0.99985,0.999713,0.999653,0.999909,0.999757,0.999923,0.999565,0.99927,,,
admissions,discharge_location,0.999568,0.999447,0.999174,0.999657,0.99951,0.999832,0.998801,,,,
admissions,ethnicity,0.999831,0.999684,0.999183,0.999878,0.999483,0.999855,,,,,
admissions,hospital_expire_flag,0.999932,0.999901,0.99984,0.999954,0.999863,,,,,,
admissions,insurance,0.999885,0.999877,0.999594,0.999839,,,,,,,
admissions,language,0.999939,0.999968,0.999667,,,,,,,,
admissions,marital_status,0.999797,0.999723,,,,,,,,,


In [ ]:
mk_ks_plot(res['cat_i'] & res['cat_j'], "mlen")

Unnamed: 0_level_0,table_j,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers
Unnamed: 0_level_1,col_j,gender,anchor_year_group,marital_status,language,insurance,hospital_expire_flag,ethnicity,discharge_location,admission_type,admission_location,eventtype
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
transfers,careunit,84.0,168.0,168.0,84.0,126.0,84.0,333.0,513.0,375.0,438.0,83.0
transfers,eventtype,8.0,16.0,16.0,8.0,12.0,8.0,32.0,52.0,36.0,44.0,
admissions,admission_location,22.0,44.0,44.0,22.0,33.0,21.0,83.0,134.0,66.0,,
admissions,admission_type,18.0,36.0,36.0,18.0,27.0,18.0,72.0,104.0,,,
admissions,discharge_location,26.0,52.0,52.0,26.0,39.0,22.0,103.0,,,,
admissions,ethnicity,16.0,32.0,32.0,16.0,24.0,16.0,,,,,
admissions,hospital_expire_flag,4.0,8.0,8.0,4.0,6.0,,,,,,
admissions,insurance,6.0,12.0,12.0,6.0,,,,,,,
admissions,language,4.0,8.0,8.0,,,,,,,,
admissions,marital_status,8.0,16.0,,,,,,,,,


In [ ]:
mk_ks_plot(~res['cat_i'] & ~res['cat_j'])

Unnamed: 0_level_0,table_j,patients,patients,patients,admissions,admissions,admissions,admissions,admissions,transfers
Unnamed: 0_level_1,col_j,dod,aod,anchor_age,edregtime,edouttime,dischtime,deathtime,admittime,outtime
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
transfers,intime,0.886423,0.871142,0.999403,0.999374,0.99938,0.999691,0.954824,0.999678,0.999962
transfers,outtime,0.841235,0.851185,0.999152,0.999007,0.999037,0.999452,0.861023,0.999462,
admissions,admittime,0.739644,0.734971,0.995674,0.999513,0.999527,0.999632,0.964872,,
admissions,deathtime,0.126661,0.121534,0.629555,0.952566,0.950708,0.988772,,,
admissions,dischtime,0.736879,0.735196,0.99565,0.999422,0.999432,,,,
admissions,edouttime,0.61241,0.621356,0.993422,0.999345,,,,,
admissions,edregtime,0.612163,0.619081,0.9934,,,,,,
patients,anchor_age,0.744475,0.952908,,,,,,,
patients,aod,0.75194,,,,,,,,


In [ ]:
mk_ks_plot(res['cat_i'] ^ res['cat_j'])

Unnamed: 0_level_0,table_j,patients,patients,patients,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers,transfers
Unnamed: 0_level_1,col_j,gender,dod,aod,anchor_year_group,anchor_age,marital_status,language,insurance,hospital_expire_flag,ethnicity,edregtime,edouttime,dischtime,discharge_location,deathtime,admittime,admission_type,admission_location,outtime,intime
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
transfers,careunit,,0.60359,0.615019,,0.987303,,,,,,0.987202,0.987208,0.992427,,0.598843,0.992493,,,0.999691,0.999697
transfers,eventtype,,0.972013,0.964646,,0.999348,,,,,,0.999105,0.999125,0.999489,,0.969429,0.999479,,,0.999943,0.999976
transfers,intime,0.999849,,,0.999826,,0.999762,0.99984,0.999798,0.999871,0.999786,,,,0.999738,,,0.999848,0.99981,,
transfers,outtime,0.999753,,,0.999727,,0.999631,0.999775,0.999742,0.999784,0.9996,,,,0.999613,,,0.999718,0.999724,,
admissions,admission_location,,0.887385,0.889972,,0.997357,,,,,,0.998348,0.998345,0.998774,,0.906651,0.998759,,,,
admissions,admission_type,,0.913515,0.908474,,0.998275,,,,,,0.998482,0.99849,0.998909,,0.920544,0.998929,,,,
admissions,admittime,0.999403,,,0.999219,,0.999004,0.999621,0.999162,0.999736,0.998503,,,,0.99807,,,,,,
admissions,deathtime,0.973491,,,0.948453,,0.958451,0.986936,0.966912,0.993006,0.933752,,,,0.95582,,,,,,
admissions,discharge_location,,0.857385,0.824613,,0.996608,,,,,,0.996818,0.996809,0.998012,,,,,,,
admissions,dischtime,0.999384,,,0.999207,,0.99906,0.999626,0.99915,0.999749,0.998523,,,,,,,,,,


In [ ]:
mk_ks_plot()

Unnamed: 0_level_0,table_j,patients,patients,patients,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers,transfers,transfers
Unnamed: 0_level_1,col_j,gender,dod,aod,anchor_year_group,anchor_age,marital_status,language,insurance,hospital_expire_flag,ethnicity,edregtime,edouttime,dischtime,discharge_location,deathtime,admittime,admission_type,admission_location,outtime,intime,eventtype
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
transfers,careunit,0.998714,0.60359,0.615019,0.997964,0.987303,0.997765,0.998871,0.998571,0.998953,0.996487,0.987202,0.987208,0.992427,0.993935,0.598843,0.992493,0.996203,0.995754,0.999691,0.999697,0.999652
transfers,eventtype,0.999949,0.972013,0.964646,0.999905,0.999348,0.999832,0.999971,0.999858,0.999955,0.999819,0.999105,0.999125,0.999489,0.999635,0.969429,0.999479,0.999863,0.999806,0.999943,0.999976,
transfers,intime,0.999849,0.886423,0.871142,0.999826,0.999403,0.999762,0.99984,0.999798,0.999871,0.999786,0.999374,0.99938,0.999691,0.999738,0.954824,0.999678,0.999848,0.99981,0.999962,,
transfers,outtime,0.999753,0.841235,0.851185,0.999727,0.999152,0.999631,0.999775,0.999742,0.999784,0.9996,0.999007,0.999037,0.999452,0.999613,0.861023,0.999462,0.999718,0.999724,,,
admissions,admission_location,0.999811,0.887385,0.889972,0.999635,0.997357,0.99947,0.999899,0.999764,0.999893,0.999275,0.998348,0.998345,0.998774,0.998945,0.906651,0.998759,0.999598,,,,
admissions,admission_type,0.99985,0.913515,0.908474,0.999713,0.998275,0.999653,0.999909,0.999757,0.999923,0.999565,0.998482,0.99849,0.998909,0.99927,0.920544,0.998929,,,,,
admissions,admittime,0.999403,0.739644,0.734971,0.999219,0.995674,0.999004,0.999621,0.999162,0.999736,0.998503,0.999513,0.999527,0.999632,0.99807,0.964872,,,,,,
admissions,deathtime,0.973491,0.126661,0.121534,0.948453,0.629555,0.958451,0.986936,0.966912,0.993006,0.933752,0.952566,0.950708,0.988772,0.95582,,,,,,,
admissions,discharge_location,0.999568,0.857385,0.824613,0.999447,0.996608,0.999174,0.999657,0.99951,0.999832,0.998801,0.996818,0.996809,0.998012,,,,,,,,
admissions,dischtime,0.999384,0.736879,0.735196,0.999207,0.99565,0.99906,0.999626,0.99915,0.999749,0.998523,0.999422,0.999432,,,,,,,,,
