In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import bootstrap
from tqdm import tqdm
import wandb
import tempfile
import os

tqdm.pandas()

In [2]:
api = wandb.Api()

In [3]:
# all old_bench
run_name_to_id = {
    "preconly_mass": "gof6lho0", #"1k6bzqu8",
    "fp_mass": "l5jcvdf1", #"wihjpo5l",
    "gnn_mass": "ymixvlg2", #"ow2np6w0",
    "preconly_formula": "6gomcjff", #"b48mjrk4",
    "fp_formula": "alt70ypf", #"a9o2s4dn",
    "gnn_formula": "fyadjjed" #"xvx3i4d1",
}

metric_cols = [
    "cos_sim",
    "cos_sim_sqrt",
    "js_sim",
    "test_hit_rate@1",
    "test_hit_rate@5",
    "test_hit_rate@20"
]

In [4]:
run_name_to_df = {}
with tempfile.TemporaryDirectory() as tmp_dp:
    for run_name, run_id in run_name_to_id.items():
        run = api.run(f"adamoyoung/msg/{run_id}")
        df_file = run.file("df_test.pkl")
        df_fn = df_file.name
        df_fp = os.path.join(tmp_dp,df_fn)
        df_file.download(replace=True, root=tmp_dp)
        run_name_to_df[run_name] = pd.read_pickle(df_fp)

In [5]:
# put all dfs into one big df
df = []
for method, method_df in run_name_to_df.items():
    method_df["method"] = method
    df.append(method_df)
df = pd.concat(df)
for col in ["test_hit_rate@1", "test_hit_rate@5", "test_hit_rate@20"]:
    df[col] = df[col] * 100
df

Unnamed: 0,losses,cos_sim,js_sim,cos_sim_sqrt,cos_sim_obj,test_hit_rate@1,test_hit_rate@5,test_hit_rate@20,method
0,1.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,preconly_mass
1,1.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,preconly_mass
2,1.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,preconly_mass
3,1.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,preconly_mass
4,0.139411,0.860589,0.714429,0.730126,0.860589,0.0,100.0,100.0,preconly_mass
...,...,...,...,...,...,...,...,...,...
9949,1.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,gnn_formula
9950,0.902007,0.030227,0.080346,0.097993,0.097993,0.0,0.0,0.0,gnn_formula
9951,0.620707,0.407422,0.329592,0.379293,0.379293,100.0,100.0,100.0,gnn_formula
9952,0.996564,0.000027,0.003427,0.003436,0.003436,0.0,0.0,0.0,gnn_formula


In [6]:
def calculate_ci(df, metric_cols):

    # Calculate means for all metrics into a single table
    df_mean = df.groupby('method')[metric_cols].mean().round(2)

    # Calculate confidence intervals for all metrics into a single table
    def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000):
        res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples)
        ci = res.confidence_interval
        return f'{ci.low:.2f}-{ci.high:.2f}'
    def get_ci_for_each_col(df_method):
        return df_method.apply(get_ci, axis=0)
    df_ci = df.groupby('method')[metric_cols].progress_apply(lambda df_method: get_ci_for_each_col(df_method))

    # Merge tables with means and confidence intervals
    for col in metric_cols:
        df_mean[col] = df_mean[col].astype(str) + ' (' + df_ci[col] + ')'
    return df_mean

In [7]:
df_ci = calculate_ci(df, metric_cols)

  0%|          | 0/6 [00:00<?, ?it/s]

100%|██████████| 6/6 [03:32<00:00, 35.38s/it]


In [8]:
# these are slightly different from results reported in the paper due to randomness in the training process
df_ci.loc[list(run_name_to_id.keys())]

Unnamed: 0_level_0,cos_sim,cos_sim_sqrt,js_sim,test_hit_rate@1,test_hit_rate@5,test_hit_rate@20
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
preconly_mass,0.15 (0.14-0.16),0.17 (0.16-0.17),0.15 (0.14-0.16),0.38 (0.21-0.61),1.72 (1.33-2.16),7.17 (6.35-8.03)
fp_mass,0.25 (0.24-0.26),0.27 (0.26-0.28),0.25 (0.24-0.26),9.82 (8.87-10.78),23.8 (22.41-25.28),40.69 (39.04-42.27)
gnn_mass,0.2 (0.19-0.21),0.23 (0.22-0.24),0.2 (0.20-0.21),4.17 (3.55-4.84),12.85 (11.75-13.97),28.8 (27.31-30.30)
preconly_formula,0.15 (0.14-0.16),0.17 (0.16-0.17),0.15 (0.14-0.16),2.09 (1.67-2.57),8.52 (7.62-9.45),22.65 (21.31-24.01)
fp_formula,0.19 (0.18-0.20),0.21 (0.20-0.22),0.19 (0.18-0.20),3.7 (3.12-4.34),13.56 (12.47-14.70),32.88 (31.35-34.46)
gnn_formula,0.25 (0.24-0.26),0.27 (0.26-0.28),0.25 (0.24-0.26),7.32 (6.49-8.25),23.0 (21.67-24.37),43.69 (42.11-45.38)
