In [None]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]


In [None]:
import os

VIEW = os.getenv("DATASET_VIEW") or "tab_adult"
TABLE = os.getenv("DATASET_TABLE") or "table"
MULTI_PROCESS = (
    os.getenv("MULTI_PROCESS") if os.getenv("MULTI_PROCESS") is not None else True
)

import numpy as np
import pandas as pd
from importlib import reload

bin: pd.DataFrame = catalog.load(f"{VIEW}.wrk.bin_{TABLE}")
random_state = catalog.load("params:random_state")


In [None]:
bin.head()

In [None]:
import pasteur.synth.privbayes as pb
reload(pb)

e = 1
beta = 0.1
e1 = beta * e
e2 = (1 - beta) * e

d = len(bin.keys())
n = len(bin)

noise_ratio = 4

k = pb.calc_k(d, n, e2, noise_ratio)

In [None]:
assert all(dtype.name == "bool" for dtype in bin.dtypes)

In [None]:
data = bin[["age_0", "workclass_0", "workclass_1", "workclass_3"]]

marginal = data.groupby(list(data.keys())).size()
marginal

In [None]:
from scipy.stats import laplace

noise_scale = 2*(d - k) / e2
noise = laplace.rvs(loc=0, scale=noise_scale, size=marginal.shape)
noise

In [None]:
noisy_marginal = marginal + noise
noisy_marginal = noisy_marginal.clip(0)
noisy_marginal = noisy_marginal / noisy_marginal.sum()
noisy_marginal, noisy_marginal.sum()

In [None]:
noise_scale

In [None]:
len(bin) / len(marginal)

In [None]:
def calc_noisy_marginal(data: pd.DataFrame):
    marginal = data.groupby(list(data.keys())).size()
    noise_scale = 2*(d - k) / e2
    noise = laplace.rvs(loc=0, scale=noise_scale, size=marginal.shape)

In [None]:
n, d = bin.shape
n, d

In [None]:
x = "age_0"
p = ["workclass_0", "workclass_1", "workclass_3"]

joint = bin.groupby([x] + p).size()
joint

In [None]:
x_marginal = joint.groupby(x).sum()

x_marginal

In [None]:
p_marginal = joint.groupby(p).sum()

p_marginal

In [None]:
contigency = pd.pivot(pd.DataFrame(bin.groupby([x] + p).size()).reset_index(), p, x).to_numpy(dtype="float")
contigency /= contigency.sum()
contigency

In [None]:
x_marginal = contigency.sum(axis=0)
x_marginal

In [None]:
p_marginal = contigency.sum(axis=1)
p_marginal

In [None]:
(contigency*np.log2(contigency/np.outer(p_marginal, x_marginal))).sum()

In [None]:
%load_ext line_profiler

In [None]:
par = [bin[pi] for pi in p]
chld = bin[x]
%timeit pd.crosstab(par, chld)

7.16 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%%timeit 
joint_dist = bin.groupby([x] + p).size()
contigency_pd = pd.DataFrame(joint_dist).reset_index()
contigency_pd = pd.pivot(contigency_pd, p, x)

5.63 ms ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit joint_dist = bin.groupby([x] + p).size()


1.65 ms ± 9.77 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
def calc_mutual_information(data: pd.DataFrame, x: str, p: list[str]):
    joint_dist = data.groupby([x] + p).size()
    contigency_pd = pd.DataFrame(joint_dist).reset_index()
    contigency_pd = pd.pivot(contigency_pd, p, x)

    contigency = contigency_pd.to_numpy(dtype="float")
    cg = contigency / contigency.sum()

    # Marginals
    x_mar = contigency.sum(axis=0)
    x_mar /= x_mar.sum()
    p_mar = contigency.sum(axis=1)
    p_mar /= p_mar.sum()

    return np.sum(cg*np.log2(cg/np.outer(p_mar, x_mar)))


%lprun -f calc_mutual_information calc_mutual_information(bin, x, p)
# calc_mutual_information(bin, x, p)

Timer unit: 1e-06 s

Total time: 0.011686 s
File: /tmp/ipykernel_4168668/872837184.py
Function: calc_mutual_information at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def calc_mutual_information(data: pd.DataFrame, x: str, p: list[str]):
     2         1       3764.0   3764.0     32.2      joint_dist = data.groupby([x] + p).size()
     3         1       2837.0   2837.0     24.3      contigency_pd = pd.DataFrame(joint_dist).reset_index()
     4         1       4965.0   4965.0     42.5      contigency_pd = pd.pivot(contigency_pd, p, x)
     5                                           
     6         1         40.0     40.0      0.3      contigency = contigency_pd.to_numpy(dtype="float")
     7         1         14.0     14.0      0.1      cg = contigency / contigency.sum()
     8                                           
     9                                               # Marginals
    10         1          7

In [None]:
import random
from itertools import combinations as n_choose_k

random.seed(0)

def greedy_bayes(data: pd.DataFrame, k: int):
    total_marginal = 0
    N = {}
    V = set()
    A = set(data.keys())
    d = len(data.keys())

    # Add root randomly
    x = random.sample(A, k=1)[0]
    N[x] = []
    V.add(x)

    for _ in range(d - 1):
        O = []

        for x in A - V:
            if len(V) > k:
                O += [(x, c) for c in n_choose_k(V, k)]
            else:
                O += [(x, V.copy())]

        x, p = O[0]
        N[x] = p
        V.add(x)
        print(f"{len(V)}: {len(O)}")
        total_marginal += len(O)

    print(total_marginal)
    return N, V

greedy_bayes(bin, 2);

2: 145
3: 144
4: 429
5: 852
6: 1410
7: 2100
8: 2919
9: 3864
10: 4932
11: 6120
12: 7425
13: 8844
14: 10374
15: 12012
16: 13755
17: 15600
18: 17544
19: 19584
20: 21717
21: 23940
22: 26250
23: 28644
24: 31119
25: 33672
26: 36300
27: 39000
28: 41769
29: 44604
30: 47502
31: 50460
32: 53475
33: 56544
34: 59664
35: 62832
36: 66045
37: 69300
38: 72594
39: 75924
40: 79287
41: 82680
42: 86100
43: 89544
44: 93009
45: 96492
46: 99990
47: 103500
48: 107019
49: 110544
50: 114072
51: 117600
52: 121125
53: 124644
54: 128154
55: 131652
56: 135135
57: 138600
58: 142044
59: 145464
60: 148857
61: 152220
62: 155550
63: 158844
64: 162099
65: 165312
66: 168480
67: 171600
68: 174669
69: 177684
70: 180642
71: 183540
72: 186375
73: 189144
74: 191844
75: 194472
76: 197025
77: 199500
78: 201894
79: 204204
80: 206427
81: 208560
82: 210600
83: 212544
84: 214389
85: 216132
86: 217770
87: 219300
88: 220719
89: 222024
90: 223212
91: 224280
92: 225225
93: 226044
94: 226734
95: 227292
96: 227715
97: 228000
98: 228144
99

In [None]:
set([1,2,3]) - set([1,2,3])

In [None]:
joint_dist = bin.groupby([x] + p).size()
joint_dist

In [None]:
joint_dist.xs(False, level="workclass_3").array