In [None]:
import logging

import numpy as np
import pandas as pd
import ray
from contexttimer import Timer
from logzero import logger

from bds.bb import get_ground_truth_count
from bds.meel import approx_mc2, approx_mc2_core
from bds.rule import Rule
from bds.utils import bin_random, randints

In [None]:
ray.init(num_cpus=16)

In [None]:
# ray.shutdown()

In [None]:
logger.setLevel(logging.WARN)

num_pts = 1000

ub = 0.8
lmbd = 0.1

delta = 0.8
eps = 0.8

show_progres = True
rand_seed = 1234

In [None]:
n_reps = 1
num_rules_list = list(reversed([50, 100, 150, 200]))
# num_rules_list = list(reversed([200]))


np.random.seed(rand_seed)
res_rows = []
for _ in range(n_reps):
    for num_rules in num_rules_list:
        rule_random_seeds = randints(num_rules)
        rules = [
            Rule.random(i + 1, num_pts, random_seed=rule_random_seeds[i])
            for i in range(num_rules)
        ]
        y = bin_random(num_pts)

        with Timer() as cbb_timer:
            test_cnt = approx_mc2(
                rules,
                y,
                lmbd=lmbd,
                ub=ub,
                delta=delta,
                eps=eps,
                rand_seed=rand_seed,
                show_progress=show_progres,
                parallel=True,
                log_level=logging.WARN,
            )
        test_elapsed = cbb_timer.elapsed

        with Timer() as bb_timer:
            ref_count = get_ground_truth_count(rules, y, lmbd, ub)
        ref_elapsed = bb_timer.elapsed

        res_rows.append((num_rules, test_elapsed, ref_elapsed, test_cnt, ref_count))

In [None]:
df = pd.DataFrame(
    res_rows,
    columns=[
        "num_rules",
        "running_time_approx_mc2",
        "running_time_bb",
        "estimate_count",
        "true_count",
    ],
)
df["runtime-factor"] = df["running_time_approx_mc2"] / df["running_time_bb"]
df["estimation-rel-diff"] = (df["estimate_count"] - df["true_count"]) / df["true_count"]

df

In [None]:
print(
    df.groupby("num_rules")[
        "running_time_approx_mc2", "running_time_bb", "runtime-factor"
    ]
    .mean()
    .to_markdown()
)