In [19]:
import os
import context
os.chdir(context.proj_dir)

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import OrderedDict
import re

from cont_gen.utils import load_jsonl, get_ckpt_paths, save_json, load_json

In [28]:
def get_ckpt_ov_met(ckpt):
    """Return the id and ood overal metric of a checkpoint"""
    return {'id': load_json(Path(ckpt) / 'ov_metrics_id_sampled.csv'),
            'ood': load_json(Path(ckpt) / 'ov_metrics_ood_sampled.csv')}

def get_run_ov_met(run_dir):
    """Return the ov metrics of all checkpoints"""
    ckpt_dirs = get_ckpt_paths(run_dir)
    run_ov = OrderedDict()
    for ckpt in ckpt_dirs:
        run_ov[int(str(ckpt).split('-')[-1])] = get_ckpt_ov_met(ckpt)
    return run_ov

def get_run_ov_best(run_dir, rep_met = 'macro_iou'):
    """Return the best checkpoint number and id and ood metrics"""
    run_ov = get_run_ov_met(run_dir)
    # find the best checkpoint
    run_ov = list(run_ov.items())
    rep_met_ckpts = [met['id'][rep_met] for _, met in run_ov]
    best_i = np.argmax(rep_met_ckpts)
    return best_i, run_ov[best_i][1]

def get_model_results(mod_run_dir, run_name_pat = r'pmt_01(_all)?_lr', rep_met = 'macro_iou'):
    """Return a dataframe of split, id_metrics and ood_metrics"""
    
    results = pd.DataFrame(columns = ['split', 'best_epoch', 'id_metrics', 'ood_metrics', 'run_path'])
    
    mod_run_dir = Path(mod_run_dir)
    for spl_path in mod_run_dir.glob('*'):
        run_paths = spl_path.glob('*')
        # filter the pmt_01 run
        run_paths = [k for k in run_paths if re.match(run_name_pat, k.name)]
        if len(run_paths) > 1:
            print(f'Warning: more than one runs for a split: {run_paths}')
            if 't5-large' in str(mod_run_dir):
                run_paths = [k for k in run_paths if 'lr1e-4' in k.name]
        if len(run_paths) == 0:
            print(f'Warning: no runs of a split: {spl_path}, {list(spl_path.glob("*"))}')
            continue
        run_p = run_paths[0]
        be, best_mets = get_run_ov_best(run_p, rep_met)
        results.loc[len(results)] = [spl_path.name, be, best_mets['id'], best_mets['ood'], str(run_p)]
    return results


In [3]:
# get_run_ov_best('runs/ood/mistral/seed89_tr29/pmt_01_all_lr1e-5_bs16_wd0.0')

In [21]:
mod_names = ['t5-large', 'flan-t5-large',  'flan-t5-xl', 'llama3', 'llama3_chat', 'mistral', 'mistral_chat']
main_df = None
for mod in mod_names:
    mod_df = get_model_results(f'runs/ood/{mod}')
    mod_df.insert(0, 'model', mod)
    if main_df is None:
        main_df = mod_df
    else:
        main_df = pd.concat([main_df, mod_df], axis = 0)
display(main_df)



Unnamed: 0,model,split,best_epoch,id_metrics,ood_metrics,run_path
0,t5-large,seed42_tr29,4,"{'doc_macro_f1': 0.23765104940683315, 'doc_mac...","{'doc_macro_f1': 0.27653595424702, 'doc_macro_...",runs/ood/t5-large/seed42_tr29/pmt_01_lr1e-4_bs...
1,t5-large,seed128_tr29,4,"{'doc_macro_f1': 0.1661662132211322, 'doc_macr...","{'doc_macro_f1': 0.23620615102515183, 'doc_mac...",runs/ood/t5-large/seed128_tr29/pmt_01_lr1e-4_b...
2,t5-large,seed89_tr29,2,"{'doc_macro_f1': 0.10306172587922359, 'doc_mac...","{'doc_macro_f1': 0.11281999801274999, 'doc_mac...",runs/ood/t5-large/seed89_tr29/pmt_01_lr1e-4_bs...
0,flan-t5-large,seed89_tr29,3,"{'doc_macro_f1': 0.8539321097926461, 'doc_macr...","{'doc_macro_f1': 0.8519106168389748, 'doc_macr...",runs/ood/flan-t5-large/seed89_tr29/pmt_01_lr1e...
1,flan-t5-large,seed128_tr29,4,"{'doc_macro_f1': 0.8786140720589551, 'doc_macr...","{'doc_macro_f1': 0.8009756117168526, 'doc_macr...",runs/ood/flan-t5-large/seed128_tr29/pmt_01_lr1...
2,flan-t5-large,seed42_tr29,4,"{'doc_macro_f1': 0.8679008972398385, 'doc_macr...","{'doc_macro_f1': 0.7574326459225548, 'doc_macr...",runs/ood/flan-t5-large/seed42_tr29/pmt_01_lr1e...
0,flan-t5-xl,seed42_tr29,4,"{'doc_macro_f1': 0.8883701388084073, 'doc_macr...","{'doc_macro_f1': 0.8137001432967591, 'doc_macr...",runs/ood/flan-t5-xl/seed42_tr29/pmt_01_lr1e-4_...
1,flan-t5-xl,seed128_tr29,4,"{'doc_macro_f1': 0.8888435417631178, 'doc_macr...","{'doc_macro_f1': 0.7964729951434358, 'doc_macr...",runs/ood/flan-t5-xl/seed128_tr29/pmt_01_lr1e-4...
2,flan-t5-xl,seed89_tr29,4,"{'doc_macro_f1': 0.8778129059488043, 'doc_macr...","{'doc_macro_f1': 0.843953571484343, 'doc_macro...",runs/ood/flan-t5-xl/seed89_tr29/pmt_01_lr1e-4_...
0,llama3,seed128_tr29,4,"{'doc_macro_f1': 0.8657928792420393, 'doc_macr...","{'doc_macro_f1': 0.8026115318850245, 'doc_macr...",runs/ood/llama3/seed128_tr29/pmt_01_all_lr1e-5...


In [25]:
for _, row in main_df[main_df['model'] == 'flan-t5-large'].iterrows():
    print(row['model'], row['split'])
    mets = row['ood_metrics']
    print({k:v for k,v in mets.items() if k.startswith('macro')})

flan-t5-large seed89_tr29
{'macro_f1': 0.4900600375112804, 'macro_iou': 0.36262888112283287, 'macro_p': 0.7933159210489532, 'macro_r': 0.4069200449655857}
flan-t5-large seed128_tr29
{'macro_f1': 0.415667370415379, 'macro_iou': 0.2936412684353065, 'macro_p': 0.6619257031651343, 'macro_r': 0.45579211229511535}
flan-t5-large seed42_tr29
{'macro_f1': 0.4740948482902716, 'macro_iou': 0.3413932168052664, 'macro_p': 0.5755152041860386, 'macro_r': 0.5583714490677154}


In [30]:
run_ov = get_run_ov_met('runs/ood/flan-t5-large/seed128_tr29/pmt_01_lr1e-4_bs16_wd0.0')
macro_iou = {k: v['ood']['macro_iou'] for k,v in run_ov.items()}
print(macro_iou)

{7213: 0.2774533722561047, 14426: 0.29234445276880855, 21639: 0.2978490388593292, 28852: 0.29325434267447953, 36065: 0.2936412684353065}


In [31]:
run_ov = list(run_ov.items())
rep_met_ckpts = [met['id']['macro_iou'] for _, met in run_ov]
best_i = np.argmax(rep_met_ckpts)
print(rep_met_ckpts)
print(best_i)

[0.44543560546887034, 0.4383646570629703, 0.45735813970012223, 0.44922889330414045, 0.4811198040240143]
4


In [22]:
keys = [f'{a}_{b}' for a in ['macro', 'micro'] for b in ['p', 'r', 'f1', 'iou']]

mets = main_df[(main_df['model'] == 'mistral') & (main_df['split'] == 'seed89_tr29')]['ood_metrics'].iloc[0]
sort_mets = [f'{mets[k]*100:.2f}' for k in keys]
print('  '.join(sort_mets))

77.34  39.25  46.68  34.29  75.88  41.94  54.02  37.01


## Print the main table

In [5]:
# average over splits
def dict_ave(dicts):
    """Average values of dicts"""
    keys = list(dicts[0].keys())
    # turn into a table
    dict_values = [[dt[k] for k in keys] for dt in dicts]
    # average each column
    ave_values = [np.mean(k) for k in zip(*dict_values)]
    return dict(zip(keys, ave_values))

main_ave_df = main_df.groupby('model').apply(lambda gp: pd.Series({
    'id_metrics': dict_ave(gp['id_metrics'].to_list()),
    'ood_metrics': dict_ave(gp['ood_metrics'].to_list()),
})).reset_index()

In [6]:
main_ave_df

Unnamed: 0,model,id_metrics,ood_metrics
0,flan-t5-large,"{'doc_macro_f1': 0.86681569303048, 'doc_macro_...","{'doc_macro_f1': 0.8034396248261274, 'doc_macr..."
1,flan-t5-xl,"{'doc_macro_f1': 0.8850088621734432, 'doc_macr...","{'doc_macro_f1': 0.8180422366415127, 'doc_macr..."
2,llama3,"{'doc_macro_f1': 0.8496301142050342, 'doc_macr...","{'doc_macro_f1': 0.8214805048095752, 'doc_macr..."
3,llama3_chat,"{'doc_macro_f1': 0.8525921787134653, 'doc_macr...","{'doc_macro_f1': 0.8322974095601343, 'doc_macr..."
4,mistral,"{'doc_macro_f1': 0.8729672388732119, 'doc_macr...","{'doc_macro_f1': 0.8254737389365082, 'doc_macr..."
5,mistral_chat,"{'doc_macro_f1': 0.8460275086018848, 'doc_macr...","{'doc_macro_f1': 0.8211038460019057, 'doc_macr..."
6,t5-large,"{'doc_macro_f1': 0.16895966283572963, 'doc_mac...","{'doc_macro_f1': 0.20852070109497392, 'doc_mac..."


In [17]:
# print table for latex or excel
keys = [f'{a}_{b}' for a in ['macro', 'micro'] for b in ['p', 'r', 'f1', 'iou']]
sep = ' & ' # ' & ' for latex
mod2mets = {row['model']: row['ood_metrics'] for _, row in main_ave_df.iterrows()}

table_main = [['Model', *keys]]

for mod in mod_names:
    vals = [round(mod2mets[mod][k]*100, 2) for k in keys]
    table_main.append([mod, *vals])
    # val_str = sep.join(map(lambda k: f'{k*100:.2f}', vals))
    # print(f'{mod} {sep} {val_str}')

for line in table_main:
    print(sep.join(map(str, line)) + r' \\')

Model & macro_p & macro_r & macro_f1 & macro_iou & micro_p & micro_r & micro_f1 & micro_iou \\
t5-large & 3.71 & 13.02 & 4.21 & 2.18 & 3.73 & 10.0 & 4.55 & 2.34 \\
flan-t5-large & 67.69 & 47.37 & 45.99 & 33.26 & 63.19 & 45.75 & 51.22 & 34.42 \\
flan-t5-xl & 69.14 & 50.04 & 49.75 & 36.3 & 65.18 & 51.16 & 56.52 & 39.42 \\
llama3 & 68.05 & 43.88 & 44.1 & 31.73 & 64.58 & 44.85 & 52.2 & 35.48 \\
llama3_chat & 68.45 & 46.93 & 48.76 & 35.45 & 69.22 & 48.35 & 56.85 & 39.74 \\
mistral & 67.03 & 45.96 & 45.8 & 33.33 & 63.84 & 46.25 & 51.91 & 35.1 \\
mistral_chat & 62.63 & 40.91 & 42.09 & 29.83 & 64.96 & 41.82 & 50.1 & 33.64 \\


In [18]:
save_json(table_main, 'scripts/plot/save_data/main_table.json')