In [None]:
from pasteur.kedro.ipython import *

register_kedro()

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
import pandas as pd
import numpy as np


In [None]:
from pasteur.metadata import Metadata
from pasteur.transform import TableTransformer, Attributes, get_dtype

view = "mimic_tab_admissions"
trn: TableTransformer = catalog.load(f"{view}.trn.table")
table: pd.DataFrame = catalog.load(f"{view}.wrk.idx_table")


In [None]:
# sensitive
table.head()


In [None]:
attrs = trn["idx"].get_attributes()


In [None]:
# from pasteur.synth.math import expand_table, calc_marginal, calc_marginal_1way, AttrSelector, AttrSelectors
# cols, cols_noncommon, domains = expand_table(attrs, table)
from functools import reduce


In [None]:
from typing import Callable

col_idx = {}
groups = []
heights = []
common = []
domains = []
for i, a in enumerate(attrs.values()):
    for n, c in a.cols.items():
        col_idx[n] = i
        groups.append(i)
        heights.append(c.lvl.height)
        common.append(a.common)
        domains.append([c.get_domain(h) for h in range(c.height)])

n = len(col_idx)
empty_pset = tuple(-1 for _ in range(n))


def add_to_pset(s, x, h):
    s = list(s)
    s[x] = h
    return tuple(s)


In [None]:
def maximal_parents(V: tuple[int], tau: float, _pgroups=set()) -> list[tuple[int]]:
    if tau < 1:
        return []
    if not V:
        return [empty_pset]
    x = V[0]
    V = V[1:]

    S = []
    U = set()

    for h in range(heights[x]):
        dom = domains[x][h]
        g = groups[x]
        if g in _pgroups:
            dom -= common[x]
            _pgroups_with_x = _pgroups
        else:
            _pgroups_with_x = _pgroups.copy()
            _pgroups_with_x.add(g)

        for z in maximal_parents(V, tau / dom, _pgroups_with_x):
            if z not in U:
                U.add(z)
                S.append(add_to_pset(z, x, h))

    for z in maximal_parents(V, tau, _pgroups):
        if z not in U:
            S.append(z)

    return S


In [None]:
psets = maximal_parents(tuple(range(n)), 100000000)
len(psets)


[1;36m400866[0m


In [None]:
%timeit maximal_parents(tuple(range(n)), 10000)

1.11 s ± 7.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%lprun -f maximal_parents maximal_parents(tuple(range(n)), 10000)

Timer unit: 1e-06 s

Total time: 8.71678 s
File: /tmp/ipykernel_2390497/3433675883.py
Function: maximal_parents at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def maximal_parents(V: tuple[int], tau: float, _pgroups=set()) -> list[tuple[int]]:
     2   1246985     552431.0      0.4      6.3      if tau < 1:
     3    307524     119424.0      0.4      1.4          return []
     4    939461     366322.0      0.4      4.2      if not V:
     5    324176     133626.0      0.4      1.5          return [empty_pset]
     6    615285     259700.0      0.4      3.0      x = V[0]
     7    615285     289405.0      0.5      3.3      V = V[1:]
     8                                           
     9    615285     238856.0      0.4      2.7      S = []
    10    615285     275013.0      0.4      3.2      U = set()
    11                                           
    12   1246984     624042.0      0.5      7.2      for h in

In [None]:
from pasteur.synth.privbayes import greedy_bayes, PrivBayesSynth

In [None]:
s = PrivBayesSynth(e2=10000)
s.bake(attrs={"table": attrs}, data={"table": table}, ids=None)

>>>>>>>  Finding Nodes: 100%|[34m##########################################################################[0m| 23/23 [00:00<00:00,  1.00it/s][0m
[0m[34mINFO    [0m Bayesian Network Tree:                                                                                        [2mprivbayes.py[0m[2m:[0m[2m559[0m
[0m         [1m([0mPrivBayes [33me1[0m=[1;36m0[0m[1;36m.30[0m, [33me2[0m=[1;36m10000[0m[1;36m.00[0m, [33mtheta[0m=[1;36m4[0m[1;36m.00[0m, available [33mt[0m=[1;36m41947[0m[1;36m.20[0m[1m)[0m                                            [2m                [0m
[0m         ┌─────────────────────┬─────┬──────────┬─────────────────────────────────────────────────────────┐            [2m                [0m
[0m         │           Attribute │ Dom │ Avail. t │ Parents                                                 │            [2m                [0m
[0m         ├─────────────────────┼─────┼──────────┼─────────────────────────────────────────

In [None]:
s.fit(data={"table": table}, ids=None)



In [None]:
node_names = [n.col for n in s.nodes]
for c, m in zip(node_names, s.marginals):
    print(c, m.shape, m.shape[0]*m.shape[1])

admittime_year (128, 1) 128
insurance (3, 128) 384
marital_status (5, 384) 1920
admission_type (9, 1920) 17280
admission_location (12, 135) 1620
discharge_location (14, 1620) 22680
dischtime_time (48, 210) 10080
admittime_time (48, 210) 10080
admittime_week (52, 210) 10920
dischtime_week (52, 468) 24336
dischtime_day (7, 2704) 18928
admittime_day (7, 4410) 30870
ethnicity (8, 1620) 12960
gender (2, 17280) 34560
language (2, 18720) 37440
dod_year (129, 256) 33024
dod_day (8, 4032) 32256
dod_week (53, 520) 27560
hospital_expire_flag (2, 26460) 52920
deathtime_time (49, 784) 38416
deathtime_day (8, 6720) 53760
deathtime_year (129, 208) 26832
deathtime_week (53, 1344) 71232
dischtime_year (128, 312) 39936


In [None]:
s.nodes[-1]


[1;35mNode[0m[1m([0m
    [33mattr[0m=[32m'dischtime'[0m,
    [33mcol[0m=[32m'dischtime_year'[0m,
    [33mdomain[0m=[1;36m128[0m,
    [33mp[0m=[1m{[0m
        [32m'dod'[0m: [1;35mAttrSelector[0m[1m([0m[33mcommon[0m=[1;36m1[0m, [33mcols[0m=[1m{[0m[32m'dod_year'[0m: [1;36m1[0m, [32m'dod_week'[0m: [1;36m1[0m, [32m'dod_day'[0m: [1;36m1[0m[1m}[0m[1m)[0m,
        [32m'marital_status'[0m: [1;35mAttrSelector[0m[1m([0m[33mcommon[0m=[1;36m1[0m, [33mcols[0m=[1m{[0m[32m'marital_status'[0m: [1;36m0[0m[1m}[0m[1m)[0m
    [1m}[0m
[1m)[0m


In [None]:
attrs["marital_status"].cols["marital_status"].lvl.size

[1;36m5[0m
