In [None]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]


In [None]:
import os

VIEW = os.getenv("DATASET_VIEW") or "tab_adult"
TABLE = os.getenv("DATASET_TABLE") or "table"
MULTI_PROCESS = (
    os.getenv("MULTI_PROCESS") if os.getenv("MULTI_PROCESS") is not None else True
)

import numpy as np
import pandas as pd
from importlib import reload

from pasteur.transform import TableTransformer

bhr: pd.DataFrame = catalog.load(f"{VIEW}.wrk.bhr_{TABLE}")
trn: TableTransformer = catalog.load(f"{VIEW}.wrk.trn_{TABLE}")
random_state = catalog.load("params:random_state")


In [None]:
# %load_ext line_profiler
# %lprun -f calc_mutual_info calc_mutual_info(data, domain, x, p)

attr_str = trn.get_attributes("bhr", bhr)
cols = list(bhr.columns)

attr = []
for a_cols in attr_str.values():
    attr.append([cols.index(col) for col in a_cols])

data = bhr.to_numpy(dtype="int16")
domain = data.max(axis=0) + 1

In [None]:
def sens_mutual_info(n: float):
    """Provides the the log2 sensitivity of the mutual information function for a given
    dataset size (n)."""
    return 2 / n * np.log2((n + 1) / 2) + (n - 1) / n * np.log2((n + 1) / (n - 1))


def calc_mutual_info(data: np.ndarray, domain: np.ndarray, x: list[int], p: list[int]):
    """Calculates mutual information I(X,P) for the provided data using log2."""
    sub_data = data[:, x + p]
    sub_domain = domain[x + p]
    margin, _ = np.histogramdd(sub_data, sub_domain)
    margin /= margin.sum()
    margin += 1e-24

    x_idx = tuple(range(len(x)))
    p_idx = tuple(range(-len(p), 0))

    x_mar = np.sum(margin, axis=p_idx).reshape(-1)
    p_mar = np.sum(margin, axis=x_idx).reshape(-1)
    j_mar = margin.reshape((len(x_mar), len(p_mar)))

    # margin.reshape((len(x_mar), len(p_mar)))
    # margin.shape|
    return np.sum(j_mar * np.log2(j_mar / np.outer(x_mar, p_mar)))


x = [11, 12, 13, 14]
p = [22, 23, 24, 25]
# %lprun -f calc_mutual_info calc_mutual_info(data, domain, x, p)
# %timeit calc_mutual_info(data, domain, x, p)
# calc_mutual_info(data, domain, x, p)


In [None]:
from functools import reduce

n = len(data)
d = len(attr)
e1 = 0.3
e2 = 0.7

# Returns the height of a hierarchical attribute
height = lambda a: len(attr[a])
# Returns the domain of a hierarchical attribute at height h (h=0 is max)
dom = lambda a, h: reduce(lambda k, l: k * l, [domain[c] for c in attr[a][:height(a) - h]], 1)
# Picks an item from a set, sets are ordered by default now so it's deterministic
pick_from = lambda V: next(iter(V))

# Sets are tuples that contain the height of each attribute in them, or -1
# if the attribute is not in them
# create_pset = lambda a, h: tuple(h if i == a else -1 for i in range(len(attr)))
add_to_pset = lambda z, a, h: tuple(h if i == a else c for i, c in enumerate(z))
empty_pset = tuple(-1 for _ in range(len(attr)))

def maximal_parent_sets(V: set[int], t: float) -> list[tuple[int, int]]:
    if t < 1:
        return set()
    if not V:
        return set([empty_pset])

    S = set()
    U = set()
    x = pick_from(V)
    for h in range(height(x)):
        for z in maximal_parent_sets(V - {x}, t / dom(x, h)):
            if z in U:
                continue
            
            U.add(z)
            S.add(add_to_pset(z, x, h))
    
    for z in maximal_parent_sets(V - {x}, t):
        if z not in U:
            S.add(z)
    
    return S

ex_psets = maximal_parent_sets({a for a in range(len(attr))}, 10000)
len(ex_psets)

In [None]:
def greedy_bayes(theta, cb: callable=pick_from):
    A = {a for a in range(len(attr))}
    x1 = pick_from(A)
    t = (n * e2) / (2 * d * theta)

    V = {x1}
    N = {(x1, empty_pset)}

    for i in range(1, d):
        O = set()
        for x in A - V:
            psets = maximal_parent_sets(V, t / dom(x, 0))
            for pset in psets:
                O.add((x, pset))
            if not psets:
                O.add((x, empty_pset))

        print(f"{i:2d}: Calculating {len(O):5d} marginals.")
        node = pick_from(O) # FIXME
        V.add(node[0])
        N.add(node)
    
    return N

# greedy_bayes(0.01)

In [None]:
%load_ext line_profiler

In [None]:
def calc_mutual_info_list(candidates: list[tuple[int, tuple[int]]]):
    mutual_infos = []
    for candidate in candidates:
        x, pset = candidate

        x_cols = attr[x]
        p_cols = []
        for p, h in enumerate(pset):
            if h == -1:
                continue

            p_cols.extend(attr[p][:height(p) - h])
        
        info = calc_mutual_info(data, domain, x_cols, p_cols)
        mutual_infos.append(info)

    return mutual_infos
        
# %lprun -f calc_mutual_info_list calc_mutual_info_list([(0, add_to_pset(empty_pset, 1, 0)), (2, add_to_pset(empty_pset, 1, 0))])

In [None]:
def exponential_pick_candidate(candidates: set[tuple[int, tuple[int]]]):
    candidates = list(candidates)
    mutual_infos = np.array(calc_mutual_info_list(candidates))
    delta = (d - 1)*sens_mutual_info(n)/e1
    
    I = np.array(mutual_infos)
    p = np.exp(I / delta)
    p /= p.sum()

    choice = np.random.choice(len(candidates), size=1, p=p)[0]
    
    return candidates[choice]

# np.random.seed(0)
# exponential_pick_candidate({(0, add_to_pset(empty_pset, 1, 0)), (2, add_to_pset(empty_pset, 1, 0))})

In [None]:
e1 = 1
e2 = 3
theta = 4
np.random.seed(0)
tree = greedy_bayes(theta, cb=exponential_pick_candidate)

tree

 1: Calculating    13 marginals.
 2: Calculating    55 marginals.
 3: Calculating    64 marginals.
 4: Calculating    96 marginals.
 5: Calculating   100 marginals.
 6: Calculating   103 marginals.
 7: Calculating    93 marginals.
 8: Calculating    71 marginals.
 9: Calculating   135 marginals.
10: Calculating    49 marginals.
11: Calculating    63 marginals.
12: Calculating    74 marginals.
13: Calculating    21 marginals.


In [None]:
def print_tree(tree: set[tuple[int, tuple[int]]]):
    attr_names = list(attr_str.keys())
    s = f"//////////\n///{'_Bayesian Tree_':>20s}"

    for a, pset in reversed(list(tree)):
        a_name = attr_names[a]
        s += f"\n///{a_name:>20s}: "

        for p, h in enumerate(pset):
            if h == -1:
                continue

            p_name = attr_names[p]
            s += f"{p_name:>15s}.{h}"

    return s


print(print_tree(tree))


//////////
///     _Bayesian Tree_
///                 age: 
///        capital-loss:             age.5         fnlwgt.5
///      native-country:          fnlwgt.4 hours-per-week.5
///           workclass:             age.4 marital-status.0
///              fnlwgt:             age.5            sex.0
///       education-num:             age.5   capital-gain.4
///        relationship:    capital-gain.4 hours-per-week.5      workclass.0
///        capital-gain:             age.3
///                race:    capital-gain.2 marital-status.0
///      marital-status:    capital-gain.4     occupation.0
///      hours-per-week:    capital-gain.4            sex.0
///           education:             age.5   capital-gain.2
///                 sex:    capital-gain.2     occupation.0
///          occupation:             age.4   capital-gain.3
