In [14]:
import pandas as pd

from src.constant import MAIN_DIR
from src.database import DB
from tqdm.auto import tqdm

PHASE2_DIR = MAIN_DIR / "archive" / "phase2"
N_TRAIN_LIST = ["25", "100"]
POLICY_LIST = ["baseline", "ea", "eb", "ec", "ia", "ib"]

In [15]:
records = []
total_iterations = len(N_TRAIN_LIST) * len(POLICY_LIST)
progress_bar = tqdm(total=total_iterations, desc="Processing")

for n_train in N_TRAIN_LIST:
    for policy in POLICY_LIST:
        # Update description to show current n_train and policy
        progress_bar.set_description(f"n_train={n_train}, policy={policy}")

        db_path_list = list((PHASE2_DIR / n_train / policy).glob("run-policy-*.db"))

        for db_path in db_path_list:
            db = DB(db_path)
            results = db.get_results()
            results_test = results.loc[results["prefix"].str.startswith("test")]
            if results_test.empty:
                print(f"No results {db_path}")
                continue
            cost = (
                results.loc[results["prefix"].str.startswith("test")]
                .groupby(["instance_id", "prefix"])["cost"]
                .min()  # min for every problem (of 2 solvers)
                .reset_index()
                .groupby("instance_id")["cost"]
                .median()  # median over 5 runs
                .mean()  # total mean score
            )
            cpu_time = (
                results.loc[
                    results["prefix"].str.startswith("config")
                    & results["cached"].eq(0)
                    & results["surrogate"].eq(0),
                    "time",
                ].sum()
                / 3600
            )
            surrogate_pct = results.loc[results["prefix"].str.startswith("config")].groupby(["solver_id", "instance_id"])["surrogate"].max().mean()
            real_pct = results.loc[results["prefix"].str.startswith("config")].groupby(["solver_id", "instance_id"])["surrogate"].min().eq(0).mean()
            records.append(
                {
                    "db_path": "/".join(db_path.parts[-3:]),
                    "n_train": n_train,
                    "policy": policy,
                    "cost": cost,
                    "cpu_time": cpu_time,
                    "surrogate_pct": surrogate_pct,
                    "real_pct": real_pct,
                }
            )
        progress_bar.update(1)

progress_bar.close()

df = pd.DataFrame(records)
df

Processing:   0%|          | 0/12 [00:00<?, ?it/s]

Unnamed: 0,db_path,n_train,policy,cost,cpu_time,surrogate_pct,real_pct
0,25/baseline/run-policy-baseline-25-1012821.db,25,baseline,0.20404,1.085407,0.000000,1.000000
1,25/baseline/run-policy-baseline-25-1012829.db,25,baseline,0.18832,1.231176,0.000000,1.000000
2,25/baseline/run-policy-baseline-25-1012830.db,25,baseline,0.73652,1.443054,0.000000,1.000000
3,25/baseline/run-policy-baseline-25-1012832.db,25,baseline,0.30916,0.832994,0.000000,1.000000
4,25/baseline/run-policy-baseline-25-1013020.db,25,baseline,0.20116,1.118927,0.000000,1.000000
...,...,...,...,...,...,...,...
100,100/ib/run-policy-ib-100-1013195.db,100,ib,0.29312,1.845759,0.918033,0.475410
101,100/ib/run-policy-ib-100-1013261.db,100,ib,0.54144,2.299865,0.907407,0.685185
102,100/ib/run-policy-ib-100-1013267.db,100,ib,0.15760,1.967770,0.939024,0.402439
103,100/ib/run-policy-ib-100-1013295.db,100,ib,0.38212,1.128463,0.924242,0.227273


In [3]:
df.pivot_table(index="policy", columns="n_train", values="cpu_time", aggfunc="count")

n_train,100,25
policy,Unnamed: 1_level_1,Unnamed: 2_level_1
baseline,9,9
ea,9,9
eb,9,9
ec,6,9
ia,9,9
ib,9,9


In [4]:
df.pivot_table(index="policy", columns="n_train", values="cpu_time", aggfunc="mean")

n_train,100,25
policy,Unnamed: 1_level_1,Unnamed: 2_level_1
baseline,4.403321,1.070346
ea,3.184109,0.816572
eb,2.066793,0.587298
ec,12.555337,3.044383
ia,6.946186,1.777969
ib,1.721727,0.410029


In [5]:
df.pivot_table(index="policy", columns="n_train", values="cost", aggfunc="mean")

n_train,100,25
policy,Unnamed: 1_level_1,Unnamed: 2_level_1
baseline,0.269547,0.32416
ea,0.312129,0.349618
eb,0.366613,0.361996
ec,0.262893,0.270084
ia,0.157129,0.192764
ib,0.36812,0.354658


In [6]:
def agg(x):
    df_agg = x.groupby("policy").agg(
        cost=("cost", "mean"),
        cpu_time=("cpu_time", "mean"),
        surrogate_pct=("surrogate_pct", "mean"),
        real_pct=("real_pct", "mean"),
    )
    df_agg["cost_ratio_to_baseline"] = df_agg["cost"] / df_agg.at["baseline", "cost"]
    df_agg["cpu_time_ratio_to_baseline"] = df_agg["cpu_time"] / df_agg.at["baseline", "cpu_time"]
    df_agg = df_agg.round(4)
    df_agg = df_agg.loc[:, ["cost", "cost_ratio_to_baseline", "cpu_time", "cpu_time_ratio_to_baseline", "surrogate_pct", "real_pct"]]
    return df_agg

df25 = df.loc[df["n_train"] == "25"].copy()
df100 = df.loc[df["n_train"] == "100"].copy()

df25_agg = agg(df25)
df100_agg = agg(df100)

In [7]:
df25_agg# .to_excel("tmp.xlsx")

Unnamed: 0_level_0,cost,cost_ratio_to_baseline,cpu_time,cpu_time_ratio_to_baseline,surrogate_pct,real_pct
policy,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
baseline,0.3242,1.0,1.0703,1.0,0.0,1.0
ea,0.3496,1.0785,0.8166,0.7629,0.4641,0.6681
eb,0.362,1.1167,0.5873,0.5487,0.9258,0.6856
ec,0.2701,0.8332,3.0444,2.8443,0.9209,1.0
ia,0.1928,0.5947,1.778,1.6611,0.4847,0.5176
ib,0.3547,1.0941,0.41,0.3831,0.9196,0.356


In [8]:
df100_agg

Unnamed: 0_level_0,cost,cost_ratio_to_baseline,cpu_time,cpu_time_ratio_to_baseline,surrogate_pct,real_pct
policy,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
baseline,0.2695,1.0,4.4033,1.0,0.0,1.0
ea,0.3121,1.158,3.1841,0.7231,0.4659,0.6622
eb,0.3666,1.3601,2.0668,0.4694,0.9313,0.6615
ec,0.2629,0.9753,12.5553,2.8513,0.9243,1.0
ia,0.1571,0.5829,6.9462,1.5775,0.4847,0.5198
ib,0.3681,1.3657,1.7217,0.391,0.9237,0.3928


In [18]:
df.loc[(df["policy"] == "ia") & (df["n_train"] == "25")].sort_values(by="cost")

Unnamed: 0,db_path,n_train,policy,cost,cpu_time,surrogate_pct,real_pct
36,25/ia/run-policy-ia-25-1013659.db,25,ia,0.10212,1.683734,0.484694,0.515306
37,25/ia/run-policy-ia-25-1013660.db,25,ia,0.12632,2.052274,0.484694,0.515306
39,25/ia/run-policy-ia-25-1013662.db,25,ia,0.15724,1.519974,0.484694,0.520408
41,25/ia/run-policy-ia-25-1013664.db,25,ia,0.16696,1.701617,0.484694,0.520408
40,25/ia/run-policy-ia-25-1013663.db,25,ia,0.18408,1.90431,0.484694,0.515306
43,25/ia/run-policy-ia-25-1013666.db,25,ia,0.1922,1.637405,0.484694,0.515306
42,25/ia/run-policy-ia-25-1013665.db,25,ia,0.22512,1.540793,0.484694,0.520408
44,25/ia/run-policy-ia-25-1013667.db,25,ia,0.2628,2.054313,0.484694,0.515306
38,25/ia/run-policy-ia-25-1013661.db,25,ia,0.31804,1.907303,0.484694,0.520408


In [19]:
db = DB(PHASE2_DIR / "25/ia/run-policy-ia-25-1013659.db")
results = db.get_results()

In [24]:
results["cached"].value_counts()

0    7450
1    3000
Name: cached, dtype: int64

In [29]:
results.loc[lambda x: x["surrogate"].eq(1)]

Unnamed: 0,id,prefix,solver_id,instance_id,cost,time,cut_off_cost,cut_off_time,cached,surrogate,error
150,config;solver=1;attempt=1;aac_iter=7;surrogate...,config;solver=1;attempt=1;aac_iter=7;surrogate,33909028774907840,1199808321398786303,1.537740,0.0,7.4,0.74,0,1,0
151,config;solver=1;attempt=1;aac_iter=7;surrogate...,config;solver=1;attempt=1;aac_iter=7;surrogate,33909028774907840,926855222569918425,2.932883,0.0,77.7,7.77,0,1,0
152,config;solver=1;attempt=1;aac_iter=7;surrogate...,config;solver=1;attempt=1;aac_iter=7;surrogate,33909028774907840,2036633111162358606,1.590186,0.0,13.3,1.33,0,1,0
153,config;solver=1;attempt=1;aac_iter=7;surrogate...,config;solver=1;attempt=1;aac_iter=7;surrogate,33909028774907840,1493964260327799128,2.459986,0.0,25.4,2.54,0,1,0
154,config;solver=1;attempt=1;aac_iter=7;surrogate...,config;solver=1;attempt=1;aac_iter=7;surrogate,33909028774907840,5624699322667089,1.274502,0.0,13.9,1.39,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...
7841,config;solver=2;attempt=4;aac_iter=24;surrogat...,config;solver=2;attempt=4;aac_iter=24;surrogate,570117169212403366,622705132865616942,4.419830,0.0,44.0,4.40,0,1,0
7843,config;solver=2;attempt=4;aac_iter=24;surrogat...,config;solver=2;attempt=4;aac_iter=24;surrogate,570117169212403366,1027898753954028458,0.875935,0.0,16.2,1.62,0,1,0
7845,config;solver=2;attempt=4;aac_iter=24;surrogat...,config;solver=2;attempt=4;aac_iter=24;surrogate,570117169212403366,1084422477911079854,1.499058,0.0,31.1,3.11,0,1,0
7847,config;solver=2;attempt=4;aac_iter=24;surrogat...,config;solver=2;attempt=4;aac_iter=24;surrogate,570117169212403366,1337376689788754595,1.368699,0.0,21.8,2.18,0,1,0
