In [None]:
from pasteur.kedro.ipython import *
register_kedro()

In [None]:
# sensitive
%pipe mimic_tab_admissions.ingest

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
import pandas as pd
import numpy as np

In [None]:
from pasteur.metadata import Metadata
from pasteur.transform import TableTransformer, Attributes, get_dtype

view = "mimic_tab_admissions"
trn: TableTransformer = catalog.load(f"{view}.trn.table")
table: pd.DataFrame = catalog.load(f"{view}.wrk.idx_table")

In [None]:
# sensitive
table.head()

In [None]:
attrs = trn["idx"].get_attributes()

In [None]:
from pasteur.synth.math import expand_table, calc_marginal, calc_marginal_1way, AttrSelector, AttrSelectors

In [None]:
cols, cols_noncommon, domains = expand_table(attrs, table)

In [None]:
x = [
    AttrSelector(0, {"dod_week": 0}),
    AttrSelector(0, {"dod_week": 0}),
    AttrSelector(0, {"dod_week": 0}),
    AttrSelector(0, {"dod_week": 0}),
    AttrSelector(0, {"dod_week": 0}),
    AttrSelector(0, {"dod_week": 0}),
]
%timeit calc_marginal_1way(cols, cols_noncommon, domains, x)

In [None]:
%lprun -f calc_marginal_1way calc_marginal_1way(cols, cols_noncommon, domains, x)

In [None]:
table.columns


[1;35mIndex[0m[1m([0m[1m[[0m[32m'dod_year'[0m, [32m'dod_week'[0m, [32m'dod_day'[0m, [32m'admittime_year'[0m, [32m'admittime_week'[0m,
       [32m'admittime_day'[0m, [32m'admittime_time'[0m, [32m'dischtime_year'[0m, [32m'dischtime_week'[0m,
       [32m'dischtime_day'[0m, [32m'dischtime_time'[0m, [32m'deathtime_year'[0m, [32m'deathtime_week'[0m,
       [32m'deathtime_day'[0m, [32m'deathtime_time'[0m, [32m'admission_type'[0m,
       [32m'admission_location'[0m, [32m'discharge_location'[0m, [32m'insurance'[0m, [32m'language'[0m,
       [32m'marital_status'[0m, [32m'ethnicity'[0m, [32m'hospital_expire_flag'[0m, [32m'gender'[0m[1m][0m,
      [33mdtype[0m=[32m'object'[0m[1m)[0m


In [None]:
import numexpr as ne

def calc_marginal_1way2(
    cols: dict[str, list[np.ndarray]],
    cols_noncommon: dict[str, list[np.ndarray]],
    domains: dict[str, list[int]],
    x: AttrSelectors,
    zero_fill: float | None = None,
):
    """Calculates the 1 way marginal of the subsections of attributes x"""

    # Find integer dtype based on domain
    dom = 1
    for attr in x:
        for i, (n, h) in enumerate(attr.cols.items()):
            dom *= domains[n][h] - (attr.common if i > 0 else 0)
    dtype = get_dtype(dom)

    n = len(next(iter(cols.values()))[0])
    _sum_nd = np.zeros((n,), dtype=np.int32)

    mul = 1
    for attr in reversed(x):
        for i, (n, h) in enumerate(attr.cols.items()):
            common = attr.common
            col = cols[n][h]
            col_nc = cols_noncommon[n][h]
            if i == 0 or common == 0:
                ne.evaluate("_sum_nd + col*mul", out=_sum_nd)
            else:
                ne.evaluate("_sum_nd + col_nc*mul", out=_sum_nd)

            mul *= domains[n][h] - (common if i > 0 else 0)

    counts = np.bincount(_sum_nd, minlength=dom)
    margin = counts.astype("float32")
    margin /= margin.sum()
    if zero_fill is not None:
        # Mutual info turns into NaN without this
        margin += zero_fill

    return margin.reshape(-1)

In [None]:
# %lprun -f calc_marginal_1way calc_marginal_1way(cols, cols_noncommon, domains, x)
%timeit calc_marginal_1way2(cols, cols_noncommon, domains, x)