## Analyze cross validation results

In [1]:
import collections
import json
import numpy as np
import os
import pandas as pd
import yaml
import pprint
import scipy.stats
import tqdm

pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)

In [2]:
# Read yaml files
cross_val_dir = os.path.join(os.getenv("DATA_DIR"),
                             "mica_text_coref/movie_coref/results/coreference/cross_val_Dec06-07")
files_with_empty_results = []
rows = []
for directory in tqdm.tqdm(os.listdir(cross_val_dir), unit="dir"):
    result_file = os.path.join(cross_val_dir, directory, "result.yaml")
    with open(result_file, "r") as f:
        result = yaml.load(f, Loader=yaml.FullLoader)
    file_rows = []
    for document_len, document_len_metric in result["test_metric"].items():
        for overlap_len, overlap_len_metric in document_len_metric.items():
            for strategy, strategy_metric in overlap_len_metric.items():
                row = [result["preprocess"], result["bert_lr"], result["coref_lr"], result["warmup"],
                       document_len, overlap_len, strategy, result["test_movie"]]
                for data_metric in [result["dev_metric"], strategy_metric]:
                    metric_f1s = []
                    for metric in ["muc", "bcub", "ceafe"]:
                        for score in ["precision", "recall", "f1"]:
                            row.append(data_metric["span"][metric][score])
                        metric_f1s.append(data_metric["span"][metric]["f1"])
                    row.append(np.mean(metric_f1s))
                file_rows.append(row)
    if len(file_rows) == 0:
        files_with_empty_results.append(directory)
    rows += file_rows
data = np.array(rows)

hyperparams = ["preprocess", "bert_lr", "model_lr", "warmup", "document_len", "overlap_len", "strategy"]
score_cols = []
for dataset in ["dev", "test"]:
    for metric in ["muc", "bcub", "ceafe"]:
        for score in ["p", "r", "f"]:
            score_cols.append(f"{dataset}_{metric}_{score}")
    score_cols.append(f"{dataset}_conll_f1")
cross_val_df = pd.DataFrame(rows, columns=hyperparams + ["test_movie"] + score_cols)

print(f"shape = {cross_val_df.shape}")
display(cross_val_df)
print(f"directories with empty results: {files_with_empty_results}")
print()

print("Columns:")
for column_name, dtype in zip(cross_val_df.columns[:8], cross_val_df.dtypes[:8]):
    unique_vals = cross_val_df[column_name].unique().tolist()
    print(f"\t{column_name:20s} ({dtype}) : {unique_vals}")

100%|██████████| 720/720 [00:11<00:00, 61.25dir/s]

shape = (2142, 28)





Unnamed: 0,preprocess,bert_lr,model_lr,warmup,document_len,overlap_len,strategy,test_movie,dev_muc_p,dev_muc_r,dev_muc_f,dev_bcub_p,dev_bcub_r,dev_bcub_f,dev_ceafe_p,dev_ceafe_r,dev_ceafe_f,dev_conll_f1,test_muc_p,test_muc_r,test_muc_f,test_bcub_p,test_bcub_r,test_bcub_f,test_ceafe_p,test_ceafe_r,test_ceafe_f,test_conll_f1
0,addsays,0.00005,0.0002,0.0,20480,512,avg,zootopia,89.98,90.44,90.21,72.01,78.08,74.92,61.67,45.38,52.29,72.473333,93.20,91.27,92.23,79.95,64.08,71.14,45.46,33.39,38.50,67.290000
1,addsays,0.00005,0.0002,0.0,20480,1024,avg,zootopia,89.98,90.44,90.21,72.01,78.08,74.92,61.67,45.38,52.29,72.473333,93.02,91.30,92.15,79.56,64.19,71.06,44.44,34.21,38.66,67.290000
2,addsays,0.00005,0.0002,0.0,20480,2048,avg,zootopia,89.98,90.44,90.21,72.01,78.08,74.92,61.67,45.38,52.29,72.473333,93.15,91.27,92.20,76.97,63.21,69.42,47.42,35.67,40.71,67.443333
3,regular,0.00002,0.0001,0.5,10240,512,avg,john_wick,89.98,91.18,90.58,68.54,73.87,71.11,53.64,45.04,48.96,70.216667,94.03,92.93,93.48,82.81,74.27,78.31,37.11,70.94,48.73,73.506667
4,regular,0.00002,0.0001,0.5,10240,1024,avg,john_wick,89.98,91.18,90.58,68.54,73.87,71.11,53.64,45.04,48.96,70.216667,93.88,92.81,93.34,85.61,77.09,81.13,43.20,74.96,54.81,76.426667
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2137,regular,0.00002,0.0001,-1.0,10240,1024,avg,prestige,90.50,90.74,90.62,80.18,76.47,78.28,61.26,46.81,53.07,73.990000,93.30,93.41,93.36,78.90,73.68,76.20,29.78,66.57,41.15,70.236667
2138,regular,0.00002,0.0001,-1.0,10240,2048,avg,prestige,90.50,90.74,90.62,80.18,76.47,78.28,61.26,46.81,53.07,73.990000,93.40,93.43,93.42,79.28,63.75,70.67,30.36,66.97,41.78,68.623333
2139,regular,0.00001,0.0002,-1.0,20480,512,avg,avengers_endgame,88.92,91.52,90.20,62.11,77.18,68.83,52.19,41.85,46.45,68.493333,92.26,92.91,92.58,48.44,58.08,52.82,29.97,42.21,35.05,60.150000
2140,regular,0.00001,0.0002,-1.0,20480,1024,avg,avengers_endgame,88.92,91.52,90.20,62.11,77.18,68.83,52.19,41.85,46.45,68.493333,92.29,92.83,92.56,44.18,59.34,50.65,29.64,37.58,33.14,58.783333


directories with empty results: ['Dec09_11:19:31PM_avengers_endgame', 'Dec09_09:36:06PM_avengers_endgame', 'Dec10_01:58:40AM_zootopia', 'Dec10_12:47:38AM_zootopia', 'Dec09_03:35:19PM_john_wick', 'Dec09_04:36:32PM_dead_poets_society']

Columns:
	preprocess           (object) : ['addsays', 'regular']
	bert_lr              (float64) : [5e-05, 2e-05, 1e-05]
	model_lr             (float64) : [0.0002, 0.0001]
	warmup               (float64) : [0.0, 0.5, 0.25, 1.0, -1.0]
	document_len         (int64) : [20480, 10240]
	overlap_len          (int64) : [512, 1024, 2048]
	strategy             (object) : ['avg']
	test_movie           (object) : ['zootopia', 'john_wick', 'quiet_place', 'avengers_endgame', 'dead_poets_society', 'prestige']


In [3]:
# Aggregate (macro) coreference performance scores for each movie
index_rows = []
data_rows = []
hyperparams = ["preprocess", "bert_lr", "model_lr", "warmup", "document_len", "overlap_len", "strategy"]
for hparam_values, df in cross_val_df.groupby(hyperparams):
    if len(df) == 6:
        row = df.iloc[:,len(hparam_values) + 1:].agg(["mean", "std"]).values.T.flatten().tolist()
        index_rows.append(list(hparam_values))
        data_rows.append(row)

index = pd.MultiIndex.from_product([["dev"], ["muc", "bcub", "ceafe"], ["p", "r", "f"], ["mean", "std"]],
                                   sortorder=None)
index = index.union(pd.MultiIndex.from_product([["dev"], ["conll"], ["f"], ["mean", "std"]], sortorder=None),
                    sort=False)
index = index.union(pd.MultiIndex.from_product([["test"], ["muc", "bcub", "ceafe"], ["p", "r", "f"], ["mean", "std"]],
                                               sortorder=None), sort=False)
index = index.union(pd.MultiIndex.from_product([["test"], ["conll"], ["f"], ["mean", "std"]], sortorder=None),
                    sort=False)
index.names = ["dataset", "metric", "score", "func"]

cross_val_index = pd.MultiIndex.from_frame(pd.DataFrame(index_rows, columns=hyperparams))
cross_val_agg_df = pd.DataFrame(data_rows, columns=index, index=cross_val_index)

display(cross_val_agg_df)
print(f"shape = {cross_val_agg_df.shape}")

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,dataset,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,dev,test,test,test,test,test,test,test,test,test,test,test,test,test,test,test,test,test,test,test,test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,metric,muc,muc,muc,muc,muc,muc,bcub,bcub,bcub,bcub,bcub,bcub,ceafe,ceafe,ceafe,ceafe,ceafe,ceafe,conll,conll,muc,muc,muc,muc,muc,muc,bcub,bcub,bcub,bcub,bcub,bcub,ceafe,ceafe,ceafe,ceafe,ceafe,ceafe,conll,conll
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,score,p,p,r,r,f,f,p,p,r,r,f,f,p,p,r,r,f,f,f,f,p,p,r,r,f,f,p,p,r,r,f,f,p,p,r,r,f,f,f,f
Unnamed: 0_level_3,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,func,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
preprocess,bert_lr,model_lr,warmup,document_len,overlap_len,strategy,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4,Unnamed: 15_level_4,Unnamed: 16_level_4,Unnamed: 17_level_4,Unnamed: 18_level_4,Unnamed: 19_level_4,Unnamed: 20_level_4,Unnamed: 21_level_4,Unnamed: 22_level_4,Unnamed: 23_level_4,Unnamed: 24_level_4,Unnamed: 25_level_4,Unnamed: 26_level_4,Unnamed: 27_level_4,Unnamed: 28_level_4,Unnamed: 29_level_4,Unnamed: 30_level_4,Unnamed: 31_level_4,Unnamed: 32_level_4,Unnamed: 33_level_4,Unnamed: 34_level_4,Unnamed: 35_level_4,Unnamed: 36_level_4,Unnamed: 37_level_4,Unnamed: 38_level_4,Unnamed: 39_level_4,Unnamed: 40_level_4,Unnamed: 41_level_4,Unnamed: 42_level_4,Unnamed: 43_level_4,Unnamed: 44_level_4,Unnamed: 45_level_4,Unnamed: 46_level_4
addsays,0.00001,0.0001,-1.0,10240,512,avg,89.785000,0.557127,91.993333,1.023302,90.868333,0.274621,64.293333,6.674524,77.433333,3.704266,70.000000,3.324732,55.598333,2.834667,43.100000,4.053093,48.398333,2.133808,69.755556,1.774844,93.268333,1.385473,93.238333,1.841678,93.248333,1.440145,71.096667,16.875146,64.896667,13.852042,67.596667,14.923850,32.290000,7.746222,52.151667,13.828830,39.051667,7.890448,66.632222,7.393902
addsays,0.00001,0.0001,-1.0,10240,1024,avg,89.785000,0.557127,91.993333,1.023302,90.868333,0.274621,64.293333,6.674524,77.433333,3.704266,70.000000,3.324732,55.598333,2.834667,43.100000,4.053093,48.398333,2.133808,69.755556,1.774844,93.308333,1.550399,93.308333,1.999024,93.301667,1.591571,71.271667,16.310279,67.590000,9.917736,69.020000,12.756879,34.158333,7.210188,52.108333,15.111480,40.521667,8.401894,67.614444,6.978330
addsays,0.00001,0.0001,-1.0,10240,2048,avg,89.785000,0.557127,91.993333,1.023302,90.868333,0.274621,64.293333,6.674524,77.433333,3.704266,70.000000,3.324732,55.598333,2.834667,43.100000,4.053093,48.398333,2.133808,69.755556,1.774844,93.166667,1.474499,93.553333,1.933853,93.351667,1.483475,69.970000,17.678938,70.103333,8.909282,69.453333,13.597955,34.831667,9.311261,51.855000,13.845201,40.836667,9.652290,67.880556,7.137034
addsays,0.00001,0.0001,-1.0,20480,512,avg,89.785000,0.557127,91.993333,1.023302,90.868333,0.274621,64.293333,6.674524,77.433333,3.704266,70.000000,3.324732,55.598333,2.834667,43.100000,4.053093,48.398333,2.133808,69.755556,1.774844,93.038333,1.463235,93.538333,1.896506,93.280000,1.492890,68.050000,19.820868,69.590000,8.807735,67.423333,14.819808,37.358333,10.337701,48.525000,15.372656,41.035000,10.153117,67.246111,7.965339
addsays,0.00001,0.0001,-1.0,20480,1024,avg,89.785000,0.557127,91.993333,1.023302,90.868333,0.274621,64.293333,6.674524,77.433333,3.704266,70.000000,3.324732,55.598333,2.834667,43.100000,4.053093,48.398333,2.133808,69.755556,1.774844,93.141667,1.509668,93.620000,1.773415,93.375000,1.454741,69.068333,19.501076,71.495000,8.180496,68.750000,14.211186,37.490000,10.609573,48.870000,13.463684,41.316667,9.553531,67.813889,7.168448
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
regular,0.00005,0.0002,1.0,10240,1024,avg,90.156667,2.260864,86.000000,8.332483,87.750000,3.978779,74.786667,8.508092,66.830000,12.216805,69.331667,6.164188,49.611667,11.259193,43.503333,4.906998,46.033333,7.289811,67.705000,5.674917,93.441667,2.492849,90.553333,5.285318,91.861667,2.403218,72.418333,23.106016,68.273333,7.823328,68.363333,14.848357,32.958333,6.940750,50.988333,17.760436,38.581667,8.166903,66.268889,7.393538
regular,0.00005,0.0002,1.0,10240,2048,avg,90.156667,2.260864,86.000000,8.332483,87.750000,3.978779,74.786667,8.508092,66.830000,12.216805,69.331667,6.164188,49.611667,11.259193,43.503333,4.906998,46.033333,7.289811,67.705000,5.674917,93.408333,2.788415,90.825000,5.199549,91.985000,2.466307,71.576667,22.834657,67.430000,10.146944,67.253333,14.760268,34.698333,7.526764,52.160000,19.364947,40.061667,9.822563,66.433333,7.997954
regular,0.00005,0.0002,1.0,20480,512,avg,90.156667,2.260864,86.000000,8.332483,87.750000,3.978779,74.786667,8.508092,66.830000,12.216805,69.331667,6.164188,49.611667,11.259193,43.503333,4.906998,46.033333,7.289811,67.705000,5.674917,93.380000,2.489024,90.918333,5.075689,92.025000,2.278629,67.916667,23.623895,67.723333,11.698854,64.990000,14.348695,34.288333,7.280353,49.958333,19.338394,39.238333,9.912183,65.417778,7.892054
regular,0.00005,0.0002,1.0,20480,1024,avg,90.156667,2.260864,86.000000,8.332483,87.750000,3.978779,74.786667,8.508092,66.830000,12.216805,69.331667,6.164188,49.611667,11.259193,43.503333,4.906998,46.033333,7.289811,67.705000,5.674917,93.445000,2.581757,90.983333,5.170376,92.088333,2.378583,70.243333,23.377408,69.586667,12.034138,67.210000,15.035839,34.856667,7.547279,51.458333,19.426548,40.153333,9.835214,66.483889,8.345340


shape = (342, 40)


In [4]:
# Find hyperparameter configs for which results are empty
hyperparams_ext = hyperparams + ["test_movie"]
hyperparams_ext_listoftups = []
for hyperparam in hyperparams_ext:
    hparam_vals = sorted(cross_val_df[hyperparam].unique().tolist())
    hyperparams_ext_listoftups.append(hparam_vals)
allpermute_hyperparam_df = pd.MultiIndex.from_product(hyperparams_ext_listoftups).to_frame(index=False)
allpermute_hyperparam_df.columns = hyperparams_ext
allpermute_hyperparam_df = allpermute_hyperparam_df.sort_values(hyperparams_ext)

hyperparam_df = cross_val_df.loc[:, hyperparams_ext].copy()
indicator = allpermute_hyperparam_df.merge(hyperparam_df, how="left", indicator=True)["_merge"]
rem_df = allpermute_hyperparam_df[indicator == "left_only"]
display(rem_df)

Unnamed: 0,preprocess,bert_lr,model_lr,warmup,document_len,overlap_len,strategy,test_movie
794,addsays,5e-05,0.0001,0.25,10240,512,avg,john_wick
800,addsays,5e-05,0.0001,0.25,10240,1024,avg,john_wick
806,addsays,5e-05,0.0001,0.25,10240,2048,avg,john_wick
865,addsays,5e-05,0.0001,1.0,10240,512,avg,dead_poets_society
871,addsays,5e-05,0.0001,1.0,10240,1024,avg,dead_poets_society
877,addsays,5e-05,0.0001,1.0,10240,2048,avg,dead_poets_society
936,addsays,5e-05,0.0002,0.0,10240,512,avg,avengers_endgame
942,addsays,5e-05,0.0002,0.0,10240,1024,avg,avengers_endgame
948,addsays,5e-05,0.0002,0.0,10240,2048,avg,avengers_endgame
954,addsays,5e-05,0.0002,0.0,20480,512,avg,avengers_endgame


In [6]:
# maximum test set performance
cross_val_agg_df.loc[cross_val_agg_df["test"]["conll"]["f"]["mean"]
                     == cross_val_agg_df["test"]["conll"]["f"]["mean"].max(), ("test", slice(None), "f")]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,dataset,test,test,test,test,test,test,test,test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,metric,muc,muc,bcub,bcub,ceafe,ceafe,conll,conll
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,score,f,f,f,f,f,f,f,f
Unnamed: 0_level_3,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,func,mean,std,mean,std,mean,std,mean,std
preprocess,bert_lr,model_lr,warmup,document_len,overlap_len,strategy,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4
regular,2e-05,0.0001,-1.0,20480,2048,avg,93.621667,1.223461,73.2,11.976626,46.708333,8.389394,71.176667,6.752167


In [64]:
# maximum dev set performance
cross_val_agg_df.loc[cross_val_agg_df["dev"]["conll"]["f"]["mean"]
                     == cross_val_agg_df["dev"]["conll"]["f"]["mean"].max(), ("dev", slice(None), "f")]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,dataset,dev,dev,dev,dev,dev,dev,dev,dev
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,metric,muc,muc,bcub,bcub,ceafe,ceafe,conll,conll
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,score,f,f,f,f,f,f,f,f
Unnamed: 0_level_3,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,func,mean,std,mean,std,mean,std,mean,std
preprocess,bert_lr,model_lr,warmup,document_len,overlap_len,strategy,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4
addsays,5e-05,0.0001,0.0,10240,512,avg,89.706667,0.855048,73.816667,2.36768,51.17,2.92189,71.564444,1.834156
addsays,5e-05,0.0001,0.0,10240,1024,avg,89.706667,0.855048,73.816667,2.36768,51.17,2.92189,71.564444,1.834156
addsays,5e-05,0.0001,0.0,10240,2048,avg,89.706667,0.855048,73.816667,2.36768,51.17,2.92189,71.564444,1.834156
addsays,5e-05,0.0001,0.0,20480,512,avg,89.706667,0.855048,73.816667,2.36768,51.17,2.92189,71.564444,1.834156
addsays,5e-05,0.0001,0.0,20480,1024,avg,89.706667,0.855048,73.816667,2.36768,51.17,2.92189,71.564444,1.834156
addsays,5e-05,0.0001,0.0,20480,2048,avg,89.706667,0.855048,73.816667,2.36768,51.17,2.92189,71.564444,1.834156


In [65]:
# test set performance at maximum dev set performance
cross_val_agg_df.loc[cross_val_agg_df["dev"]["conll"]["f"]["mean"]
                     == cross_val_agg_df["dev"]["conll"]["f"]["mean"].max(), ("test", slice(None), "f")]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,dataset,test,test,test,test,test,test,test,test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,metric,muc,muc,bcub,bcub,ceafe,ceafe,conll,conll
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,score,f,f,f,f,f,f,f,f
Unnamed: 0_level_3,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,func,mean,std,mean,std,mean,std,mean,std
preprocess,bert_lr,model_lr,warmup,document_len,overlap_len,strategy,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4
addsays,5e-05,0.0001,0.0,10240,512,avg,92.963333,1.362317,66.251667,10.889822,37.995,6.344379,65.736667,5.331055
addsays,5e-05,0.0001,0.0,10240,1024,avg,92.96,1.39446,69.698333,9.989906,38.75,5.29655,67.136111,4.84675
addsays,5e-05,0.0001,0.0,10240,2048,avg,93.003333,1.340308,70.468333,10.238925,40.215,4.868416,67.895556,5.23436
addsays,5e-05,0.0001,0.0,20480,512,avg,92.951667,1.310197,67.726667,11.718882,39.868333,6.538197,66.848889,6.113898
addsays,5e-05,0.0001,0.0,20480,1024,avg,92.99,1.352938,69.096667,11.437835,40.113333,6.746637,67.4,5.735403
addsays,5e-05,0.0001,0.0,20480,2048,avg,93.105,1.239689,69.465,11.016511,40.406667,6.431459,67.658889,5.522639


In [62]:
# top 50 test set performances
display(cross_val_agg_df.sort_values(("test", "conll", "f", "mean"), ascending=False)
                        .loc[:,(slice(None), "conll", "f", "mean")]
                        .iloc[:50])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,dataset,dev,test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,metric,conll,conll
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,score,f,f
Unnamed: 0_level_3,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,func,mean,mean
preprocess,bert_lr,model_lr,warmup,document_len,overlap_len,strategy,Unnamed: 7_level_4,Unnamed: 8_level_4
regular,2e-05,0.0001,-1.0,20480,2048,avg,69.984444,71.176667
regular,2e-05,0.0001,0.25,20480,2048,avg,70.026667,70.183333
regular,2e-05,0.0001,-1.0,20480,1024,avg,69.984444,70.107222
regular,2e-05,0.0001,-1.0,10240,2048,avg,69.984444,70.081667
regular,5e-05,0.0001,0.25,20480,2048,avg,71.479444,70.076111
regular,1e-05,0.0002,-1.0,20480,2048,avg,69.022778,70.067222
regular,2e-05,0.0002,0.5,10240,2048,avg,71.279444,70.011111
regular,1e-05,0.0002,0.0,20480,2048,avg,71.243889,70.000556
regular,2e-05,0.0002,0.5,20480,2048,avg,71.279444,69.909444
regular,2e-05,0.0001,0.25,10240,2048,avg,70.026667,69.879444


In [68]:
cross_val_df.groupby("test_movie").agg({"test_conll_f1": "max"}).mean()

test_conll_f1    73.558333
dtype: float64