In [None]:
from pasteur.kedro.ipython import *
register_kedro()

In [None]:
# sensitive
%pipe mimic_tab_admissions.ingest

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np

In [None]:
from pasteur.metadata import Metadata
from pasteur.transform import TableTransformer, Attributes, get_type

view = "mimic_tab_admissions"
trn: TableTransformer = catalog.load(f"{view}.trn.table")
table: pd.DataFrame = catalog.load(f"{view}.wrk.idx_table")

In [None]:
# sensitive
table.head()

In [None]:
attrs = trn["idx"].get_attributes()

In [None]:
def expand_table(attrs: Attributes, table: pd.DataFrame):
    cols = {}
    cols_noncommon = {}
    domains = {}
    for attr in attrs.values():
        for name, col in attr.cols.items():
            col_hier = []
            col_noncommon = []
            col_dom = []

            lvl = col.lvl
            for height in range(lvl.height):
                domain = lvl.get_domain(height)
                col_dom.append(domain)
                
                col_lvl = lvl.get_mapping(height)[table[name]]
                col_lvl = col_lvl.astype(get_type(domain))
                col_hier.append(col_lvl)

                if attr.common > 0:
                    nc = np.where(col_lvl > attr.common, col_lvl - attr.common, 0)
                    col_noncommon.append(nc)

            domains[name] = col_dom
            cols[name] = col_hier
            cols_noncommon[name] = col_noncommon
    
    return cols, cols_noncommon, domains

cols, cols_noncommon, domains = expand_table(attrs, table)

In [None]:
from functools import reduce
from itertools import chain
from typing import NamedTuple

class AttrSelector(NamedTuple):
    common: int
    cols: dict[str, int]

AttrSelectors = list[AttrSelector]

def calc_marginal(
    cols: dict[str, list[np.ndarray]],
    cols_noncommon: dict[str, list[np.ndarray]],
    domains: dict[str, list[int]],
    x: AttrSelector,
    p: AttrSelectors,
    zero_fill: float | None = None,
):
    """Calculates the 1 way and 2 way marginals between the subsection of the 
    hierarchical attribute x and the attributes p(arents)."""
    xp = [x] + p
    
    # Find integer dtype based on domain
    p_dom = 1
    for attr in p:
        for i, (n, h) in enumerate(attr.cols.items()):
            p_dom *= domains[n][h] - (attr.common if i > 0 else 0)
    x_dom = 1
    for i, (n, h) in enumerate(x.cols.items()):
        x_dom *= domains[n][h] - (attr.common if i > 0 else 0)

    dtype = get_type(p_dom*x_dom)

    n = len(next(iter(cols.values()))[0])
    _sum_nd = np.zeros((n,), dtype=dtype)
    _tmp_nd = np.zeros((n,), dtype=dtype)

    mul = 1 
    for attr in reversed(xp):
        for i, (n, h) in enumerate(attr.cols.items()):
            common = attr.common
            if i == 0 or common == 0:
                np.multiply(cols[n][h], mul, out=_tmp_nd)
            else:
                np.multiply(cols_noncommon[n][h], mul, out=_tmp_nd)
            
            np.add(_sum_nd, _tmp_nd, out=_sum_nd)
            mul *= domains[n][h] - (common if i > 0 else 0)

    counts = np.bincount(_sum_nd, minlength=p_dom*x_dom)
    margin = counts.reshape(x_dom, p_dom).astype("float32")

    margin /= margin.sum()
    if zero_fill is not None:
        # Mutual info turns into NaN without this
        margin += zero_fill

    j_mar = margin
    x_mar = np.sum(margin, axis=1)
    p_mar = np.sum(margin, axis=0)

    return j_mar, x_mar, p_mar

x = AttrSelector(1, {"ethnicity": 0})
p = [
    AttrSelector(1, {"hospital_expire_flag": 0}),
    AttrSelector(1, {"language": 0}),
    AttrSelector(1, {"insurance": 0}),
    AttrSelector(0, {"admission_type": 0})
]
%timeit calc_marginal(cols, cols_noncommon, domains, x, p)

1.32 ms ± 2.62 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
scols = []
sdoms = []
for attr in [x] + p:
    for n, h in attr.cols.items():
        scols.append(cols[n][h])
        sdoms.append(domains[n][h])

scols = np.stack(scols).T
%timeit np.histogramdd(scols, bins=sdoms)

29.3 ms ± 42.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
