In [None]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]


In [None]:
import os

VIEW = os.getenv("DATASET_VIEW") or "tab_adult"
TABLE = os.getenv("DATASET_TABLE") or "table"
MULTI_PROCESS = (
    os.getenv("MULTI_PROCESS") if os.getenv("MULTI_PROCESS") is not None else True
)

import numpy as np
import pandas as pd
from importlib import reload

from pasteur.transform import TableTransformer

bhr: pd.DataFrame = catalog.load(f"{VIEW}.wrk.bhr_{TABLE}")
trn: TableTransformer = catalog.load(f"{VIEW}.wrk.trn_{TABLE}")
random_state = catalog.load("params:random_state")


In [None]:
# %load_ext line_profiler
# %lprun -f calc_mutual_info calc_mutual_info(data, domain, x, p)

attr_str = trn.get_attributes("bhr", bhr)
cols = list(bhr.columns)

attr = []
for a_cols in attr_str.values():
    attr.append([cols.index(col) for col in a_cols])

data = bhr.to_numpy(dtype="int16")
domain = data.max(axis=0) + 1

In [None]:
def calc_mutual_info(data: np.ndarray, domain: np.ndarray, x: list[int], p: list[int]):
    sub_data = data[:, x + p]
    sub_domain = domain[x + p]
    margin, _ = np.histogramdd(sub_data, sub_domain)
    margin /= margin.sum()
    margin += 1e-24

    x_idx = tuple(range(len(x)))
    p_idx = tuple(range(-len(p), 0))

    x_mar = np.sum(margin, axis=p_idx).reshape(-1)
    p_mar = np.sum(margin, axis=x_idx).reshape(-1)
    j_mar = margin.reshape((len(x_mar), len(p_mar)))

    # margin.reshape((len(x_mar), len(p_mar)))
    # margin.shape|
    return np.sum(j_mar*np.log(j_mar/np.outer(x_mar, p_mar)))

x = [11, 12, 13, 14]
p = [22, 23, 24, 25]
# %lprun -f calc_mutual_info calc_mutual_info(data, domain, x, p)
dt = data.astype(dtype="int16")
%timeit calc_mutual_info(dt, domain, x, p)
# calc_mutual_info(data, domain, x, p)

4.74 ms ± 79.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
from functools import reduce

n = len(data)
d = len(attr)
e1 = 0.3
e2 = 0.7

# Returns the height of a hierarchical attribute
height = lambda a: len(attr[a])
# Returns the domain of a hierarchical attribute at height h (h=0 is max)
dom = lambda a, h: reduce(lambda k, l: k * l, [domain[c] for c in attr[a][:height(a) - h]], 1)
# Picks an item from a set, sets are ordered by default now so it's deterministic
pick_from = lambda V: next(iter(V))

# Sets are tuples that contain the height of each attribute in them, or -1
# if the attribute is not in them
# create_pset = lambda a, h: tuple(h if i == a else -1 for i in range(len(attr)))
add_to_pset = lambda z, a, h: tuple(h if i == a else c for i, c in enumerate(z))
empty_pset = tuple(-1 for _ in range(len(attr)))

def maximal_parent_sets(V: set[int], t: float) -> list[tuple[int, int]]:
    if t < 1:
        return []
    if not V:
        return [empty_pset]

    S = set()
    U = set()
    x = pick_from(V)
    for h in range(height(x)):
        for z in maximal_parent_sets(V - {x}, t / dom(x, h)):
            if z in U:
                continue
            
            U.add(z)
            S.add(add_to_pset(z, x, h))
    
    for z in maximal_parent_sets(V - {x}, t):
        if z not in U:
            S.add(z)
    
    return S

In [None]:
def greedy_bayes(theta):
    A = {a for a in range(len(attr))}
    x1 = pick_from(D)
    t = (n * e2) / (2 * d * theta)

    V = {x1}
    N = {(x1, empty_pset)}

    for i in range(1, d):
        O = {}
        for x in A - V:
            Tx = maximal_parent_sets(V, t / dom(x, 0))
            for t in Tx:
                O.add((x, t))
            if not Tx:
                O.add{(x, empty_pset)}
    
        t = pick_from(O) # FIXME
        V.add(t[0])
        N.add(t)