In [None]:
from pasteur.kedro.ipython import *
register_kedro()

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
import pandas as pd
import numpy as np

In [None]:
from pasteur.metadata import Metadata
from pasteur.transform import TableTransformer, Attributes, get_dtype

view = "mimic_tab_admissions"
trn: TableTransformer = catalog.load(f"{view}.trn.table")
table: pd.DataFrame = catalog.load(f"{view}.wrk.idx_table")

In [None]:
# sensitive
table.head()

In [None]:
attrs = trn["idx"].get_attributes()

In [None]:
# from pasteur.synth.math import expand_table, calc_marginal, calc_marginal_1way, AttrSelector, AttrSelectors
# cols, cols_noncommon, domains = expand_table(attrs, table)
from functools import reduce

In [None]:
from typing import Callable

col_idx = {}
groups = []
heights = []
common = []
domains = []
for i, a in enumerate(attrs.values()):
    for n, c in a.cols.items():
        col_idx[n] = i
        groups.append(i)
        heights.append(c.lvl.height)
        common.append(a.common)
        domains.append([c.get_domain(h) for h in range(c.height)])

n = len(col_idx)

In [None]:
class Pset(tuple[int]):
    """Pset is a decorated tuple that holds a number for each column, -1 if
    the column is not included and its height if it is.
    
    It's a decorated tuple because tuples are hashable and the order of columns
    doesn't matter."""

    def pop(self, i: int) -> "Pset":
        return Pset(v if j != i else -1 for j, v in enumerate(self))
    
    def pop_first(self):
        for i, v in enumerate(self):
            if v != -1:
                return i, self.pop(i)
        return None, self.empty(len(self))
    
    def add(self, i: int, h: int = 0):
        return Pset(v if j != i else h for j, v in enumerate(self))

    def contains(self, i: int):
        return self[i] != -1
    
    @staticmethod
    def empty(n: int):
        return Pset(-1 for _ in range(n))
    
    @staticmethod
    def all(n: int):
        return Pset(0 for _ in range(n))

In [None]:
def find_tau(P: Pset, tau: float):
    inc_groups = set()
    for i, h in enumerate(P):
        if h == -1:
            continue

        dom = domains[i][h]
        if groups[i] in inc_groups:
            dom -= common[i]
        inc_groups.add(groups[i])
        tau /= dom
    
    return tau

def maximal_parents(V: Pset, P: Pset, tau: float) -> list[Pset]:
    x, V = V.pop_first()

    if find_tau(P, tau) < 1 or x is None:
        return [Pset.empty(n)]

    S = []
    U = set()

    for h in range(heights[x]):
        for z in maximal_parents(V, P.add(x, h), tau):
            if z not in U:
                U.add(z)
                S.append(z.add(x, h))

    for z in maximal_parents(V, P, tau):
        if z not in U:
            S.append(z)

    return S

In [None]:
psets = maximal_parents(Pset.all(n).pop_first()[1], Pset.empty(n), 1000)
len(psets)

[1;36m3470[0m


In [None]:
%lprun -f maximal_parents maximal_parents(Pset.all(n).pop_first()[1], Pset.empty(n), 1000)

Timer unit: 1e-06 s

Total time: 11.9732 s
File: /tmp/ipykernel_2329787/1750792131.py
Function: maximal_parents at line 15

Line #      Hits         Time  Per Hit   % Time  Line Contents
    15                                           def maximal_parents(V: Pset, P: Pset, tau: float) -> list[Pset]:
    16    150159    3597120.0     24.0     30.0      x, V = V.pop_first()
    17                                           
    18    150159    2447219.0     16.3     20.4      if find_tau(P, tau) < 1 or x is None:
    19     76547    1105020.0     14.4      9.2          return [Pset.empty(n)]
    20                                           
    21     73612      27624.0      0.4      0.2      S = []
    22     73612      32304.0      0.4      0.3      U = set()
    23                                           
    24    150158      72663.0      0.5      0.6      for h in range(heights[x]):
    25    230973    1431580.0      6.2     12.0          for z in maximal_parents(V, P.add(x, h), ta