In [None]:
# Removes lint errors from VS Code
from typing import Dict, TYPE_CHECKING, Tuple, List
from functools import reduce

if TYPE_CHECKING:
    import kedro

    catalog: kedro.io.data_catalog.DataCatalog
    session: kedro.framework.session.session.KedroSession
    catalog: kedro.io.data_catalog.DataCatalog
    pipelines: Dict[str, kedro.pipeline.pipeline.Pipeline]

import numpy as np
import pandas as pd


In [None]:
from pasteur.transform import TableTransformer

table: pd.DataFrame = catalog.load("mimic_tab_admissions.wrk.bhr_table")
trn: TableTransformer = catalog.load("mimic_tab_admissions.wrk.trn_table")

In [None]:
attr = trn["bhr"].get_attributes()

In [None]:
%load_ext line_profiler

In [None]:
def calc_marginal(
    data: np.ndarray,
    domain: np.ndarray,
    x: list[int],
    p: list[int],
    rm_zeros: bool = False,
):
    """Calculates the 1 way and 2 way marginals between the set of columns in x
    and the set of columns in p."""

    sub_data = data[:, x + p]
    sub_domain = domain[x + p]

    margin, _ = np.histogramdd(sub_data, sub_domain)
    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    x_idx = tuple(range(len(x)))
    p_idx = tuple(range(-len(p), 0))

    x_mar = np.sum(margin, axis=p_idx).reshape(-1)
    p_mar = np.sum(margin, axis=x_idx).reshape(-1)
    j_mar = margin.reshape((len(x_mar), len(p_mar)))

    return j_mar, x_mar, p_mar

In [None]:
# synth = PrivBayesSynth()
# synth.bake({"table": attr}, {"table": table}, {})

data = table.to_numpy(dtype="uint16")
domain = data.max(axis=0) + 1

# %lprun -f calc_marginal calc_marginal(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15])
j_mar1, x_mar1, p_mar1 = calc_marginal(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15])

In [None]:
def calc_marginal_2(
    data: np.ndarray,
    domain: np.ndarray,
    x: list[int],
    p: list[int],
    rm_zeros: bool = False,
):
    """Calculates the 1 way and 2 way marginals between the set of columns in x
    and the set of columns in p."""

    sub_data = data[:, x + p]
    sub_domain = domain[x + p]

    x_dom = reduce(lambda a, b: a*b, domain[x], 1)
    p_dom = reduce(lambda a, b: a*b, domain[p], 1)
    idx = np.zeros((len(sub_data)), dtype="int64")
    for col, dom in zip(sub_data.transpose(), sub_domain):
        idx = idx*dom + col
    
    counts = np.bincount(idx, minlength=x_dom*p_dom)
    margin = counts.reshape(x_dom, p_dom).astype("float")

    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    j_mar = margin
    x_mar = np.sum(margin, axis=1)
    p_mar = np.sum(margin, axis=0)

    return j_mar, x_mar, p_mar
    
j_mar2, x_mar2, p_mar2 = calc_marginal_2(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15])

In [None]:
assert np.all(j_mar1 == j_mar2)
assert np.all(x_mar1 == x_mar2)
assert np.all(p_mar1 == p_mar2)

In [None]:
%timeit calc_marginal(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

In [None]:
%timeit calc_marginal_2(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

13.9 ms ± 9.81 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%lprun -f calc_marginal_2 calc_marginal_2(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

Timer unit: 1e-06 s

Total time: 0.013973 s
File: /tmp/ipykernel_9824/4051514708.py
Function: calc_marginal_2 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def calc_marginal_2(
     2                                               data: np.ndarray,
     3                                               domain: np.ndarray,
     4                                               x: list[int],
     5                                               p: list[int],
     6                                               rm_zeros: bool = False,
     7                                           ):
     8                                               """Calculates the 1 way and 2 way marginals between the set of columns in x
     9                                               and the set of columns in p."""
    10                                           
    11         1       1719.0   1719.0     12.3      sub_data = data[:, x +

In [None]:
import numexpr as ne

def calc_marginal_3(
    data: np.ndarray,
    domain: np.ndarray,
    x: list[int],
    p: list[int],
    rm_zeros: bool = False,
):
    """Calculates the 1 way and 2 way marginals between the set of columns in x
    and the set of columns in p."""

    xp = x + p
    x_dom = reduce(lambda a, b: a*b, domain[x], 1)
    p_dom = reduce(lambda a, b: a*b, domain[p], 1)

    idx = np.zeros((len(data)), dtype="int64")
    for col in xp:
        idx = domain[col]*idx + data[:, col]
    
    counts = np.bincount(idx, minlength=x_dom*p_dom)
    margin = counts.reshape(x_dom, p_dom).astype("float")

    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    j_mar = margin
    x_mar = np.sum(margin, axis=1)
    p_mar = np.sum(margin, axis=0)

    return j_mar, x_mar, p_mar
    
j_mar3, x_mar3, p_mar3 = calc_marginal_3(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15])

In [None]:
%timeit calc_marginal_3(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

12 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%lprun -f calc_marginal_3 calc_marginal_3(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

Timer unit: 1e-06 s

Total time: 0.012281 s
File: /tmp/ipykernel_9824/3634434575.py
Function: calc_marginal_3 at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
     3                                           def calc_marginal_3(
     4                                               data: np.ndarray,
     5                                               domain: np.ndarray,
     6                                               x: list[int],
     7                                               p: list[int],
     8                                               rm_zeros: bool = False,
     9                                           ):
    10                                               """Calculates the 1 way and 2 way marginals between the set of columns in x
    11                                               and the set of columns in p."""
    12                                           
    13         1          2.0      2.0      0.0      xp = x + p
    14     

In [None]:
assert np.all(j_mar1 == j_mar3)
assert np.all(x_mar1 == x_mar3)
assert np.all(p_mar1 == p_mar3)

In [None]:
def calc_marginal_4(
    data: np.ndarray,
    domain: np.ndarray,
    x: list[int],
    p: list[int],
    rm_zeros: bool = False,
):
    """Calculates the 1 way and 2 way marginals between the set of columns in x
    and the set of columns in p."""

    xp = x + p
    x_dom = reduce(lambda a, b: a*b, domain[x], 1)
    p_dom = reduce(lambda a, b: a*b, domain[p], 1)

    idx = np.zeros((len(data)), dtype="int64")
    mul = 1
    for col in reversed(xp):
        # idx += mul*data[:, col]
        np.add(idx, mul*data[:, col], out=idx)
        mul *= domain[col]
    
    counts = np.bincount(idx, minlength=x_dom*p_dom)
    margin = counts.reshape(x_dom, p_dom).astype("float")

    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    j_mar = margin
    x_mar = np.sum(margin, axis=1)
    p_mar = np.sum(margin, axis=0)

    return j_mar, x_mar, p_mar
    
j_mar4, x_mar4, p_mar4 = calc_marginal_4(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15])

In [None]:
%timeit calc_marginal_4(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

8.74 ms ± 12.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%lprun -f calc_marginal_4 calc_marginal_4(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

Timer unit: 1e-06 s

Total time: 0.009247 s
File: /tmp/ipykernel_9824/1461518459.py
Function: calc_marginal_4 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def calc_marginal_4(
     2                                               data: np.ndarray,
     3                                               domain: np.ndarray,
     4                                               x: list[int],
     5                                               p: list[int],
     6                                               rm_zeros: bool = False,
     7                                           ):
     8                                               """Calculates the 1 way and 2 way marginals between the set of columns in x
     9                                               and the set of columns in p."""
    10                                           
    11         1          1.0      1.0      0.0      xp = x + p
    12     

In [None]:
assert np.all(j_mar1 == j_mar4)
assert np.all(x_mar1 == x_mar4)
assert np.all(p_mar1 == p_mar4)

In [None]:
def calc_marginal_5(
    data: np.ndarray,
    domain: np.ndarray,
    x: list[int],
    p: list[int],
    rm_zeros: bool = False,
):
    """Calculates the 1 way and 2 way marginals between the set of columns in x
    and the set of columns in p."""

    xp = x + p
    x_dom = reduce(lambda a, b: a*b, domain[x], 1)
    p_dom = reduce(lambda a, b: a*b, domain[p], 1)

    idx = np.zeros((len(data)), dtype="uint16")
    tmp = np.empty((len(data)), dtype="uint16")
    mul = 1
    for col in reversed(xp):
        # idx += mul*data[:, col]
        np.add(idx, np.multiply(mul,data[:, col],out=tmp), out=idx)
        mul *= domain[col]
    
    counts = np.bincount(idx, minlength=x_dom*p_dom)
    margin = counts.reshape(x_dom, p_dom).astype("float")

    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    j_mar = margin
    x_mar = np.sum(margin, axis=1)
    p_mar = np.sum(margin, axis=0)

    return j_mar, x_mar, p_mar
    
j_mar5, x_mar5, p_mar5 = calc_marginal_5(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15])

In [None]:
%timeit calc_marginal_5(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

4.82 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%lprun -f calc_marginal_4 calc_marginal_4(data, domain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

Timer unit: 1e-06 s

Total time: 0.009334 s
File: /tmp/ipykernel_9824/1461518459.py
Function: calc_marginal_4 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def calc_marginal_4(
     2                                               data: np.ndarray,
     3                                               domain: np.ndarray,
     4                                               x: list[int],
     5                                               p: list[int],
     6                                               rm_zeros: bool = False,
     7                                           ):
     8                                               """Calculates the 1 way and 2 way marginals between the set of columns in x
     9                                               and the set of columns in p."""
    10                                           
    11         1          2.0      2.0      0.0      xp = x + p
    12     

In [None]:
assert np.all(j_mar1 == j_mar5)
assert np.all(x_mar1 == x_mar5)
assert np.all(p_mar1 == p_mar5)

In [None]:
def calc_marginal_1way(
    data: np.ndarray, domain: np.ndarray, x: list[int], rm_zeros: bool = False
):
    """Calculates the 1 way marginal of x, returned as a 1D array."""

    x_dom = reduce(lambda a, b: a * b, domain[x], 1)

    idx = np.zeros((len(data)), dtype="int64")
    mul = 1
    for col in reversed(x):
        # idx += mul*data[:, col]
        np.add(idx, mul * data[:, col], out=idx)
        mul *= domain[col]

    counts = np.bincount(idx, minlength=x_dom)
    margin = counts.astype("float")
    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    return margin

mar = calc_marginal_1way(data, domain, [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

In [None]:
%timeit calc_marginal_1way(data, domain, [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

5.36 ms ± 5.11 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%lprun -f calc_marginal_1way (data, domain, [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

Timer unit: 1e-06 s

Total time: 0 s
File: /tmp/ipykernel_9824/268377759.py
Function: calc_marginal_1way at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def calc_marginal_1way(
     2                                               data: np.ndarray, domain: np.ndarray, x: list[int], rm_zeros: bool = False
     3                                           ):
     4                                               """Calculates the 1 way marginal of x, returned as a 1D array."""
     5                                           
     6                                               x_dom = reduce(lambda a, b: a * b, domain[x], 1)
     7                                           
     8                                               idx = np.zeros((len(data)), dtype="int64")
     9                                               mul = 1
    10                                               for col in reversed(x):
    11       

In [None]:
def calc_marginal_1way_old(
    data: np.ndarray, domain: np.ndarray, x: list[int], rm_zeros: bool = False
):
    """Calculates the 1 way marginal of x, returned as a 1D array."""
    
    sub_data = data[:, x]
    sub_domain = domain[x]
    margin, _ = np.histogramdd(sub_data, sub_domain)
    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    return margin.reshape(-1)

mar2 = calc_marginal_1way_old(data, domain, [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

In [None]:
%timeit calc_marginal_1way_old(data, domain, [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

547 ms ± 747 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
assert np.all(mar == mar2)

In [None]:
model = catalog.load("mimic_tab_admissions.privbayes.model")

In [None]:
model.fit({"table": table}, {})

           INFO     Deterministic check: random number after          base.py:29
                    PrivBayesSynth.fit(): <np.random> 0.96812                   
                    <random> 0.14817                                            


In [None]:
import cupy as cp

def calc_marginal_cupy(
    data,
    domain,
    x: list[int],
    p: list[int],
    rm_zeros: bool = False,
):
    """Calculates the 1 way and 2 way marginals between the set of columns in x
    and the set of columns in p."""

    xp = x + p
    x_dom = int(reduce(lambda a, b: a*b, domain[x], 1))
    p_dom = int(reduce(lambda a, b: a*b, domain[p], 1))

    idx = cp.zeros((len(data)), dtype="uint16")
    tmp = cp.empty((len(data)), dtype="uint16")
    mul = 1
    for col in reversed(xp):
        # idx += mul*data[:, col]
        cp.add(idx, cp.multiply(mul,data[:, col],out=tmp), out=idx)
        mul *= domain[col]
    
    counts = cp.bincount(idx, minlength=x_dom*p_dom)
    margin = counts.reshape(x_dom, p_dom).astype("float")

    margin /= margin.sum()
    if rm_zeros:
        # Mutual info turns into NaN without this
        margin += 1e-24

    j_mar = margin
    x_mar = cp.sum(margin, axis=1)
    p_mar = cp.sum(margin, axis=0)

    return j_mar, x_mar, p_mar

cdata = cp.asarray(data)
cdomain = cp.asarray(domain)
j_mar5, x_mar5, p_mar5 = calc_marginal_cupy(cdata, cdomain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15])

In [None]:
j_mar5


array([[0.93426498, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])


In [None]:
%timeit calc_marginal_cupy(cdata, cdomain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])[0][0]

3.96 ms ± 5.95 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%lprun -f calc_marginal_cupy calc_marginal_cupy(cdata, cdomain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18])

Timer unit: 1e-06 s

Total time: 0.002122 s
File: /tmp/ipykernel_9824/82347716.py
Function: calc_marginal_cupy at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
     3                                           def calc_marginal_cupy(
     4                                               data,
     5                                               domain,
     6                                               x: list[int],
     7                                               p: list[int],
     8                                               rm_zeros: bool = False,
     9                                           ):
    10                                               """Calculates the 1 way and 2 way marginals between the set of columns in x
    11                                               and the set of columns in p."""
    12                                           
    13         1          2.0      2.0      0.1      xp = x + p
    14         1        491.0  

In [None]:

from cupyx.profiler import benchmark
print(benchmark(calc_marginal_cupy, (cdata, cdomain, [0,1,2,3,4], [5,6,7,8,9,10,11,12,13,14,15,16,17,18]), n_repeat=1000))

calc_marginal_cupy  :    CPU: 1490.534 us   +/-49.099 (min: 1453.099 / max: 2265.651) us     GPU-0: 4163.733 us   +/-96.352 (min: 4108.288 / max: 5195.776) us
