In [None]:
from pasteur.kedro.ipython import * # type: ignore
register_kedro() # type: ignore

%load_ext autoreload
%autoreload 2

In [None]:
from pasteur.hierarchy import rebalance_attributes
from pasteur.marginal import MarginalOracle

old_attrs = catalog.load("mimic_billion.trn.table")["idx"].get_attributes()
wrk = catalog.load("mimic_billion.wrk.idx_table")

m = MarginalOracle(old_attrs, wrk)
counts = m.get_counts()
m.close()

attrs = rebalance_attributes(counts, old_attrs, fixed=[2, 4, 8, 16, 32, 48], u=4)

In [None]:
from pasteur.marginal import AttrSelector
from pasteur.marginal.oracle import parallel_load
from pasteur.marginal.memory import map_to_memory, load_from_memory
from pasteur.utils import LazyDataset, LazyPartition
from pasteur.utils.progress import init_pool
import time

In [None]:
mar = {
    "gender": AttrSelector("gender", 0, {"gender": 0}),
    "warning": AttrSelector("warning", 0, {"warning": 0}),
    "intime": AttrSelector("intime", 0, {"intime_day": 1}),
    "outtime": AttrSelector("outtime", 0, {"outtime_day": 1}),
    "charttime": AttrSelector(
        "charttime", 0, {"charttime_day": 0, "charttime_time": 2}
    ),
}

In [None]:
from functools import partial

sample = wrk.sample()

def get_small_dataset(n):
    return LazyDataset(merged_load=LazyPartition(lambda: sample[:n], lambda: sample[:n].shape))

def get_big_dataset(n):
    return LazyDataset(merged_load=None, partitions=dict(list(wrk._partitions.items())[:n]))


In [None]:
tests = [
    (100_000, get_small_dataset(100_000), 10_000_000),
    (500_000, get_small_dataset(500_000), 100_000),
    (1_000_000, get_small_dataset(1_000_000), 1_000_000),
    (10_000_000, get_small_dataset(10_000_000), 200_000),
    (100_000_000, get_big_dataset(6), 50_000),
    (500_000_000, get_big_dataset(31), 20_000),
    (1_000_000_000, wrk, 10_000),
]

In [None]:
print("> Single Core")
for N, ds, M in tests:
    M = M // 30

    with MarginalOracle(attrs, ds, mode="inmemory_shared", max_worker_mult=1, log=False) as m:
        init_pool(max_workers=1)
        m.load_data()
        reqs = [mar for _ in range(M)]
        init_pool()

        start = time.perf_counter()
        m.process(reqs, desc=f"N={N:,}")
        end = time.perf_counter()

        print(f"N={N: 5d}: {(M / (end - start)):.3f}")
