In [153]:
import sys
sys.path.insert(0, '../src')
from pathlib import Path
import json
import os
import warnings
from tqdm import tqdm
import numpy as np
import pandas as pd

In [154]:
mode = 2
if mode is None:
    mode = 12 # 2: SGD / 12: MW
SHD_ALG_NAME = "SHDILP"
GRAPH_ALG_NAME = "CSILP"
#GRAPH_ALG_NAME = "BCILP"
#GRAPH_ALG_NAME = "MSG2"
FILTER="*"

DATASET = 'SGD' if mode <10 else 'MultiWOZ'
GRAPH_DATASET = 'SGD' if mode <10 else 'MultiWOZ'
INF_GRAPH_DIR = f'../graphs/{GRAPH_DATASET}/{GRAPH_ALG_NAME}'
SHD_DIR = f'../graphs/{GRAPH_DATASET}/{SHD_ALG_NAME}'
#
if mode < 10: # SGD
    PREDICTIONS_DIR = f'../outputs/{DATASET}/'   
else:
    PREDICTIONS_DIR = f'../outputs/MultiWOZ/'

In [155]:
def read_config(config_pth):
    try:
        with config_pth.open() as cf:
            config = json.load(cf)
    except Exception as e:
        print(f'Error while reading {config_pth}')
        print(e)
    #dataset = config['dataset']
    if '_trajectories.json' in config['traj_path']:
        config['domain'] = config['traj_path'].rsplit('/', 1)[-1].split('_trajectories.json')[0]
    else:
        config['domain'] = config['traj_path'].rsplit('/', 1)[-1].split('.json')[0]
    prompt = (config['prompt_style'], config['num_shot'], config['use_mask_prompt'])
    temp = 0.0 if 'temperature' not in config else config['temperature']
    sampling = 'multi' if 'sampling' not in config else config['sampling']
    
    key = [config['domain'], config['model'], prompt, temp, sampling]
    return config, key

## Load predictions. Organize by the models

In [156]:
domains = set()
predictions, mapped_predictions = {}, {}
predictions_by_domain, label_tuple_by_domain = {}, {}
print(f'searching {len(list(Path(PREDICTIONS_DIR).glob(FILTER)))} directories @ {PREDICTIONS_DIR}')
for pth in Path(PREDICTIONS_DIR).glob(FILTER):
    for config_pth in pth.glob('config*.json'):
        config, base_key = read_config(config_pth)
        domains.add(config['domain'])
        for seed in config['seed']:
            key = tuple(base_key + [seed])
            for mapped_file_pth in config_pth.parent.glob(f'DM_mapped_prediction_S{seed}.npy'):
                if key in mapped_predictions:
                    assert False, f'duplicated: {key}, {mapped_file_pth}'
                try:
                    mapped_predictions[key] = np.load(mapped_file_pth, allow_pickle=True)
                    #print(f'loading @ {mapped_file_pth}')
                except Exception as e:
                    print(f'Error while reading {mapped_file_pth}', e)
        print(mapped_predictions.keys())
        label_tuple_by_domain[config['domain']] = (mapped_predictions[key][-2], mapped_predictions[key][-1])
assert len(mapped_predictions) % len(domains) == 0, "Number of predictions per domain is not consistent!"
print(f"Loaded {len(mapped_predictions)} files for {len(domains)} domains")

searching 4 directories @ ../outputs/SGD/
dict_keys([('Events_2', 'flan-t5-xxl', ('entire-concise', 5, False), 1.0, 'multi', 1636423)])
dict_keys([('Events_2', 'flan-t5-xxl', ('entire-concise', 5, False), 1.0, 'multi', 1636423), ('Events_2', 'flan-t5-xxl', ('entire-concise', 5, False), 0.0, 'multi', 1636423)])
dict_keys([('Events_2', 'flan-t5-xxl', ('entire-concise', 5, False), 1.0, 'multi', 1636423), ('Events_2', 'flan-t5-xxl', ('entire-concise', 5, False), 0.0, 'multi', 1636423), ('Banks_1', 'flan-t5-xxl', ('entire-concise', 5, False), 1.0, 'multi', 1636423)])
dict_keys([('Events_2', 'flan-t5-xxl', ('entire-concise', 5, False), 1.0, 'multi', 1636423), ('Events_2', 'flan-t5-xxl', ('entire-concise', 5, False), 0.0, 'multi', 1636423), ('Banks_1', 'flan-t5-xxl', ('entire-concise', 5, False), 1.0, 'multi', 1636423), ('Banks_1', 'flan-t5-xxl', ('entire-concise', 5, False), 0.0, 'multi', 1636423)])
Loaded 4 files for 2 domains


## Load inferred graphs

In [157]:
from util.graph_utils import get_graph_sop
def load_graphs(domains_list, root_dir, dataset, graph_alg_name, label_tuple_by_domain, is_should, filter="*"):
    load_count = 0
    num_graph_per_domain = []
    graphs = {}
    print(f'Loading {root_dir}/{dataset}_"domain"/*{graph_alg_name}{filter}.npy')
    for domain in domains_list:
        #print(f'loading @ {root_dir}/{dataset}_{domain}*/*{graph_alg_name}*.npy')
        all_acts, all_statuses = label_tuple_by_domain[domain]
        graph_algo_dict = {}
        matchings = list(Path(root_dir).glob(f'{dataset}_{domain}/*{graph_alg_name}{filter}.npy'))
        num_graph_per_domain.append(len(matchings))
        for matching in matchings:
            graph_path = str(matching)
            graph_raw = np.load(graph_path, allow_pickle=True).item()
            alg_name = graph_path.split('/')[-1].replace('.npy', '').replace("inferred_graph_", "")
            graph_sop = get_graph_sop(
                graph_raw,
                subtask_list=all_statuses,
                option_list=all_acts,
                empty_value=False if is_should else True
            )
            graph_algo_dict[alg_name] = graph_sop
            load_count += 1
        graphs[domain] = graph_algo_dict
    assert all([num_graph == num_graph_per_domain[0] for num_graph in num_graph_per_domain]), "Error. Num graph is different for each domain"
    return graphs, load_count

In [158]:
domains_list = list(domains)
#
print(INF_GRAPH_DIR)
is_should=False
graphs, load_count = load_graphs(domains_list, INF_GRAPH_DIR, DATASET, GRAPH_ALG_NAME, label_tuple_by_domain, is_should, filter="*")
print(f"Loaded {load_count} inferred CAN+SHDNT graphs from {len(domains)} domains")
#
print(SHD_DIR)
is_should=True
shd_sops, load_count = load_graphs(domains_list, SHD_DIR, DATASET, SHD_ALG_NAME, label_tuple_by_domain, is_should, filter="*")
print(f"Loaded {load_count} inferred SHD graphs from {len(domains)} domains")
#shd_sops = None


../graphs/SGD/CCAOILP
Loading ../graphs/SGD/CCAOILP/SGD_"domain"/*CCAOILP*.npy
Loaded 2 inferred CAN+SHDNT graphs from 2 domains
../graphs/SGD/SHDILP
Loading ../graphs/SGD/SHDILP/SGD_"domain"/*SHDILP*.npy
Loaded 2 inferred SHD graphs from 2 domains


## Calculate metrics

In [159]:
from multiprocess import Pool # use multiprocessing to speed up evaluation!

In [160]:
from util.eval_utils import dact_traj_metrics_report, dact_traj_multi_sample_metrics_report, standardize_dact
from copy import deepcopy

def eval_job(args):
    pred_params, graph_params, traj = args
    domain, model, prompt_params, temp, sampling, seed = pred_params
    prompt_style, num_shot, use_mask_prompt = prompt_params
    is_multisampling = float(temp) > 0 and ('repeat' not in sampling)
    graph_names, graph_tuples = graph_params
    graphs, neg_pcond_mats, should_sops = [], [], []
    for graph_tuple in graph_tuples:
        graph, should_sop = graph_tuple
        graphs.append(graph)
        neg_pcond_mats.append(None)
        should_sops.append(should_sop)
    
    if not isinstance(traj, tuple):
        gt_processed_label_tuple = tuple(traj)
    else:
        gt_processed_label_tuple = traj
    #print(f'In {pred_params} with multisampling={is_multisampling}')
    if is_multisampling:
        report_list = dact_traj_multi_sample_metrics_report(*gt_processed_label_tuple, graph_sop=graphs, neg_precond_mat=neg_pcond_mats, should_sops=should_sops, verbose=False)
    else:
        report_list = dact_traj_metrics_report(*gt_processed_label_tuple, graph_sop=graphs, neg_precond_mat=neg_pcond_mats, should_sops=should_sops, verbose=False)
    
    metrics_list = []
    #print(graph_names)
    for reprt, graph_name in zip(report_list,graph_names):
        if not isinstance(reprt, tuple):
            reprt = [reprt]
        for report in reprt:
            stats = report['Predicted']
            post = report['post']
            metrics = {
                'domain': domain[:4]+domain[-1],
                'model': model,
                'prompt': prompt_style,
                'shot': num_shot,
                'use_mask_prompt': use_mask_prompt,
                'temp': temp,
                'sampling': sampling,
                'seed': seed,
                'graph': graph_name,    
                'precision': stats['precision'],
                'recall': stats['recall'],
                'f1': stats['f1-score'],
                'support': stats['support'],
                'postprocess': post
            }
            metrics_list.append(metrics)
    return metrics_list
    
jobs = []
assert shd_sops is not None, "Error: SHOULD is empty"
for pred_params, mapped_pred_tuple in mapped_predictions.items():
    domain, model, prompt_params, temp, sampling, seed = pred_params
    #print(domain)
    #print(graphs)
    if temp > 0: # in case multi sampling, we cannot run without graph
        graph_list = list(graphs.get(domain, {}).items())
        shd_list = list(shd_sops.get(domain, {}).items())
    else:
        graph_list = list(graphs.get(domain, {}).items())
        shd_list = [('(None)', None)] + list(shd_sops.get(domain, {}).items())
    #print(graph_list)
    assert len(graph_list) == 1, "Error: current code cannot handle more than one precondition"
    can_graph = graph_list[0]
    """
    # ==== TEMPORARY DEBUGGING =====
    for i, shd in enumerate(shd_list):
        shd_sop_list = shd[1]
        if shd_sop_list is None: # no graph
            continue
        for j, sop in enumerate(shd_sop_list):
            shd_list[i][1][j] = False
        #print(f"[{domain}, {shd[0]}] {sum(is_shd_sop_none)} / {len(is_shd_sop_none)} should is None")
    # ==== TEMPORARY DEBUGGING =====
    """
    graph_names = [shd[0] for shd in shd_list]
    graph_tuples = [(can_graph[1], shd[1]) for shd in shd_list]
    #graph_tuples = [(can_graph[1], None) for shd in shd_list] # temporary: Exclude should
    #graph_tuples = [(None, shd[1]) for shd in shd_list] # temporary: should-only
    
    graph_params = (graph_names, graph_tuples)
    mapped_pred_tuple = mapped_pred_tuple[:-2] # remove last two: all_acts, all_statuses
    jobs.append((pred_params, graph_params, mapped_pred_tuple))
print(f"# jobs={len(jobs)}")

with Pool(min(60, len(jobs))) as p:
    raw_metrics = [result for result in tqdm(p.imap(eval_job, jobs)) if result is not None]
#raw_metrics = [eval_job(job) for job in jobs]
metrics = []
for elem in raw_metrics:
    for metric_dict in elem:
        metrics.append(metric_dict)
print(f"output={len(metrics)}")

# jobs=4


4it [00:00,  7.69it/s]

output=12





In [161]:
org_metrics_df = pd.DataFrame(metrics)
metrics_df = org_metrics_df.copy()

nunique = metrics_df.nunique()
cols_to_drop = nunique[nunique == 1].index
metrics_df = metrics_df.drop(cols_to_drop, axis=1)

### 1. Aggregated performance (averaged over schemas)

In [162]:
gpt_base_performance = 0.787513 if DATASET == 'SGD' else 0.446
t5_base_performance = 0.499171 if DATASET == 'SGD' else 0.304
base_performance = (gpt_base_performance + t5_base_performance) / 2
print(f"Mean={base_performance:.3f}, GPT={gpt_base_performance:.3f}, T5={t5_base_performance:.3f}")

Mean=0.643, GPT=0.788, T5=0.499


In [163]:
#rows = ['model', 'graph']
rows = ['graph']
columns = ['postprocess']
display_df = metrics_df.pivot_table(index=rows, columns=columns, values='f1', aggfunc='mean')
display_df

postprocess,None,major,max,uniform,violation
graph,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
(None),0.55866,,,,
SHDILP_bw=4_bd=4_cp=0.01_mins=0.9,0.795395,0.791888,0.741792,0.764794,0.768255


In [164]:
def parse_graph_hparam(hparam_str, param_names):
    tokens = hparam_str.split('_')
    hparam_dict = dict(algo=tokens[0])
    for par, token in zip(param_names, tokens[1:]):
        value = token.split('=')[1]
        try:
            hparam_dict[par] = int(value)
            continue
        except Exception as e:
            pass
        try:
            hparam_dict[par] = float(value)
            continue
        except Exception as e:
            pass
        hparam_dict[par] = value
    return hparam_dict

In [165]:
if mode % 10 == 2 or mode % 10 == 3:
    new_metrics = []
    for metric_dict in metrics:
        hparam_str = metric_dict['graph']
        if SHD_ALG_NAME == 'SHDILP': # inferred_graph_SHDILP_bw=4_bd=4_cp=0.01_mins=0.9
            param_names = ['bw', 'bd', 'cp', 'mins']
        hparam_dict = parse_graph_hparam(hparam_str, param_names)
        metric_dict.update(hparam_dict)
        new_metrics.append(metric_dict)
    print(f'[{SHD_ALG_NAME} @{DATASET}] num runs: {len(new_metrics)}')
    new_metrics_df = pd.DataFrame(new_metrics)
    graph_columns = param_names[1::2]
    graph_rows = param_names[::2]
    df_agg = new_metrics_df.pivot_table(index=graph_rows, columns=graph_columns, values='f1', aggfunc='mean')
    df_agg -= base_performance
    display(df_agg.round(decimals=3))


[SHDILP @SGD] num runs: 12


Unnamed: 0_level_0,bd,4.0
Unnamed: 0_level_1,mins,0.9
bw,cp,Unnamed: 2_level_2
4.0,0.01,0.129


In [166]:
df_group = new_metrics_df.groupby(graph_rows + graph_columns)
print(f"{num_domains * 2} == {list(df_group.size())} ?")
assert all([num_elem == num_domains * 2 for num_elem in df_group.size()])
#df_group.mean()

NameError: name 'num_domains' is not defined

In [None]:
if mode % 10 == 2:
    for hparam_label in param_names:
        new_df = new_metrics_df.groupby(hparam_label).agg('mean', numeric_only=True)['f1'].to_frame()
        new_df -= base_performance # subtract no graph performance
        if len(new_df.index) == 1 or hparam_label == 'model':
            continue
        display_df = new_df.T.round(decimals=3)
        display(display_df)

mins,0.60,0.70,0.80,0.83,0.85,0.87,0.90,0.93,0.95,0.97,1.00
f1,0.102,0.102,0.102,0.102,0.102,0.102,0.102,0.102,0.102,0.102,0.102


In [None]:
graph_columns = ['model'] + param_names[1::2]
graph_columns = param_names[1::2]
graph_rows = ['model'] + param_names[::2]
df_agg2 = new_metrics_df.pivot_table(index=graph_rows, columns=graph_columns, values='f1', aggfunc='mean')
display(df_agg2.round(decimals=3))

Unnamed: 0_level_0,Unnamed: 1_level_0,bd,4,4,4,4,4,4,4,4,4,4,4
Unnamed: 0_level_1,Unnamed: 1_level_1,mins,0.60,0.70,0.80,0.83,0.85,0.87,0.90,0.93,0.95,0.97,1.00
model,bw,cp,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
flan-t5-xxl,4,0.01,0.656,0.656,0.656,0.656,0.656,0.656,0.656,0.656,0.656,0.656,0.656
gpt-turbo,4,0.01,0.835,0.835,0.835,0.835,0.835,0.835,0.835,0.835,0.835,0.835,0.835


### 2. schema-wise performance

In [None]:
from IPython.display import display
print(rows+columns)
display_df = metrics_df.pivot(index=rows+columns, columns=['domain'], values='f1')
display_df = display_df.round(3)
num_columns = len(display_df.columns)
if num_columns > 12:
    df1 = display_df.iloc[:, :num_columns//2]
    df2 = display_df.iloc[:, num_columns//2:]
    display(df1)
    display(df2)
else:
    display(display_df)


['model', 'graph', 'postprocess']


Unnamed: 0_level_0,Unnamed: 1_level_0,domain,Bank1,Buse1,Buse2,Cale1,Even1,Even2,Flig1,Flig2,Home1,Hote1,Hote2,Hote3
model,graph,postprocess,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
flan-t5-xxl,(None),,0.565,0.543,0.582,0.378,0.461,0.464,0.217,0.685,0.813,0.613,0.565,0.551
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.6,,0.565,0.543,0.582,0.378,0.461,0.464,0.217,0.685,0.813,0.613,0.565,0.551
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.6,max,0.591,0.589,0.642,0.586,0.524,0.586,0.389,0.674,0.845,0.642,0.605,0.606
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.7,,0.565,0.543,0.582,0.378,0.461,0.464,0.217,0.685,0.813,0.613,0.565,0.551
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.7,max,0.591,0.589,0.642,0.586,0.524,0.586,0.389,0.674,0.845,0.642,0.605,0.606
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.8,,0.565,0.543,0.582,0.378,0.461,0.464,0.217,0.685,0.813,0.613,0.565,0.551
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.8,max,0.591,0.589,0.642,0.586,0.524,0.586,0.389,0.674,0.845,0.642,0.605,0.606
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.83,,0.565,0.543,0.582,0.378,0.461,0.464,0.217,0.685,0.813,0.613,0.565,0.551
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.83,max,0.591,0.589,0.642,0.586,0.524,0.586,0.389,0.674,0.845,0.642,0.605,0.606
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.85,,0.565,0.543,0.582,0.378,0.461,0.464,0.217,0.685,0.813,0.613,0.565,0.551


Unnamed: 0_level_0,Unnamed: 1_level_0,domain,Medi1,Movi1,Musi1,Musi2,Rent1,Rent2,Rest1,Ride1,Ride2,Serv1,Serv2,Serv3
model,graph,postprocess,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
flan-t5-xxl,(None),,0.778,0.681,0.819,0.692,0.361,0.45,0.473,0.635,0.561,0.692,0.734,0.704
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.6,,0.778,0.681,0.819,0.692,0.361,0.45,0.473,0.635,0.561,0.692,0.734,0.704
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.6,max,0.814,0.73,0.82,0.754,0.502,0.489,0.539,0.736,0.824,0.807,0.733,0.718
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.7,,0.778,0.681,0.819,0.692,0.361,0.45,0.473,0.635,0.561,0.692,0.734,0.704
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.7,max,0.814,0.73,0.82,0.754,0.502,0.489,0.539,0.736,0.824,0.807,0.733,0.718
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.8,,0.778,0.681,0.819,0.692,0.361,0.45,0.473,0.635,0.561,0.692,0.734,0.704
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.8,max,0.814,0.73,0.82,0.754,0.502,0.489,0.539,0.736,0.824,0.807,0.733,0.718
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.83,,0.778,0.681,0.819,0.692,0.361,0.45,0.473,0.635,0.561,0.692,0.734,0.704
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.83,max,0.814,0.73,0.82,0.754,0.502,0.489,0.539,0.736,0.824,0.807,0.733,0.718
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.85,,0.778,0.681,0.819,0.692,0.361,0.45,0.473,0.635,0.561,0.692,0.734,0.704


### 3. Prec and Rec

In [None]:
display_df = metrics_df.pivot_table(index=rows, columns=columns, values=['precision', 'recall'], aggfunc='mean')
display_df.round(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,precision,precision,recall,recall
Unnamed: 0_level_1,postprocess,None,max,None,max
model,graph,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
flan-t5-xxl,(None),0.869,,0.458,
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.6,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.7,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.8,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.83,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.85,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.87,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.9,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.93,0.869,0.732,0.458,0.603
flan-t5-xxl,SHDILP_bw=4_bd=4_cp=0.01_mins=0.95,0.869,0.732,0.458,0.603


### 3. Paired t-test over all schema

In [None]:
display_df.query("graph == '(None)' and model == 'flan-t5-xxl'")

Unnamed: 0_level_0,Unnamed: 1_level_0,precision,precision,recall,recall
Unnamed: 0_level_1,postprocess,None,max,None,max
model,graph,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
flan-t5-xxl,(None),0.868589,,0.458327,


In [None]:
# ref: https://pythonfordatascienceorg.wordpress.com/paired-samples-t-test-python/
from scipy import stats
if mode == 1: # with/without graph for gpt & T5 on SGD
    model='gpt-turbo' # 'flan-t5-xxl'
    df1 = display_df.query(f"graph == '(None)' and model == '{model}'")
    df1 = df1.squeeze()
    df2 = display_df.query(f"graph == 'RILP' and model == '{model}'")
    df2 = df2.squeeze()
else:
    df1 = display_df.loc[display_df['graph']=='(None)']
    df2 = display_df.loc[display_df['graph']=='RILP']
stat, pval = stats.ttest_rel(df1, df2)
print(f'Mean before={df1.mean()} Mean after={df2.mean()}')
if df1.mean() < df2.mean():
    change_text = "increased"
else:
    change_text = "decreased"
print(f'pval = {pval}')
if pval < 0.05:
    print(f'Statistically significant {change_text} by {df2.mean() - df1.mean()}!')
else:
    print('Not significant')

KeyError: 'graph'

In [None]:
t5_models=['flan-t5-xxl', 't5-xxl-lm-adapt']
mterics_t5 = metrics_df[metrics_df["model"].isin(t5_models)]
mterics_t5_5shot = mterics_t5[mterics_t5["shot"]==5]
display_df = mterics_t5_5shot.pivot(index=['model', 'graph'], columns=['domain'], values='f1')
rounded_df = display_df.round(decimals=3)
rounded_df

ValueError: Index contains duplicate entries, cannot reshape

In [None]:
print(metrics_df.pivot(index=['model', 'graph', 'shot'], columns=['domain'], values='f1').to_latex())


ValueError: Index contains duplicate entries, cannot reshape