In [ ]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    catalog: kedro.io.data_catalog.DataCatalog
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]

import numpy as np
import pandas as pd

patients_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.patients")
admissions_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.admissions")
transfers_wrk: pd.DataFrame = catalog.load("mimic_mm_core.wrk.transfers")

patients_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.patients")
admissions_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.admissions")
transfers_ref: pd.DataFrame = catalog.load("mimic_mm_core.ref.transfers")

metadata: Dict = catalog.load("params:mimic_mm_core.metadata")


2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.patients` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.admissions` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.wrk.transfers` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.patients` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.admissions` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `mimic_mm_core.ref.transfers` (ParquetDataSet)...
2000-01-01 00:00:00,000 - kedro.io.data_catalog - INFO - Loading data from `params:mimic_mm_core.metadata` (MemoryDataSet)...


In [ ]:
tables = {
    "patients": [patients_wrk, patients_ref],
    "admissions": [admissions_wrk, admissions_ref],
    "transfers": [transfers_wrk, transfers_ref],
}

def gen_freq(a, b, cols, is_cat=None, fillna=1e-6, bins=32):
    if isinstance(cols, str):
        cols = [cols]
    if isinstance(is_cat, bool):
        is_cat = [is_cat]
    
    # Shallow copy a, b and only keep cols
    # FIXME: stop dropping NAs
    a, b = a[cols].dropna(), b[cols].dropna()

    # Convert any continuous variables to discrete
    if is_cat is not None and not all(is_cat):
        for col, cat in zip(cols, is_cat):
            if cat:
                continue
            
            c, d = pd.to_numeric(a[col]), pd.to_numeric(b[col])

            col_bins = np.histogram_bin_edges(c, bins=bins).astype(np.float32)
            a[col] = np.digitize(c, col_bins)
            b[col] = np.digitize(d, col_bins)


    ## Calculate frequencies of discrete variables
    a, b = a.value_counts(), b.value_counts()
    c = pd.concat([a, b], axis=1)
    c = c / c.sum()
    c = c.fillna(value=fillna)
    c = c / c.sum()
    return c.iloc[:, 0], c.iloc[:, 1]


In [ ]:
from scipy.stats import chisquare

res = []
for name, (a, b) in tables.items():
    for col in a.keys():
        if metadata["tables"][name]["fields"][col]["type"] == "categorical":
            k, j = gen_freq(a, b, col)
            chi, p = chisquare(k, j)
            # print(f"{name:12}.{col:20}: X^2={chi:3.3f} p={100*p:7.3f}%")
            res.append([name, col, chi, p])

res = pd.DataFrame(res, columns=["table", "col", "X^2", "p"])
res.set_index(keys=["table", "col"]).style.background_gradient(axis=0)


Unnamed: 0_level_0,Unnamed: 1_level_0,X^2,p
table,col,Unnamed: 2_level_1,Unnamed: 3_level_1
patients,gender,1.4e-05,0.997065
patients,year_group,4.7e-05,1.0
admissions,admission_type,7.1e-05,1.0
admissions,admission_location,0.000102,1.0
admissions,discharge_location,0.000225,1.0
admissions,insurance,0.000203,0.999899
admissions,language,2.7e-05,0.995881
admissions,marital_status,0.000221,0.999999
admissions,ethnicity,0.000156,1.0
admissions,hospital_expire_flag,2.8e-05,0.995768


In [ ]:
from scipy.stats import ks_2samp

res = []
for name, (a, b) in tables.items():
    for col in a.keys():
        c, d = a[col].dropna(), b[col].dropna()
        match metadata["tables"][name]["fields"][col]["type"]:
            case "datetime" | "timespan":
                c, d = pd.to_numeric(c), pd.to_numeric(d)
            case "numerical":
                pass
            case _:
                continue
        ks, p = ks_2samp(c, d)
        res.append([name, col, ks, p])

res = pd.DataFrame(res, columns=["table", "col", "K-S", "p"])
res.set_index(keys=["table", "col"]).style.background_gradient(axis=0)


Unnamed: 0_level_0,Unnamed: 1_level_0,K-S,p
table,col,Unnamed: 2_level_1,Unnamed: 3_level_1
patients,death_age,0.014049,0.836151
patients,death_date,0.016165,0.689379
admissions,admittime,0.004006,0.06905
admissions,dischtime,0.003988,0.071149
admissions,deathtime,0.016131,0.703661
admissions,edregtime,0.005217,0.066937
admissions,edouttime,0.005217,0.066945
transfers,intime,0.001215,0.647334
transfers,outtime,0.001698,0.432048


In [ ]:
def get_parent(table):
    match table:
        case 'patients':
            return None
        case 'admissions':
            return 'patients'
        case 'transfers':
            return 'admissions'
        case _:
            assert False, 'Table not found'

In [ ]:
from scipy.special import rel_entr


def is_categorical(table, col):
    return metadata["tables"][table]["fields"][col]["type"] == "categorical"

def is_id(table, col):
    return metadata["tables"][table]["fields"][col]["type"] == "id"

res = []
for name, (a, b) in tables.items():
    cols = sorted(a.keys())
    for i, col_i in enumerate(cols):
        if is_id(name, col_i):
            continue
        cat_i = is_categorical(name, col_i)

        for col_j in cols[i + 1 :]:
            if is_id(name, col_j):
                continue
            cat_j = is_categorical(name, col_j)

            k, j = gen_freq(a, b, [col_i, col_j], [cat_i, cat_j])
            kl = rel_entr(k, j).sum()
            kl_norm = 1 / (1 + kl)
            res.append([name, col_i, name, col_j, kl, kl_norm, len(k), cat_i, cat_j])

for table in tables:
    a, b = tables[table]
    cols_i = sorted(a.keys())
    parent = table

    while get_parent(parent):
        parent = get_parent(parent)

        c, d = tables[parent]
        a, b = a.join(c, rsuffix=parent), b.join(d, rsuffix=parent)
        cols_j = sorted(c.keys())

        for col_i in cols_i:
            if is_id(table, col_i):
                continue
            cat_i = is_categorical(table, col_i)

            for col_j in cols_j:
                if is_id(parent, col_j):
                    continue
                cat_j = is_categorical(parent, col_j)

                k, j = gen_freq(a, b, [col_i, col_j], [cat_i, cat_j])
                kl = rel_entr(k, j).sum()
                kl_norm = 1 / (1 + kl)
                res.append([table, col_i, parent, col_j, kl, kl_norm, len(k), cat_i, cat_j])

res = pd.DataFrame(res, columns=["table_i", "col_i", "table_j", "col_j", "kl", "kl_norm", "mlen", "cat_i", "cat_j"])

In [ ]:
res

Unnamed: 0,table_i,col_i,table_j,col_j,kl,kl_norm,mlen,cat_i,cat_j
0,patients,death_age,patients,death_date,0.016903,0.983378,54,False,False
1,patients,death_age,patients,gender,0.011108,0.989014,55,False,True
2,patients,death_age,patients,year_group,0.051329,0.951177,107,False,True
3,patients,death_date,patients,gender,0.011796,0.988341,56,False,True
4,patients,death_date,patients,year_group,0.057189,0.945904,109,False,True
...,...,...,...,...,...,...,...,...,...
205,transfers,intime,patients,year_group,0.000174,0.999826,21,False,True
206,transfers,outtime,patients,death_age,0.174832,0.851185,222,False,False
207,transfers,outtime,patients,death_date,0.173277,0.852314,221,False,False
208,transfers,outtime,patients,gender,0.000247,0.999753,16,False,True


In [ ]:
def mk_ks_plot(filter=None, val="kl_norm"):
    # idx_tuples = []
    # for table in reversed(tables):
    #     a, _ = tables[table]
    #     for col in sorted(a.keys()):
    #         if is_categorical(table, col):
    #             idx_tuples.append((table, col))

    # idx = pd.MultiIndex.from_tuples(idx_tuples, names=['table_i', 'col_i'])
    # cols = pd.MultiIndex.from_tuples(reversed(idx_tuples), names=['table_j', 'col_j'])

    pt = res[filter] if filter is not None else res
    pt = pt.pivot_table(
        values=val, index=["table_i", "col_i"], columns=["table_j", "col_j"]
    )

    idx_tuples = []
    for table in reversed(tables):
        idx_tuples.extend([i for i in pt.index if i[0] == table])
    idx = pd.MultiIndex.from_tuples(idx_tuples, names=["table_i", "col_i"])

    col_tuples = []
    for table in tables:
        col_tuples.extend([i for i in reversed(pt.columns) if i[0] == table])
    cols = pd.MultiIndex.from_tuples(col_tuples, names=["table_j", "col_j"])

    # Try to tweak colormap
    vmin = vmax = None
    match val:
        case "kl_norm":
            vmin = 0.996
        case "kl":
            vmax = 0.04
        case "mlen":
            vmax = 100

    pt = pt.reindex(index=idx, columns=cols)
    pt = pt.style.background_gradient(axis=None, vmin=vmin, vmax=vmax).applymap(
        lambda x: "color: transparent; background-color: transparent"
        if pd.isnull(x)
        else ""
    )

    return pt


In [ ]:
mk_ks_plot(res['cat_i'] & res['cat_j'])

Unnamed: 0_level_0,table_j,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers
Unnamed: 0_level_1,col_j,year_group,gender,marital_status,language,insurance,hospital_expire_flag,ethnicity,discharge_location,admission_type,admission_location,eventtype
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
transfers,careunit,0.997964,0.998714,0.997765,0.998871,0.998571,0.998953,0.996487,0.993935,0.996203,0.995754,0.999652
transfers,eventtype,0.999905,0.999949,0.999832,0.999971,0.999858,0.999955,0.999819,0.999635,0.999863,0.999806,
admissions,admission_location,0.999635,0.999811,0.99947,0.999899,0.999764,0.999893,0.999275,0.998945,0.999598,,
admissions,admission_type,0.999713,0.99985,0.999653,0.999909,0.999757,0.999923,0.999565,0.99927,,,
admissions,discharge_location,0.999447,0.999568,0.999174,0.999657,0.99951,0.999832,0.998801,,,,
admissions,ethnicity,0.999684,0.999831,0.999183,0.999878,0.999483,0.999855,,,,,
admissions,hospital_expire_flag,0.999901,0.999932,0.99984,0.999954,0.999863,,,,,,
admissions,insurance,0.999877,0.999885,0.999594,0.999839,,,,,,,
admissions,language,0.999968,0.999939,0.999667,,,,,,,,
admissions,marital_status,0.999723,0.999797,,,,,,,,,


In [ ]:
mk_ks_plot(res['cat_i'] & res['cat_j'], "mlen")

Unnamed: 0_level_0,table_j,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers
Unnamed: 0_level_1,col_j,year_group,gender,marital_status,language,insurance,hospital_expire_flag,ethnicity,discharge_location,admission_type,admission_location,eventtype
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
transfers,careunit,168.0,84.0,168.0,84.0,126.0,84.0,333.0,513.0,375.0,438.0,83.0
transfers,eventtype,16.0,8.0,16.0,8.0,12.0,8.0,32.0,52.0,36.0,44.0,
admissions,admission_location,44.0,22.0,44.0,22.0,33.0,21.0,83.0,134.0,66.0,,
admissions,admission_type,36.0,18.0,36.0,18.0,27.0,18.0,72.0,104.0,,,
admissions,discharge_location,52.0,26.0,52.0,26.0,39.0,22.0,103.0,,,,
admissions,ethnicity,32.0,16.0,32.0,16.0,24.0,16.0,,,,,
admissions,hospital_expire_flag,8.0,4.0,8.0,4.0,6.0,,,,,,
admissions,insurance,12.0,6.0,12.0,6.0,,,,,,,
admissions,language,8.0,4.0,8.0,,,,,,,,
admissions,marital_status,16.0,8.0,,,,,,,,,


In [ ]:
mk_ks_plot(~res['cat_i'] & ~res['cat_j'])

Unnamed: 0_level_0,table_j,patients,patients,admissions,admissions,admissions,admissions,admissions,transfers
Unnamed: 0_level_1,col_j,death_date,death_age,edregtime,edouttime,dischtime,deathtime,admittime,outtime
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
transfers,intime,0.870785,0.871142,0.999079,0.999079,0.999479,0.955714,0.999483,0.999962
transfers,outtime,0.852314,0.851185,0.998452,0.998455,0.999172,0.815146,0.999169,
admissions,admittime,0.750701,0.752574,0.999153,0.999328,0.99943,0.983446,,
admissions,deathtime,0.130858,0.125708,0.962904,0.962904,0.994148,,,
admissions,dischtime,0.750954,0.752755,0.998813,0.998844,,,,
admissions,edouttime,0.600952,0.601551,0.999139,,,,,
admissions,edregtime,0.600952,0.601551,,,,,,
patients,death_age,0.983378,,,,,,,


In [ ]:
mk_ks_plot(res['cat_i'] ^ res['cat_j'])

Unnamed: 0_level_0,table_j,patients,patients,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers,transfers
Unnamed: 0_level_1,col_j,year_group,gender,death_date,death_age,marital_status,language,insurance,hospital_expire_flag,ethnicity,edregtime,edouttime,dischtime,discharge_location,deathtime,admittime,admission_type,admission_location,outtime,intime
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
transfers,careunit,,,0.607822,0.615019,,,,,,0.980829,0.980824,0.989082,,0.586182,0.989068,,,0.999691,0.999697
transfers,eventtype,,,0.960007,0.964646,,,,,,0.998503,0.998501,0.999232,,0.97287,0.999239,,,0.999943,0.999976
transfers,intime,0.999826,0.999849,,,0.999762,0.99984,0.999798,0.999871,0.999786,,,,0.999738,,,0.999848,0.99981,,
transfers,outtime,0.999727,0.999753,,,0.999631,0.999775,0.999742,0.999784,0.9996,,,,0.999613,,,0.999718,0.999724,,
admissions,admission_location,,,0.891783,0.889972,,,,,,0.996601,0.996597,0.997634,,0.917374,0.997623,,,,
admissions,admission_type,,,0.920345,0.908474,,,,,,0.99749,0.997484,0.997965,,0.934443,0.997953,,,,
admissions,admittime,0.999066,0.999381,,,0.997853,0.999226,0.998262,0.999565,0.996667,,,,0.997074,,,,,,
admissions,deathtime,0.932126,0.975499,,,0.961944,0.979443,0.976739,0.994148,0.935272,,,,0.94714,,,,,,
admissions,discharge_location,,,0.82653,0.824613,,,,,,0.994982,0.994982,0.997061,,,,,,,
admissions,dischtime,0.999085,0.999391,,,0.997838,0.999227,0.998272,0.999553,0.996633,,,,,,,,,,


In [ ]:
mk_ks_plot()

Unnamed: 0_level_0,table_j,patients,patients,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers,transfers,transfers
Unnamed: 0_level_1,col_j,year_group,gender,death_date,death_age,marital_status,language,insurance,hospital_expire_flag,ethnicity,edregtime,edouttime,dischtime,discharge_location,deathtime,admittime,admission_type,admission_location,outtime,intime,eventtype
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
transfers,careunit,0.997964,0.998714,0.607822,0.615019,0.997765,0.998871,0.998571,0.998953,0.996487,0.980829,0.980824,0.989082,0.993935,0.586182,0.989068,0.996203,0.995754,0.999691,0.999697,0.999652
transfers,eventtype,0.999905,0.999949,0.960007,0.964646,0.999832,0.999971,0.999858,0.999955,0.999819,0.998503,0.998501,0.999232,0.999635,0.97287,0.999239,0.999863,0.999806,0.999943,0.999976,
transfers,intime,0.999826,0.999849,0.870785,0.871142,0.999762,0.99984,0.999798,0.999871,0.999786,0.999079,0.999079,0.999479,0.999738,0.955714,0.999483,0.999848,0.99981,0.999962,,
transfers,outtime,0.999727,0.999753,0.852314,0.851185,0.999631,0.999775,0.999742,0.999784,0.9996,0.998452,0.998455,0.999172,0.999613,0.815146,0.999169,0.999718,0.999724,,,
admissions,admission_location,0.999635,0.999811,0.891783,0.889972,0.99947,0.999899,0.999764,0.999893,0.999275,0.996601,0.996597,0.997634,0.998945,0.917374,0.997623,0.999598,,,,
admissions,admission_type,0.999713,0.99985,0.920345,0.908474,0.999653,0.999909,0.999757,0.999923,0.999565,0.99749,0.997484,0.997965,0.99927,0.934443,0.997953,,,,,
admissions,admittime,0.999066,0.999381,0.750701,0.752574,0.997853,0.999226,0.998262,0.999565,0.996667,0.999153,0.999328,0.99943,0.997074,0.983446,,,,,,
admissions,deathtime,0.932126,0.975499,0.130858,0.125708,0.961944,0.979443,0.976739,0.994148,0.935272,0.962904,0.962904,0.994148,0.94714,,,,,,,
admissions,discharge_location,0.999447,0.999568,0.82653,0.824613,0.999174,0.999657,0.99951,0.999832,0.998801,0.994982,0.994982,0.997061,,,,,,,,
admissions,dischtime,0.999085,0.999391,0.750954,0.752755,0.997838,0.999227,0.998272,0.999553,0.996633,0.998813,0.998844,,,,,,,,,


In [ ]:
mk_ks_plot(val="kl")

Unnamed: 0_level_0,table_j,patients,patients,patients,patients,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,admissions,transfers,transfers,transfers
Unnamed: 0_level_1,col_j,year_group,gender,death_date,death_age,marital_status,language,insurance,hospital_expire_flag,ethnicity,edregtime,edouttime,dischtime,discharge_location,deathtime,admittime,admission_type,admission_location,outtime,intime,eventtype
table_i,col_i,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
transfers,careunit,0.00204,0.001288,0.64522,0.625967,0.00224,0.00113,0.001431,0.001048,0.003526,0.019546,0.019551,0.011038,0.006102,0.705954,0.011052,0.003811,0.004264,0.000309,0.000303,0.000348
transfers,eventtype,9.5e-05,5.1e-05,0.041659,0.03665,0.000168,2.9e-05,0.000142,4.5e-05,0.000181,0.0015,0.001501,0.000769,0.000365,0.027886,0.000762,0.000137,0.000194,5.7e-05,2.4e-05,
transfers,intime,0.000174,0.000151,0.148389,0.147919,0.000238,0.00016,0.000202,0.000129,0.000214,0.000922,0.000922,0.000521,0.000262,0.046338,0.000517,0.000152,0.00019,3.8e-05,,
transfers,outtime,0.000273,0.000247,0.173277,0.174832,0.000369,0.000225,0.000258,0.000216,0.0004,0.001551,0.001548,0.000829,0.000387,0.226774,0.000832,0.000282,0.000276,,,
admissions,admission_location,0.000365,0.000189,0.121349,0.123631,0.000531,0.000101,0.000236,0.000107,0.000726,0.003411,0.003414,0.002372,0.001056,0.090068,0.002383,0.000402,,,,
admissions,admission_type,0.000287,0.00015,0.086549,0.100747,0.000348,9.1e-05,0.000244,7.7e-05,0.000436,0.002517,0.002522,0.002039,0.000731,0.070156,0.002051,,,,,
admissions,admittime,0.000935,0.00062,0.332088,0.328772,0.002151,0.000775,0.001741,0.000435,0.003344,0.000848,0.000672,0.00057,0.002935,0.016833,,,,,,
admissions,deathtime,0.072816,0.025117,6.641887,6.954929,0.039562,0.020988,0.023815,0.005887,0.069208,0.038526,0.038526,0.005887,0.05581,,,,,,,
admissions,discharge_location,0.000553,0.000433,0.209878,0.21269,0.000827,0.000343,0.000491,0.000168,0.0012,0.005043,0.005043,0.002947,,,,,,,,
admissions,dischtime,0.000916,0.00061,0.331639,0.328453,0.002166,0.000773,0.001731,0.000447,0.003378,0.001188,0.001157,,,,,,,,,
