In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import bootstrap
from tqdm import tqdm
import wandb
import tempfile
import os

tqdm.pandas()

In [2]:
api = wandb.Api()

In [3]:
# all old_bench
run_name_to_id = {
    "retmass_preconly": "1k6bzqu8",
    "retmass_fp": "wihjpo5l",
    "retmass_gnn": "ow2np6w0",
    "retform_preconly": "b48mjrk4",
    "retform_fp": "a9o2s4dn",
    "retform_gnn": "xvx3i4d1",
}

metric_cols = [
    "cos_sim",
    "cos_sim_sqrt",
    "js_sim",
    "test_hit_rate@1",
    "test_hit_rate@5",
    "test_hit_rate@20"
]

In [4]:
run_name_to_df = {}
with tempfile.TemporaryDirectory() as tmp_dp:
    for run_name, run_id in run_name_to_id.items():
        run = api.run(f"adamoyoung/msg/{run_id}")
        df_file = run.file("df_test.pkl")
        df_fn = df_file.name
        df_fp = os.path.join(tmp_dp,df_fn)
        df_file.download(replace=True, root=tmp_dp)
        run_name_to_df[run_name] = pd.read_pickle(df_fp)

In [5]:
run_name_to_df["retform_preconly"].mean()

losses              0.845981
cos_sim             0.154019
js_sim              0.408519
cos_sim_sqrt        0.165258
cos_sim_obj         0.154019
test_hit_rate@1     0.020896
test_hit_rate@5     0.085192
test_hit_rate@20    0.226542
dtype: float64

In [6]:
# put all dfs into one big df
df = []
for method, method_df in run_name_to_df.items():
    method_df["method"] = method
    df.append(method_df)
df = pd.concat(df)
df

Unnamed: 0,losses,cos_sim,js_sim,cos_sim_sqrt,cos_sim_obj,test_hit_rate@1,test_hit_rate@5,test_hit_rate@20,method
0,1.000000,0.000000,0.306853,0.000000,0.000000,0.0,0.0,0.0,retmass_preconly
1,1.000000,0.000000,0.306853,0.000000,0.000000,0.0,0.0,0.0,retmass_preconly
2,1.000000,0.000000,0.306853,0.000000,0.000000,0.0,0.0,0.0,retmass_preconly
3,1.000000,0.000000,0.306853,0.000000,0.000000,0.0,0.0,0.0,retmass_preconly
4,0.139411,0.860589,0.802057,0.730126,0.860589,0.0,1.0,1.0,retmass_preconly
...,...,...,...,...,...,...,...,...,...
9949,0.759845,0.204946,0.437068,0.240155,0.240155,0.0,0.0,1.0,retform_gnn
9950,0.982515,0.000745,0.317624,0.017485,0.017485,0.0,0.0,0.0,retform_gnn
9951,0.817273,0.038494,0.385755,0.182727,0.182727,0.0,0.0,0.0,retform_gnn
9952,1.000000,0.000000,0.306853,0.000000,0.000000,0.0,0.0,0.0,retform_gnn


In [12]:
def calculate_ci(df, metric_cols):

    # Calculate means for all metrics into a single table
    df[metric_cols] = df[metric_cols]*100
    df_mean = df.groupby('method')[metric_cols].mean().round(2)

    # Calculate confidence intervals for all metrics into a single table
    def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000):
        res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples)
        ci = res.confidence_interval
        return f'{ci.low:.2f}-{ci.high:.2f}'
    def get_ci_for_each_col(df_method):
        return df_method.apply(get_ci, axis=0)
    df_ci = df.groupby('method')[metric_cols].progress_apply(lambda df_method: get_ci_for_each_col(df_method))

    # Merge tables with means and confidence intervals
    for col in metric_cols:
        df_mean[col] = df_mean[col].astype(str) + ' (' + df_ci[col] + ')'
    return df_mean

In [13]:
df_ci = calculate_ci(df, metric_cols)

  0%|          | 0/6 [00:00<?, ?it/s]

100%|██████████| 6/6 [03:23<00:00, 33.93s/it]


In [14]:
df_ci

Unnamed: 0_level_0,cos_sim,cos_sim_sqrt,js_sim,test_hit_rate@1,test_hit_rate@5,test_hit_rate@20
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
retform_fp,24.43 (23.34-25.51),26.84 (25.91-27.80),47.52 (46.89-48.18),7.62 (6.77-8.54),22.7 (21.32-24.12),44.12 (42.51-45.75)
retform_gnn,18.97 (18.03-19.95),22.01 (21.15-22.86),44.26 (43.70-44.87),3.63 (3.05-4.29),13.55 (12.46-14.68),33.77 (32.26-35.37)
retform_preconly,15.4 (14.44-16.43),16.53 (15.63-17.42),40.85 (40.26-41.49),2.09 (1.66-2.59),8.52 (7.65-9.53),22.65 (21.26-24.01)
retmass_fp,24.95 (23.91-25.98),27.28 (26.32-28.21),47.8 (47.17-48.47),8.44 (7.56-9.34),21.43 (20.10-22.79),38.57 (36.99-40.23)
retmass_gnn,19.22 (18.23-20.23),21.81 (20.99-22.69),44.15 (43.56-44.75),3.95 (3.37-4.62),11.92 (10.87-13.00),26.27 (24.83-27.82)
retmass_preconly,15.4 (14.40-16.42),16.53 (15.61-17.46),40.85 (40.27-41.46),0.38 (0.21-0.62),1.72 (1.32-2.18),7.17 (6.32-8.04)
