In [1]:
# this file takes in all files and produce the appropriate d2s analysis
from pathlib import Path
from datasets import load_dataset
from evaluate import load
import pandas as pd
import numpy as np
from pathlib import Path
from functional import seq
from funcutils import underscore as _
from funcutils import get
from IPython.display import display, display_html, HTML
from editdistance import distance as edit_distance

import matplotlib.pyplot as plt

In [2]:
plt.style.use('seaborn-v0_8-whitegrid')
params = {"ytick.color" : "black",
          "xtick.color" : "black",
          "axes.labelcolor" : "black",
          "axes.edgecolor" : "black",
          "text.usetex" : True,
          "font.family" : "serif",
          "font.serif" : ["Computer Modern Serif"]}
plt.rcParams.update(params)

In [3]:
dspl_html = lambda x: display_html(x, raw=True)
rouge = load('rouge')

In [4]:
root_path = Path("/home/vente/repos/nlgs-research")

pkl = max( (root_path / "pipeline/predictions").glob("*mt*"))
pkl.name

'mt-t5-small-5.pkl'

In [5]:
OUTPUT_PATH = root_path / "pipeline/scores" / pkl.name.removesuffix(".pkl")
OUTPUT_PATH.mkdir(exist_ok=True)
OUTPUT_PATH

PosixPath('/home/vente/repos/nlgs-research/pipeline/scores/mt-t5-small-5')

 ## First, Data to sentence.

In [6]:
preds_raw = pd.read_pickle(pkl)
test_predictions = preds_raw[preds_raw.task == 'd2s']

In [7]:
compute_rouge = lambda x,y: rouge.compute(references=[x], predictions=[y], use_stemmer=False, use_aggregator=False)
compute_rouge(["general kenobi"], "general kenobi")
y_pred = test_predictions.drop(columns=['input_ids','attention_mask','pred_ids','labels'])

chunked = (
  seq(y_pred.to_dict('records'))
    .group_by(get.record_idx)
    .map(get[1]) # focus on teh values
    .map(lambda x: [
      seq(x).map(get.nl).to_list(),        # gather up all of the references
      seq(x).map(get.decoded).to_list()[0] # and the first prediction
    ])
)
chunked

0,1
"['Abilene, Texas is served by the Abilene regional airport.', 'Abilene Regional Airport serves the city of Abilene in Texas.']","d2s 1: Abilene Regional Airport|city served|Abilene, Texas"
"['Adolfo Suarez Madrid-Barajas Airport can be found in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.', 'Adolfo Suarez Madrid-Barajas airport is located at Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.', 'Adolfo Suarez Madrid-Barajas Airport is located in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.']","d2s 1: Adolfo Suarez Madrid-Barajas Airport|location|Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas"
"['The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.', 'The runway name at Adolfo Suarez Madrid-Barajas airport is 18L/36R.', 'The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.']",d2s 1: Adolfo Suarez Madrid-Barajas Airport|runway name|18L/36R
"['Afonso Pena International Airport ICAO Location Idenitifier is SBCT.', 'SBCT is the ICAO location identifier of Afonso Pena International Airport.']",d2s 1: Afonso Pena International Airport|i c a o location identifier|SBCT
"['Afonso Pena International Airport serves the city of Curitiba.', 'Afonso Pena International Airport serves Curitiba.']",d2s 1: Afonso Pena International Airport|city served|Curitiba
"['The Al Taqaddum Air Base serves the city of Fallujah.', 'Al-Taqaddum Air Base serves the city of Fallujah.']",d2s 1: Al-Taqaddum Air Base|city served|Fallujah
"['The runway length of Al-Taqaddum Air Base is 3684.0.', 'The length of the runway at Al-Taqaddum Air Base is 3684.0.']",d2s 0: Al-Taqaddum Air Base|runway length|3684.0
"['Alderney Airport runway name is 14/32.', '14/32 is the runway name of Alderney Airport.', 'The runway name of Alderney Airport is 14/32.']",d2s 0: Alderney Airport|runway name|14/32
"['The runway length at Allama Iqbal International Airport is 3,360.12.', 'The runway at Allama Iqbal International Airport has a length of 3360.12.', 'The runway at Allama Iqbal International Airport is 3360.12 long.']",d2s 1: Allama Iqbal International Airport|runway length|3360.12
"[""The first runway at Amsterdam's Schiphol Airport is known as Number 18."", ""The Amsterdam Airport Schiphol's 1st runway number is 18."", 'The number of the 1st runway at Amsterdam Airport Schiphol is 18.']",d2s 0: Amsterdam Airport Schiphol|1st runway name|No 18.


In [8]:
rouge_scores = (
  chunked.starmap(compute_rouge)
    # only one rouge per sample, so take the first of the values
    .map(lambda x: seq(x.items()).starmap(lambda x,y : {x:y[0]}))
    # rolling union on dictionaries since they are disjoint
    .map(lambda x: seq(x).reduce(lambda x, y: x | y))
    .to_pandas()
)
rouge_scores

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
0,0.666667,0.375000,0.631579,0.631579
1,0.842105,0.777778,0.842105,0.842105
2,0.782609,0.571429,0.608696,0.608696
3,0.583333,0.363636,0.545455,0.545455
4,0.666667,0.461538,0.666667,0.666667
...,...,...,...,...
1595,0.454545,0.230769,0.302521,0.302521
1596,0.666667,0.341463,0.404762,0.404762
1597,0.355030,0.168675,0.284024,0.284024
1598,0.530973,0.307692,0.452830,0.452830


In [9]:
rouge_scores.describe()

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
count,1600.0,1600.0,1600.0,1600.0
mean,0.561394,0.30651,0.461987,0.461987
std,0.107615,0.139222,0.132266,0.132266
min,0.181818,0.0,0.142857,0.142857
25%,0.489796,0.210526,0.366667,0.366667
50%,0.555556,0.3,0.444444,0.444444
75%,0.625,0.382509,0.545455,0.545455
max,0.909091,0.9,0.909091,0.909091


In [10]:
bleu = load('sacrebleu')
bleu
compute_bleu = lambda x,y: bleu.compute(references=[x], predictions=[y])

In [11]:
bleu_scores = (
  chunked.starmap(compute_bleu)
    .to_pandas()
)

In [12]:
bleu_scores

Unnamed: 0,score,counts,totals,precisions,bp,sys_len,ref_len
0,20.780604,"[8, 4, 2, 0]","[13, 12, 11, 10]","[61.53846153846154, 33.333333333333336, 18.181...",1.0,13,11
1,64.360455,"[17, 15, 13, 11]","[23, 22, 21, 20]","[73.91304347826087, 68.18181818181819, 61.9047...",1.0,23,22
2,28.656122,"[9, 6, 3, 1]","[14, 13, 12, 11]","[64.28571428571429, 46.15384615384615, 25.0, 9...",1.0,14,13
3,18.922406,"[7, 4, 2, 1]","[16, 15, 14, 13]","[43.75, 26.666666666666668, 14.285714285714286...",1.0,16,12
4,23.462350,"[6, 3, 2, 1]","[12, 11, 10, 9]","[50.0, 27.272727272727273, 20.0, 11.1111111111...",1.0,12,10
...,...,...,...,...,...,...,...
1595,15.379561,"[33, 18, 11, 6]","[93, 92, 91, 90]","[35.483870967741936, 19.565217391304348, 12.08...",1.0,93,69
1596,14.928788,"[29, 13, 5, 2]","[54, 53, 52, 51]","[53.7037037037037, 24.528301886792452, 9.61538...",1.0,54,53
1597,8.404055,"[29, 14, 10, 7]","[156, 155, 154, 153]","[18.58974358974359, 9.03225806451613, 6.493506...",1.0,156,47
1598,18.766606,"[31, 16, 11, 7]","[76, 75, 74, 73]","[40.78947368421053, 21.333333333333332, 14.864...",1.0,76,58


In [13]:
bertscore = load('bertscore')
compute_bert = lambda x,y: bertscore.compute(predictions=[y], references=[x], lang="en", model_type="distilbert-base-uncased" )

In [14]:
bert_scores = (
 chunked
   .starmap(compute_bert)
   .to_pandas()
   .drop(columns='hashcode')
   .applymap(np.mean)
)

In [15]:
prepend_name_to_cols = lambda x,y : x.rename(columns=lambda e: y+"_"+e)
prepend_name_to_cols(bert_scores, 'bert')
all_scores = (
  seq(bert_scores, bleu_scores, rouge_scores)
    .zip(['bert','bleu','rouge'])
    .starmap(prepend_name_to_cols)
)
scores_df = pd.concat(all_scores, axis=1)
scores_df

Unnamed: 0,bert_precision,bert_recall,bert_f1,bleu_score,bleu_counts,bleu_totals,bleu_precisions,bleu_bp,bleu_sys_len,bleu_ref_len,rouge_rouge1,rouge_rouge2,rouge_rougeL,rouge_rougeLsum
0,0.785898,0.873805,0.827067,20.780604,"[8, 4, 2, 0]","[13, 12, 11, 10]","[61.53846153846154, 33.333333333333336, 18.181...",1.0,13,11,0.666667,0.375000,0.631579,0.631579
1,0.904106,0.952915,0.927869,64.360455,"[17, 15, 13, 11]","[23, 22, 21, 20]","[73.91304347826087, 68.18181818181819, 61.9047...",1.0,23,22,0.842105,0.777778,0.842105,0.842105
2,0.888998,0.928360,0.908253,28.656122,"[9, 6, 3, 1]","[14, 13, 12, 11]","[64.28571428571429, 46.15384615384615, 25.0, 9...",1.0,14,13,0.782609,0.571429,0.608696,0.608696
3,0.807687,0.892655,0.848048,18.922406,"[7, 4, 2, 1]","[16, 15, 14, 13]","[43.75, 26.666666666666668, 14.285714285714286...",1.0,16,12,0.583333,0.363636,0.545455,0.545455
4,0.792502,0.938147,0.852189,23.462350,"[6, 3, 2, 1]","[12, 11, 10, 9]","[50.0, 27.272727272727273, 20.0, 11.1111111111...",1.0,12,10,0.666667,0.461538,0.666667,0.666667
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1595,0.803384,0.782828,0.792973,15.379561,"[33, 18, 11, 6]","[93, 92, 91, 90]","[35.483870967741936, 19.565217391304348, 12.08...",1.0,93,69,0.454545,0.230769,0.302521,0.302521
1596,0.824956,0.846023,0.835357,14.928788,"[29, 13, 5, 2]","[54, 53, 52, 51]","[53.7037037037037, 24.528301886792452, 9.61538...",1.0,54,53,0.666667,0.341463,0.404762,0.404762
1597,0.803119,0.819222,0.811090,8.404055,"[29, 14, 10, 7]","[156, 155, 154, 153]","[18.58974358974359, 9.03225806451613, 6.493506...",1.0,156,47,0.355030,0.168675,0.284024,0.284024
1598,0.806948,0.834723,0.818820,18.766606,"[31, 16, 11, 7]","[76, 75, 74, 73]","[40.78947368421053, 21.333333333333332, 14.864...",1.0,76,58,0.530973,0.307692,0.452830,0.452830


In [16]:
scores_df.describe()

Unnamed: 0,bert_precision,bert_recall,bert_f1,bleu_score,bleu_bp,bleu_sys_len,bleu_ref_len,rouge_rouge1,rouge_rouge2,rouge_rougeL,rouge_rougeLsum
count,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0
mean,0.793839,0.845702,0.817315,14.318463,0.996254,31.0875,24.40875,0.561394,0.30651,0.461987,0.461987
std,0.037304,0.03671,0.034328,9.784737,0.018527,16.477053,11.815463,0.107615,0.139222,0.132266,0.132266
min,0.651801,0.701917,0.681451,1.241494,0.793923,6.0,5.0,0.181818,0.0,0.142857,0.142857
25%,0.772815,0.822722,0.797072,7.267074,1.0,17.0,14.0,0.489796,0.210526,0.366667,0.366667
50%,0.794727,0.844941,0.817328,11.71737,1.0,30.0,23.0,0.555556,0.3,0.444444,0.444444
75%,0.817753,0.86829,0.838385,18.391495,1.0,41.0,32.0,0.625,0.382509,0.545455,0.545455
max,0.924009,0.954741,0.93701,69.698246,1.0,156.0,69.0,0.909091,0.9,0.909091,0.909091


In [17]:
scores_df.to_pickle(OUTPUT_PATH / "d2s_scores.pkl")

In [18]:
test_predictions  = preds_raw[preds_raw.task == 's2d']

In [19]:
# define set notion of precision when multiple labels are assigned
# to a single instance, with epsilon preventing div by zero
def compute_f_measure(pred: set[str], gt: set[str], epsilon=1e-99):
    tp = len(pred.intersection(gt)) # pred true and actually true
    fp = len(gt - pred)             # in pred but not in gt
    fn = len(pred - gt)             # not in pred but actualy true

    prec = tp / (tp + fp + epsilon) 
    recl = tp / (tp + fn + epsilon) 
    f1 = (2 * prec * recl) / (prec + recl + epsilon)
    return f1

 ## Unit Tests

In [20]:
assert compute_f_measure(set("a"), set('a')) == 1
assert compute_f_measure(set("ab"), set('a')) == 2/3
assert compute_f_measure(set() , set('a')) == 0

In [21]:
# don't penalize for quotes or spaces
norm_split_set = lambda x: x.str.upper().str.replace("'", '').str.replace(' ','').map(_.split(";")).map(set)
y_pred = norm_split_set(test_predictions['decoded'])
y_pred

77337    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77339    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77341    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77343    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77345    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
                               ...                        
85775    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85777    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85779    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85781    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85783    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
Name: decoded, Length: 4224, dtype: object

In [22]:
y_true = norm_split_set(test_predictions['sd'])
y_true

77337    {S2D0:ABILENE,TEXASISSERVEDBYTHEABILENEREGIONA...
77339    {S2D1:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77341    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTCANBEFO...
77343    {S2D1:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77345    {S2D2:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
                               ...                        
85775    {S2D1:ESTABLISHEDIN1928,THESCHOOLOFBUSINESSAND...
85777    {S2D0:DENMARKISLEDBYTHEMONARCHYOFDEMARKANDTHEC...
85779    {S2D1:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85781    {S2D2:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85783    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
Name: sd, Length: 4224, dtype: object

In [23]:
len("d2s 1: ")

7

In [24]:
compute_rouge = lambda x,y: rouge.compute(references=[x], predictions=[y], use_stemmer=False, use_aggregator=False)
compute_rouge(["general kenobi"], "general kenobi")
y_pred = test_predictions.drop(columns=['input_ids','attention_mask','pred_ids','labels'])

chunked = (
  seq(y_pred.to_dict('records'))
    .group_by(get.record_idx)
    .map(get[1]) # focus on teh values
    .map(lambda x: [
      seq(x).map(get.nl).to_list(),        # gather up all of the references
      seq(x).map(get.decoded).to_list()[0][7:] # and the first prediction
    ])
)
chunked

0,1
"['Abilene Regional Airport|city served|Abilene, Texas', 'Abilene Regional Airport|city served|Abilene, Texas']","Abilene Regional Airport serves the city of Abilene, Texas."
"['Adolfo Suarez Madrid-Barajas Airport|location|Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas', 'Adolfo Suarez Madrid-Barajas Airport|location|Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas', 'Adolfo Suarez Madrid-Barajas Airport|location|Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas']","Adolfo Suarez Madrid-Barajas Airport is located in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas."
"['Adolfo Suarez Madrid-Barajas Airport|runway name|18L/36R', 'Adolfo Suarez Madrid-Barajas Airport|runway name|18L/36R', 'Adolfo Suarez Madrid-Barajas Airport|runway name|18L/36R']",The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.
"['Afonso Pena International Airport|i c a o location identifier|SBCT', 'Afonso Pena International Airport|i c a o location identifier|SBCT']",The ICAO location identifier of Afonso Pena International Airport is SBCT.
"['Afonso Pena International Airport|city served|Curitiba', 'Afonso Pena International Airport|city served|Curitiba']",Afonso Pena International Airport serves the city of Curitiba.
"['Al-Taqaddum Air Base|city served|Fallujah', 'Al-Taqaddum Air Base|city served|Fallujah']",Al-Taqaddum Air Base serves the city of Fallujah.
"['Al-Taqaddum Air Base|runway length|3684.0', 'Al-Taqaddum Air Base|runway length|3684.0']",The runway length of Al-Taqaddum Air Base is 3684.0.
"['Alderney Airport|runway name|14/32', 'Alderney Airport|runway name|14/32', 'Alderney Airport|runway name|14/32']",The runway name of Alderney Airport is 14/32.
"['Allama Iqbal International Airport|runway length|3360.12', 'Allama Iqbal International Airport|runway length|3360.12', 'Allama Iqbal International Airport|runway length|3360.12']",The runway length of Allama Iqbal International Airport is 3360.12.
"['Amsterdam Airport Schiphol|1st runway number|18', 'Amsterdam Airport Schiphol|1st runway number|18', 'Amsterdam Airport Schiphol|1st runway number|18']",The 1st runway at Amsterdam Airport Schiphol is 18.


In [25]:
rouge_scores = (
  chunked.starmap(compute_rouge)
    # only one rouge per sample, so take the first of the values
    .map(lambda x: seq(x.items()).starmap(lambda x,y : {x:y[0]}))
    # rolling union on dictionaries since they are disjoint
    .map(lambda x: seq(x).reduce(lambda x, y: x | y))
    .to_pandas()
)
rouge_scores

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
0,0.750000,0.428571,0.750000,0.750000
1,0.888889,0.823529,0.888889,0.888889
2,0.857143,0.631579,0.666667,0.666667
3,0.636364,0.400000,0.454545,0.454545
4,0.750000,0.428571,0.750000,0.750000
...,...,...,...,...
1595,0.643478,0.371681,0.417391,0.417391
1596,0.675325,0.373333,0.285714,0.285714
1597,0.496000,0.227642,0.304000,0.304000
1598,0.545455,0.277778,0.472727,0.472727


In [26]:
rouge_scores.describe()

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
count,1600.0,1600.0,1600.0,1600.0
mean,0.589085,0.321559,0.470035,0.470035
std,0.1243,0.152282,0.153842,0.153842
min,0.227273,0.0,0.176471,0.176471
25%,0.5,0.222222,0.355556,0.355556
50%,0.580645,0.310728,0.444444,0.444444
75%,0.666667,0.4,0.571429,0.571429
max,1.0,1.0,1.0,1.0


In [27]:
bleu = load('sacrebleu')
bleu
compute_bleu = lambda x,y: bleu.compute(references=[x], predictions=[y])

In [28]:
bleu_scores = (
  chunked.starmap(compute_bleu)
    .to_pandas()
)

In [29]:
bleu_scores

Unnamed: 0,score,counts,totals,precisions,bp,sys_len,ref_len
0,24.384183,"[7, 4, 2, 0]","[11, 10, 9, 8]","[63.63636363636363, 40.0, 22.22222222222222, 6...",1.000000,11,10
1,70.982323,"[17, 15, 13, 11]","[21, 20, 19, 18]","[80.95238095238095, 75.0, 68.42105263157895, 6...",1.000000,21,20
2,31.170907,"[9, 6, 3, 1]","[13, 12, 11, 10]","[69.23076923076923, 50.0, 27.272727272727273, ...",1.000000,13,11
3,24.107473,"[7, 4, 2, 1]","[12, 11, 10, 9]","[58.333333333333336, 36.36363636363637, 20.0, ...",0.920044,12,13
4,29.071537,"[6, 3, 2, 1]","[10, 9, 8, 7]","[60.0, 33.333333333333336, 25.0, 14.2857142857...",1.000000,10,9
...,...,...,...,...,...,...,...
1595,21.849253,"[37, 21, 14, 8]","[61, 60, 59, 58]","[60.65573770491803, 35.0, 23.728813559322035, ...",0.756776,61,78
1596,12.362695,"[24, 12, 5, 2]","[42, 41, 40, 39]","[57.142857142857146, 29.26829268292683, 12.5, ...",0.683210,42,58
1597,8.453482,"[29, 14, 10, 7]","[46, 45, 44, 43]","[63.04347826086956, 31.11111111111111, 22.7272...",0.289636,46,103
1598,16.348419,"[29, 15, 11, 7]","[55, 54, 53, 52]","[52.72727272727273, 27.77777777777778, 20.7547...",0.646383,55,79


In [30]:
bertscore = load('bertscore')
compute_bert = lambda x,y: bertscore.compute(predictions=[y], references=[x], lang="en", model_type="distilbert-base-uncased" )

In [31]:
bert_scores = (
 chunked
   .starmap(compute_bert)
   .to_pandas()
   .drop(columns='hashcode')
   .applymap(np.mean)
)

In [32]:
prepend_name_to_cols = lambda x,y : x.rename(columns=lambda e: y+"_"+e)
prepend_name_to_cols(bert_scores, 'bert')
all_scores = (
  seq(bert_scores, bleu_scores, rouge_scores)
    .zip(['bert','bleu','rouge'])
    .starmap(prepend_name_to_cols)
)
scores_df = pd.concat(all_scores, axis=1)
scores_df

Unnamed: 0,bert_precision,bert_recall,bert_f1,bleu_score,bleu_counts,bleu_totals,bleu_precisions,bleu_bp,bleu_sys_len,bleu_ref_len,rouge_rouge1,rouge_rouge2,rouge_rougeL,rouge_rougeLsum
0,0.894634,0.867282,0.880746,24.384183,"[7, 4, 2, 0]","[11, 10, 9, 8]","[63.63636363636363, 40.0, 22.22222222222222, 6...",1.000000,11,10,0.750000,0.428571,0.750000,0.750000
1,0.961804,0.954445,0.958110,70.982323,"[17, 15, 13, 11]","[21, 20, 19, 18]","[80.95238095238095, 75.0, 68.42105263157895, 6...",1.000000,21,20,0.888889,0.823529,0.888889,0.888889
2,0.924742,0.933685,0.929192,31.170907,"[9, 6, 3, 1]","[13, 12, 11, 10]","[69.23076923076923, 50.0, 27.272727272727273, ...",1.000000,13,11,0.857143,0.631579,0.666667,0.666667
3,0.893098,0.842519,0.867072,24.107473,"[7, 4, 2, 1]","[12, 11, 10, 9]","[58.333333333333336, 36.36363636363637, 20.0, ...",0.920044,12,13,0.636364,0.400000,0.454545,0.454545
4,0.902310,0.882575,0.892334,29.071537,"[6, 3, 2, 1]","[10, 9, 8, 7]","[60.0, 33.333333333333336, 25.0, 14.2857142857...",1.000000,10,9,0.750000,0.428571,0.750000,0.750000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1595,0.848056,0.822139,0.834896,21.849253,"[37, 21, 14, 8]","[61, 60, 59, 58]","[60.65573770491803, 35.0, 23.728813559322035, ...",0.756776,61,78,0.643478,0.371681,0.417391,0.417391
1596,0.841947,0.821407,0.831550,12.362695,"[24, 12, 5, 2]","[42, 41, 40, 39]","[57.142857142857146, 29.26829268292683, 12.5, ...",0.683210,42,58,0.675325,0.373333,0.285714,0.285714
1597,0.827888,0.823977,0.825928,8.453482,"[29, 14, 10, 7]","[46, 45, 44, 43]","[63.04347826086956, 31.11111111111111, 22.7272...",0.289636,46,103,0.496000,0.227642,0.304000,0.304000
1598,0.820088,0.810658,0.815346,16.348419,"[29, 15, 11, 7]","[55, 54, 53, 52]","[52.72727272727273, 27.77777777777778, 20.7547...",0.646383,55,79,0.545455,0.277778,0.472727,0.472727


In [33]:
scores_df.describe()

Unnamed: 0,bert_precision,bert_recall,bert_f1,bleu_score,bleu_bp,bleu_sys_len,bleu_ref_len,rouge_rouge1,rouge_rouge2,rouge_rougeL,rouge_rougeLsum
count,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0
mean,0.84887,0.81535,0.831574,14.536473,0.767622,21.78625,28.83,0.589085,0.321559,0.470035,0.470035
std,0.038692,0.043724,0.03945,10.716554,0.206756,10.993286,16.865242,0.1243,0.152282,0.153842,0.153842
min,0.731854,0.683687,0.712789,0.625143,0.107446,5.0,5.0,0.227273,0.0,0.176471,0.176471
25%,0.824437,0.785798,0.805452,7.189701,0.630018,13.0,14.0,0.5,0.222222,0.355556,0.355556
50%,0.847829,0.815189,0.829826,11.731175,0.79189,20.0,28.0,0.580645,0.310728,0.444444,0.444444
75%,0.872457,0.844189,0.857341,18.579157,1.0,29.0,40.0,0.666667,0.4,0.571429,0.571429
max,0.967394,0.9581,0.95811,73.587363,1.0,76.0,111.0,1.0,1.0,1.0,1.0


In [34]:
compute_rouge = lambda x,y: rouge.compute(references=[x], predictions=[y], use_stemmer=False, use_aggregator=False)
compute_rouge(["general kenobi"], "general kenobi")
y_pred = test_predictions.drop(columns=['input_ids','attention_mask','pred_ids','labels'])

chunked = (
  seq(y_pred.to_dict('records'))
    .group_by(get.record_idx)
    .map(get[1]) # focus on teh values
    .map(lambda x: [
      seq(x).map(get.nl).to_list(),        # gather up all of the references
      seq(x).map(get.decoded).to_list()[0][7:] # and the first prediction
    ])
)
chunked

0,1
"['Abilene Regional Airport|city served|Abilene, Texas', 'Abilene Regional Airport|city served|Abilene, Texas']","Abilene Regional Airport serves the city of Abilene, Texas."
"['Adolfo Suarez Madrid-Barajas Airport|location|Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas', 'Adolfo Suarez Madrid-Barajas Airport|location|Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas', 'Adolfo Suarez Madrid-Barajas Airport|location|Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas']","Adolfo Suarez Madrid-Barajas Airport is located in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas."
"['Adolfo Suarez Madrid-Barajas Airport|runway name|18L/36R', 'Adolfo Suarez Madrid-Barajas Airport|runway name|18L/36R', 'Adolfo Suarez Madrid-Barajas Airport|runway name|18L/36R']",The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.
"['Afonso Pena International Airport|i c a o location identifier|SBCT', 'Afonso Pena International Airport|i c a o location identifier|SBCT']",The ICAO location identifier of Afonso Pena International Airport is SBCT.
"['Afonso Pena International Airport|city served|Curitiba', 'Afonso Pena International Airport|city served|Curitiba']",Afonso Pena International Airport serves the city of Curitiba.
"['Al-Taqaddum Air Base|city served|Fallujah', 'Al-Taqaddum Air Base|city served|Fallujah']",Al-Taqaddum Air Base serves the city of Fallujah.
"['Al-Taqaddum Air Base|runway length|3684.0', 'Al-Taqaddum Air Base|runway length|3684.0']",The runway length of Al-Taqaddum Air Base is 3684.0.
"['Alderney Airport|runway name|14/32', 'Alderney Airport|runway name|14/32', 'Alderney Airport|runway name|14/32']",The runway name of Alderney Airport is 14/32.
"['Allama Iqbal International Airport|runway length|3360.12', 'Allama Iqbal International Airport|runway length|3360.12', 'Allama Iqbal International Airport|runway length|3360.12']",The runway length of Allama Iqbal International Airport is 3360.12.
"['Amsterdam Airport Schiphol|1st runway number|18', 'Amsterdam Airport Schiphol|1st runway number|18', 'Amsterdam Airport Schiphol|1st runway number|18']",The 1st runway at Amsterdam Airport Schiphol is 18.


In [35]:
preds_raw = pd.read_pickle(pkl)
test_predictions = preds_raw[preds_raw.task == 'd2s']
test_predictions 

Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task,input_ids,attention_mask,labels,pred_ids,decoded
77336,14495,0,test,Airport,0,d2s 0: Abilene Regional Airport|city served|Ab...,"Abilene, Texas is served by the Abilene region...",d2s,"[891, 23, 14205, 6, 2514, 19, 2098, 57, 8, 891...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 891, 23, 14205, 76...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...
77338,14495,1,test,Airport,0,d2s 1: Abilene Regional Airport|city served|Ab...,Abilene Regional Airport serves the city of Ab...,d2s,"[891, 23, 14205, 7676, 5735, 4657, 8, 690, 13,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 891, 23, 14205, 7676,...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...
77340,14496,0,test,Airport,1,d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport can be fo...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 89, ...","[0, 3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, ...",d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...
77342,14496,1,test,Airport,1,d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...
77344,14496,2,test,Airport,1,d2s 2: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
85774,16092,1,test,University,1597,d2s 1: School of Business and Social Sciences ...,"Established in 1928, the School of Business an...",d2s,"[25275, 16, 29004, 6, 8, 1121, 13, 1769, 11, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...
85776,16093,0,test,University,1598,d2s 0: School of Business and Social Sciences ...,Denmark is led by the Monarchy of Demark and t...,d2s,"[18001, 19, 2237, 57, 8, 2963, 7064, 63, 13, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769, 11...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...
85778,16093,1,test,University,1598,d2s 1: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...
85780,16093,2,test,University,1598,d2s 2: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...


In [36]:
preds_raw = pd.read_pickle(pkl)
test_predictions = preds_raw[preds_raw.task == 's2d']
test_predictions 

Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task,input_ids,attention_mask,labels,pred_ids,decoded
77337,14495,0,test,Airport,0,"s2d 0: Abilene, Texas is served by the Abilene...","Abilene Regional Airport|city served|Abilene, ...",s2d,"[891, 23, 14205, 7676, 5735, 9175, 6726, 2098,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 3, 632, 10, 891, 23, 14205, 6,...","[0, 3, 7, 357, 26, 3, 632, 10, 891, 23, 14205,...",s2d 0: Abilene Regional Airport serves the cit...
77339,14495,1,test,Airport,0,s2d 1: Abilene Regional Airport serves the cit...,"Abilene Regional Airport|city served|Abilene, ...",s2d,"[891, 23, 14205, 7676, 5735, 9175, 6726, 2098,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 891, 23, 14205, 7676,...","[0, 3, 7, 357, 26, 3, 632, 10, 891, 23, 14205,...",s2d 0: Abilene Regional Airport serves the cit...
77341,14496,0,test,Airport,1,s2d 0: Adolfo Suarez Madrid-Barajas Airport ca...,Adolfo Suarez Madrid-Barajas Airport|location|...,s2d,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 3, 632, 10, 1980, 32, 40, 89, ...","[0, 3, 7, 357, 26, 3, 632, 10, 1980, 32, 40, 8...",s2d 0: Adolfo Suarez Madrid-Barajas Airport is...
77343,14496,1,test,Airport,1,s2d 1: Adolfo Suarez Madrid-Barajas airport is...,Adolfo Suarez Madrid-Barajas Airport|location|...,s2d,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 1980, 32, 40, 89, 32,...","[0, 3, 7, 357, 26, 3, 632, 10, 1980, 32, 40, 8...",s2d 0: Adolfo Suarez Madrid-Barajas Airport is...
77345,14496,2,test,Airport,1,s2d 2: Adolfo Suarez Madrid-Barajas Airport is...,Adolfo Suarez Madrid-Barajas Airport|location|...,s2d,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 204, 10, 1980, 32, 40, 89, 32,...","[0, 3, 7, 357, 26, 3, 632, 10, 1980, 32, 40, 8...",s2d 0: Adolfo Suarez Madrid-Barajas Airport is...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
85775,16092,1,test,University,1597,"s2d 1: Established in 1928, the School of Busi...",School of Business and Social Sciences at the ...,s2d,"[1121, 13, 1769, 11, 2730, 9226, 44, 8, 71, 29...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 25275, 16, 29004, 6, ...","[0, 3, 7, 357, 26, 3, 632, 10, 37, 1121, 13, 1...",s2d 0: The School of Business and Social Scien...
85777,16093,0,test,University,1598,s2d 0: Denmark is led by the Monarchy of Demar...,School of Business and Social Sciences at the ...,s2d,"[1121, 13, 1769, 11, 2730, 9226, 44, 8, 71, 29...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 3, 632, 10, 18001, 19, 2237, 5...","[0, 3, 7, 357, 26, 3, 632, 10, 37, 1121, 13, 1...",s2d 0: The School of Business and Social Scien...
85779,16093,1,test,University,1598,s2d 1: The School of Business and Social Scien...,School of Business and Social Sciences at the ...,s2d,"[1121, 13, 1769, 11, 2730, 9226, 44, 8, 71, 29...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 37, 1121, 13, 1769, 1...","[0, 3, 7, 357, 26, 3, 632, 10, 37, 1121, 13, 1...",s2d 0: The School of Business and Social Scien...
85781,16093,2,test,University,1598,s2d 2: The School of Business and Social Scien...,School of Business and Social Sciences at the ...,s2d,"[1121, 13, 1769, 11, 2730, 9226, 44, 8, 71, 29...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 204, 10, 37, 1121, 13, 1769, 1...","[0, 3, 7, 357, 26, 3, 632, 10, 37, 1121, 13, 1...",s2d 0: The School of Business and Social Scien...


In [37]:
compute_rouge = lambda x,y: rouge.compute(references=[x], predictions=[y], use_stemmer=False, use_aggregator=False)
compute_rouge(["general kenobi"], "general kenobi")
y_pred = test_predictions.drop(columns=['input_ids','attention_mask','pred_ids','labels'])

In [38]:
chunked = (
  seq(y_pred.to_dict('records'))
    .group_by(get.record_idx)
    .map(get[1]) # focus on teh values
    .map(lambda x: [
      seq(x).map(get.sd).to_list(),        # gather up all of the references
      seq(x).map(get.decoded).to_list()[0][7:] # and the first prediction
    ])
)
chunked

0,1
"['s2d 0: Abilene, Texas is served by the Abilene regional airport.', 's2d 1: Abilene Regional Airport serves the city of Abilene in Texas.']","Abilene Regional Airport serves the city of Abilene, Texas."
"['s2d 0: Adolfo Suarez Madrid-Barajas Airport can be found in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.', 's2d 1: Adolfo Suarez Madrid-Barajas airport is located at Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.', 's2d 2: Adolfo Suarez Madrid-Barajas Airport is located in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.']","Adolfo Suarez Madrid-Barajas Airport is located in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas."
"['s2d 0: The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.', 's2d 1: The runway name at Adolfo Suarez Madrid-Barajas airport is 18L/36R.', 's2d 2: The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.']",The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.
"['s2d 0: Afonso Pena International Airport ICAO Location Idenitifier is SBCT.', 's2d 1: SBCT is the ICAO location identifier of Afonso Pena International Airport.']",The ICAO location identifier of Afonso Pena International Airport is SBCT.
"['s2d 0: Afonso Pena International Airport serves the city of Curitiba.', 's2d 1: Afonso Pena International Airport serves Curitiba.']",Afonso Pena International Airport serves the city of Curitiba.
"['s2d 0: The Al Taqaddum Air Base serves the city of Fallujah.', 's2d 1: Al-Taqaddum Air Base serves the city of Fallujah.']",Al-Taqaddum Air Base serves the city of Fallujah.
"['s2d 0: The runway length of Al-Taqaddum Air Base is 3684.0.', 's2d 1: The length of the runway at Al-Taqaddum Air Base is 3684.0.']",The runway length of Al-Taqaddum Air Base is 3684.0.
"['s2d 0: Alderney Airport runway name is 14/32.', 's2d 1: 14/32 is the runway name of Alderney Airport.', 's2d 2: The runway name of Alderney Airport is 14/32.']",The runway name of Alderney Airport is 14/32.
"['s2d 0: The runway length at Allama Iqbal International Airport is 3,360.12.', 's2d 1: The runway at Allama Iqbal International Airport has a length of 3360.12.', 's2d 2: The runway at Allama Iqbal International Airport is 3360.12 long.']",The runway length of Allama Iqbal International Airport is 3360.12.
"[""s2d 0: The first runway at Amsterdam's Schiphol Airport is known as Number 18."", ""s2d 1: The Amsterdam Airport Schiphol's 1st runway number is 18."", 's2d 2: The number of the 1st runway at Amsterdam Airport Schiphol is 18.']",The 1st runway at Amsterdam Airport Schiphol is 18.


In [39]:
chunked = (
  seq(y_pred.to_dict('records'))
    .group_by(get.record_idx)
    .map(get[1]) # focus on teh values
    .map(lambda x: [
      seq(x).map(get.sd).map(get[7:]).to_list(),        # gather up all of the references
      seq(x).map(get.decoded).to_list()[0][7:] # and the first prediction
    ])
)
chunked

0,1
"['Abilene, Texas is served by the Abilene regional airport.', 'Abilene Regional Airport serves the city of Abilene in Texas.']","Abilene Regional Airport serves the city of Abilene, Texas."
"['Adolfo Suarez Madrid-Barajas Airport can be found in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.', 'Adolfo Suarez Madrid-Barajas airport is located at Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.', 'Adolfo Suarez Madrid-Barajas Airport is located in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas.']","Adolfo Suarez Madrid-Barajas Airport is located in Madrid, Paracuellos de Jarama, San Sebastian de los Reyes and Alcobendas."
"['The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.', 'The runway name at Adolfo Suarez Madrid-Barajas airport is 18L/36R.', 'The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.']",The runway name of Adolfo Suarez Madrid-Barajas Airport is 18L/36R.
"['Afonso Pena International Airport ICAO Location Idenitifier is SBCT.', 'SBCT is the ICAO location identifier of Afonso Pena International Airport.']",The ICAO location identifier of Afonso Pena International Airport is SBCT.
"['Afonso Pena International Airport serves the city of Curitiba.', 'Afonso Pena International Airport serves Curitiba.']",Afonso Pena International Airport serves the city of Curitiba.
"['The Al Taqaddum Air Base serves the city of Fallujah.', 'Al-Taqaddum Air Base serves the city of Fallujah.']",Al-Taqaddum Air Base serves the city of Fallujah.
"['The runway length of Al-Taqaddum Air Base is 3684.0.', 'The length of the runway at Al-Taqaddum Air Base is 3684.0.']",The runway length of Al-Taqaddum Air Base is 3684.0.
"['Alderney Airport runway name is 14/32.', '14/32 is the runway name of Alderney Airport.', 'The runway name of Alderney Airport is 14/32.']",The runway name of Alderney Airport is 14/32.
"['The runway length at Allama Iqbal International Airport is 3,360.12.', 'The runway at Allama Iqbal International Airport has a length of 3360.12.', 'The runway at Allama Iqbal International Airport is 3360.12 long.']",The runway length of Allama Iqbal International Airport is 3360.12.
"[""The first runway at Amsterdam's Schiphol Airport is known as Number 18."", ""The Amsterdam Airport Schiphol's 1st runway number is 18."", 'The number of the 1st runway at Amsterdam Airport Schiphol is 18.']",The 1st runway at Amsterdam Airport Schiphol is 18.


In [40]:
rouge_scores = (
  chunked.starmap(compute_rouge)
    # only one rouge per sample, so take the first of the values
    .map(lambda x: seq(x.items()).starmap(lambda x,y : {x:y[0]}))
    # rolling union on dictionaries since they are disjoint
    .map(lambda x: seq(x).reduce(lambda x, y: x | y))
    .to_pandas()
)
rouge_scores

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
0,0.947368,0.823529,0.947368,0.947368
1,1.000000,1.000000,1.000000,1.000000
2,1.000000,1.000000,1.000000,1.000000
3,1.000000,0.800000,0.818182,0.818182
4,1.000000,1.000000,1.000000,1.000000
...,...,...,...,...
1595,0.857143,0.615385,0.603774,0.603774
1596,0.741573,0.413793,0.449438,0.449438
1597,0.857143,0.602410,0.690476,0.690476
1598,0.787879,0.597938,0.707071,0.707071


In [41]:
rouge_scores.describe()

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
count,1600.0,1600.0,1600.0,1600.0
mean,0.845672,0.669599,0.746154,0.746154
std,0.112022,0.200196,0.173836,0.173836
min,0.387097,0.125,0.266667,0.266667
25%,0.77551,0.533333,0.619551,0.619551
50%,0.848155,0.652174,0.73913,0.73913
75%,0.933333,0.8,0.894737,0.894737
max,1.0,1.0,1.0,1.0


In [42]:
bleu = load('sacrebleu')
bleu
compute_bleu = lambda x,y: bleu.compute(references=[x], predictions=[y])

In [43]:
bleu_scores = (
  chunked.starmap(compute_bleu)
    .to_pandas()
)

In [44]:
bleu_scores

Unnamed: 0,score,counts,totals,precisions,bp,sys_len,ref_len
0,83.499502,"[11, 10, 7, 5]","[11, 10, 9, 8]","[100.0, 100.0, 77.77777777777777, 62.5]",1.000000,11,11
1,100.000000,"[21, 20, 19, 18]","[21, 20, 19, 18]","[100.0, 100.0, 100.0, 100.0]",1.000000,21,21
2,100.000000,"[13, 12, 11, 10]","[13, 12, 11, 10]","[100.0, 100.0, 100.0, 100.0]",1.000000,13,13
3,73.488892,"[11, 9, 7, 5]","[12, 11, 10, 9]","[91.66666666666667, 81.81818181818181, 70.0, 5...",1.000000,12,12
4,100.000000,"[10, 9, 8, 7]","[10, 9, 8, 7]","[100.0, 100.0, 100.0, 100.0]",1.000000,10,10
...,...,...,...,...,...,...,...
1595,62.766540,"[56, 47, 39, 32]","[61, 60, 59, 58]","[91.80327868852459, 78.33333333333333, 66.1016...",0.877088,61,69
1596,33.549194,"[35, 22, 14, 9]","[42, 41, 40, 39]","[83.33333333333333, 53.65853658536585, 35.0, 2...",0.769584,42,53
1597,58.521497,"[40, 29, 22, 18]","[46, 45, 44, 43]","[86.95652173913044, 64.44444444444444, 50.0, 4...",1.000000,46,46
1598,58.197179,"[46, 36, 27, 21]","[55, 54, 53, 52]","[83.63636363636364, 66.66666666666667, 50.9433...",1.000000,55,54


In [45]:
bertscore = load('bertscore')
compute_bert = lambda x,y: bertscore.compute(predictions=[y], references=[x], lang="en", model_type="distilbert-base-uncased" )

In [46]:
bert_scores = (
 chunked
   .starmap(compute_bert)
   .to_pandas()
   .drop(columns='hashcode')
   .applymap(np.mean)
)

In [47]:
prepend_name_to_cols = lambda x,y : x.rename(columns=lambda e: y+"_"+e)
prepend_name_to_cols(bert_scores, 'bert')
all_scores = (
  seq(bert_scores, bleu_scores, rouge_scores)
    .zip(['bert','bleu','rouge'])
    .starmap(prepend_name_to_cols)
)
scores_df = pd.concat(all_scores, axis=1)
scores_df

Unnamed: 0,bert_precision,bert_recall,bert_f1,bleu_score,bleu_counts,bleu_totals,bleu_precisions,bleu_bp,bleu_sys_len,bleu_ref_len,rouge_rouge1,rouge_rouge2,rouge_rougeL,rouge_rougeLsum
0,0.988835,0.977650,0.983211,83.499502,"[11, 10, 7, 5]","[11, 10, 9, 8]","[100.0, 100.0, 77.77777777777777, 62.5]",1.000000,11,11,0.947368,0.823529,0.947368,0.947368
1,1.000000,1.000000,1.000000,100.000000,"[21, 20, 19, 18]","[21, 20, 19, 18]","[100.0, 100.0, 100.0, 100.0]",1.000000,21,21,1.000000,1.000000,1.000000,1.000000
2,1.000000,1.000000,1.000000,100.000000,"[13, 12, 11, 10]","[13, 12, 11, 10]","[100.0, 100.0, 100.0, 100.0]",1.000000,13,13,1.000000,1.000000,1.000000,1.000000
3,0.967089,0.967089,0.967089,73.488892,"[11, 9, 7, 5]","[12, 11, 10, 9]","[91.66666666666667, 81.81818181818181, 70.0, 5...",1.000000,12,12,1.000000,0.800000,0.818182,0.818182
4,1.000000,1.000000,1.000000,100.000000,"[10, 9, 8, 7]","[10, 9, 8, 7]","[100.0, 100.0, 100.0, 100.0]",1.000000,10,10,1.000000,1.000000,1.000000,1.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1595,0.916515,0.929712,0.920013,62.766540,"[56, 47, 39, 32]","[61, 60, 59, 58]","[91.80327868852459, 78.33333333333333, 66.1016...",0.877088,61,69,0.857143,0.615385,0.603774,0.603774
1596,0.924948,0.906139,0.915447,33.549194,"[35, 22, 14, 9]","[42, 41, 40, 39]","[83.33333333333333, 53.65853658536585, 35.0, 2...",0.769584,42,53,0.741573,0.413793,0.449438,0.449438
1597,0.955313,0.943391,0.949315,58.521497,"[40, 29, 22, 18]","[46, 45, 44, 43]","[86.95652173913044, 64.44444444444444, 50.0, 4...",1.000000,46,46,0.857143,0.602410,0.690476,0.690476
1598,0.927428,0.943713,0.928931,58.197179,"[46, 36, 27, 21]","[55, 54, 53, 52]","[83.63636363636364, 66.66666666666667, 50.9433...",1.000000,55,54,0.787879,0.597938,0.707071,0.707071


In [48]:
scores_df.describe()

Unnamed: 0,bert_precision,bert_recall,bert_f1,bleu_score,bleu_bp,bleu_sys_len,bleu_ref_len,rouge_rouge1,rouge_rouge2,rouge_rougeL,rouge_rougeLsum
count,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0,1600.0
mean,0.956374,0.949423,0.951973,62.306456,0.956103,21.78625,22.149375,0.845672,0.669599,0.746154,0.746154
std,0.030983,0.035058,0.032241,23.595339,0.082793,10.993286,10.93283,0.112022,0.200196,0.173836,0.173836
min,0.804568,0.778368,0.789856,5.34666,0.296922,5.0,5.0,0.387097,0.125,0.266667,0.266667
25%,0.935331,0.925107,0.929605,46.022013,0.939413,13.0,13.0,0.77551,0.533333,0.619551,0.619551
50%,0.95708,0.950145,0.9519,61.443645,1.0,20.0,21.0,0.848155,0.652174,0.73913,0.73913
75%,0.981166,0.977079,0.976631,78.961573,1.0,29.0,29.0,0.933333,0.8,0.894737,0.894737
max,1.0,1.0,1.0,100.0,1.0,76.0,69.0,1.0,1.0,1.0,1.0


In [49]:
scores_df.to_pickle(OUTPUT_PATH / "d2s_scores.pkl")

In [50]:
test_predictions  = preds_raw[preds_raw.task == 's2d']

In [51]:
# define set notion of precision when multiple labels are assigned
# to a single instance, with epsilon preventing div by zero
def compute_f_measure(pred: set[str], gt: set[str], epsilon=1e-99):
    tp = len(pred.intersection(gt)) # pred true and actually true
    fp = len(gt - pred)             # in pred but not in gt
    fn = len(pred - gt)             # not in pred but actualy true

    prec = tp / (tp + fp + epsilon) 
    recl = tp / (tp + fn + epsilon) 
    f1 = (2 * prec * recl) / (prec + recl + epsilon)
    return f1

 ## Unit Tests

In [52]:
assert compute_f_measure(set("a"), set('a')) == 1
assert compute_f_measure(set("ab"), set('a')) == 2/3
assert compute_f_measure(set() , set('a')) == 0

In [53]:
# don't penalize for quotes or spaces
norm_split_set = lambda x: x.str.upper().str.replace("'", '').str.replace(' ','').map(_.split(";")).map(set)
y_pred = norm_split_set(test_predictions['decoded'])
y_pred

77337    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77339    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77341    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77343    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77345    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
                               ...                        
85775    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85777    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85779    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85781    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85783    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
Name: decoded, Length: 4224, dtype: object

In [54]:
# don't penalize for quotes or spaces
norm_split_set = lambda x: (
  x.str.upper()
  .str.replace("'", '')
  .str.replace(' ','')
  .str.replace("S2D\d:", "")
  .map(_.split(";")).map(set)
)
y_pred = norm_split_set(test_predictions['decoded'])
y_pred

  x.str.upper()


77337    {ABILENEREGIONALAIRPORTSERVESTHECITYOFABILENE,...
77339    {ABILENEREGIONALAIRPORTSERVESTHECITYOFABILENE,...
77341    {ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCATEDINM...
77343    {ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCATEDINM...
77345    {ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCATEDINM...
                               ...                        
85775    {THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARH...
85777    {THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARH...
85779    {THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARH...
85781    {THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARH...
85783    {THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARH...
Name: decoded, Length: 4224, dtype: object

In [55]:
# don't penalize for quotes or spaces
norm_split_set = lambda x: (
  x.str.upper()
  .str.replace("'", '')
  .str.replace(' ','')
  # .str.replace("S2D\d:", "")
  .map(_.split(";")).map(set)
)
y_pred = norm_split_set(test_predictions['decoded'])
y_pred

77337    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77339    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77341    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77343    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77345    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
                               ...                        
85775    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85777    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85779    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85781    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85783    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
Name: decoded, Length: 4224, dtype: object

In [56]:
# don't penalize for quotes or spaces
norm_split_set = lambda x: (
  x.str.upper()
  .str.replace("'", '')
  .str.replace(' ','')
  # .str.replace("S2D\d:", "")
  .map(_.split(";")).map(set)
)
y_pred = norm_split_set(test_predictions.decoded)
y_pred

77337    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77339    {S2D0:ABILENEREGIONALAIRPORTSERVESTHECITYOFABI...
77341    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77343    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
77345    {S2D0:ADOLFOSUAREZMADRID-BARAJASAIRPORTISLOCAT...
                               ...                        
85775    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85777    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85779    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85781    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
85783    {S2D0:THESCHOOLOFBUSINESSANDSOCIALSCIENCESATTH...
Name: decoded, Length: 4224, dtype: object

In [57]:
test_predictions.decoded

77337    s2d 0: Abilene Regional Airport serves the cit...
77339    s2d 0: Abilene Regional Airport serves the cit...
77341    s2d 0: Adolfo Suarez Madrid-Barajas Airport is...
77343    s2d 0: Adolfo Suarez Madrid-Barajas Airport is...
77345    s2d 0: Adolfo Suarez Madrid-Barajas Airport is...
                               ...                        
85775    s2d 0: The School of Business and Social Scien...
85777    s2d 0: The School of Business and Social Scien...
85779    s2d 0: The School of Business and Social Scien...
85781    s2d 0: The School of Business and Social Scien...
85783    s2d 0: The School of Business and Social Scien...
Name: decoded, Length: 4224, dtype: object

In [58]:
# MISNOMER HACK - 
test_predictions  = preds_raw[preds_raw.task == 'd2s']

In [59]:
# define set notion of precision when multiple labels are assigned
# to a single instance, with epsilon preventing div by zero
def compute_f_measure(pred: set[str], gt: set[str], epsilon=1e-99):
    tp = len(pred.intersection(gt)) # pred true and actually true
    fp = len(gt - pred)             # in pred but not in gt
    fn = len(pred - gt)             # not in pred but actualy true

    prec = tp / (tp + fp + epsilon) 
    recl = tp / (tp + fn + epsilon) 
    f1 = (2 * prec * recl) / (prec + recl + epsilon)
    return f1

 ## Unit Tests

In [60]:
assert compute_f_measure(set("a"), set('a')) == 1
assert compute_f_measure(set("ab"), set('a')) == 2/3
assert compute_f_measure(set() , set('a')) == 0

In [61]:
test_predictions.decoded

77336    d2s 1: Abilene Regional Airport|city served|Ab...
77338    d2s 1: Abilene Regional Airport|city served|Ab...
77340    d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...
77342    d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...
77344    d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...
                               ...                        
85774    d2s 0: School of Business and Social Sciences ...
85776    d2s 0: School of Business and Social Sciences ...
85778    d2s 0: School of Business and Social Sciences ...
85780    d2s 0: School of Business and Social Sciences ...
85782    d2s 1: School of Business and Social Sciences ...
Name: decoded, Length: 4224, dtype: object

In [62]:
# don't penalize for quotes or spaces
norm_split_set = lambda x: (
  x.str.upper()
  .str.replace("'", '')
  .str.replace(' ','')
  .str.replace("d2s\d:", "")
  .map(_.split(";")).map(set)
)
y_pred = norm_split_set(test_predictions.decoded)
y_pred

  x.str.upper()


77336    {D2S1:ABILENEREGIONALAIRPORT|CITYSERVED|ABILEN...
77338    {D2S1:ABILENEREGIONALAIRPORT|CITYSERVED|ABILEN...
77340    {D2S1:ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATI...
77342    {D2S0:ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATI...
77344    {D2S0:ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATI...
                               ...                        
85774    {SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARHUSU...
85776    {D2S0:SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAA...
85778    {D2S0:SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAA...
85780    {DENMARK|LEADERTITLE|MONARCHYOFDENMARK, D2S0:S...
85782    {SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARHUSU...
Name: decoded, Length: 4224, dtype: object

In [63]:
# don't penalize for quotes or spaces
norm_split_set = lambda x: (
  x.str.upper()
  .str.replace("'", '')
  .str.replace(' ','')
  .str.replace("D2S\d:", "")
  .map(_.split(";")).map(set)
)
y_pred = norm_split_set(test_predictions.decoded)
y_pred

  x.str.upper()


77336    {ABILENEREGIONALAIRPORT|CITYSERVED|ABILENE,TEXAS}
77338    {ABILENEREGIONALAIRPORT|CITYSERVED|ABILENE,TEXAS}
77340    {ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATION|MA...
77342    {ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATION|MA...
77344    {ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATION|MA...
                               ...                        
85774    {SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARHUSU...
85776    {DENMARK|LEADERTITLE|MONARCHYOFDENMARK, SCHOOL...
85778    {SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARHUSU...
85780    {DENMARK|LEADERTITLE|MONARCHYOFDENMARK, SCHOOL...
85782    {SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARHUSU...
Name: decoded, Length: 4224, dtype: object

In [64]:
y_true = norm_split_set(test_predictions.sd)
y_true

  x.str.upper()


77336    {ABILENEREGIONALAIRPORT|CITYSERVED|ABILENE,TEXAS}
77338    {ABILENEREGIONALAIRPORT|CITYSERVED|ABILENE,TEXAS}
77340    {ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATION|MA...
77342    {ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATION|MA...
77344    {ADOLFOSUAREZMADRID-BARAJASAIRPORT|LOCATION|MA...
                               ...                        
85774    {SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARHUSU...
85776    {DENMARK|LEADERTITLE|MONARCHYOFDENMARK, SCHOOL...
85778    {DENMARK|LEADERTITLE|MONARCHYOFDENMARK, SCHOOL...
85780    {DENMARK|LEADERTITLE|MONARCHYOFDENMARK, SCHOOL...
85782    {SCHOOLOFBUSINESSANDSOCIALSCIENCESATTHEAARHUSU...
Name: sd, Length: 4224, dtype: object

In [65]:
f1_scores = seq(y_pred).zip(y_true).starmap(compute_f_measure).to_list()
f1_scores

[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0

In [66]:
def compute_closest_edit_dists(y_pred, y_true):
    # we need an alignment of the labels by edit distance
    return (
      seq(y_pred)
        .cartesian(y_true)
        .starmap(edit_distance) 
        .sorted()
        # full penalty for missed guesses or too many guesses
        # .take(seq(y_true, y_pred).map(len).max())
        .to_list()
    )

edit_distances = (
  seq(y_pred)
    .zip(y_true)
    .starmap(compute_closest_edit_dists)
    .map(np.mean)
    .to_list()
)
edit_distances 

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 6.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 11.0,
 0.0,
 0.0,
 6.0,
 0.0,
 0.0,
 0.0,
 2.0,
 0.0,
 0.0,
 0.0,
 2.0,
 0.0,
 0.0,
 0.0,
 0.0,
 3.0,
 2.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 10.0,
 10.0,
 10.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 7.0,
 7.0,
 0.0,
 27.0,
 22.0,
 0.0,
 0.0,
 31.5,
 15.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 7.0,
 0.0,
 8.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 27.666666666666668,
 29.25,
 21.5,
 0.0,
 0.0,
 0.0,
 0.0,
 7.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 6.0,
 0.0,
 0.0,
 0.0,
 0.0,
 26.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 

In [67]:
results = test_predictions
results['f1_scores'] = f1_scores
results['med_scores'] = edit_distances # med mean edit distance
results 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['f1_scores'] = f1_scores
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['med_scores'] = edit_distances # med mean edit distance


Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task,input_ids,attention_mask,labels,pred_ids,decoded,f1_scores,med_scores
77336,14495,0,test,Airport,0,d2s 0: Abilene Regional Airport|city served|Ab...,"Abilene, Texas is served by the Abilene region...",d2s,"[891, 23, 14205, 6, 2514, 19, 2098, 57, 8, 891...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 891, 23, 14205, 76...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...,1.000000,0.000000
77338,14495,1,test,Airport,0,d2s 1: Abilene Regional Airport|city served|Ab...,Abilene Regional Airport serves the city of Ab...,d2s,"[891, 23, 14205, 7676, 5735, 4657, 8, 690, 13,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 891, 23, 14205, 7676,...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...,1.000000,0.000000
77340,14496,0,test,Airport,1,d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport can be fo...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 89, ...","[0, 3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, ...",d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
77342,14496,1,test,Airport,1,d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
77344,14496,2,test,Airport,1,d2s 2: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85774,16092,1,test,University,1597,d2s 1: School of Business and Social Sciences ...,"Established in 1928, the School of Business an...",d2s,"[25275, 16, 29004, 6, 8, 1121, 13, 1769, 11, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.769231,30.238095
85776,16093,0,test,University,1598,d2s 0: School of Business and Social Sciences ...,Denmark is led by the Monarchy of Demark and t...,d2s,"[18001, 19, 2237, 57, 8, 2963, 7064, 63, 13, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769, 11...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.833333,41.428571
85778,16093,1,test,University,1598,d2s 1: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.666667,42.085714
85780,16093,2,test,University,1598,d2s 2: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.833333,42.200000


In [68]:
# let's define nth finish and "place-number" as 0 for "finishing in first place"
# give find the place-number given a score: ties should have the same place
score_to_nth_finish: dict[float, int]= (
  seq(f1_scores)
    .sorted(reverse=True) # Highest to lowest
    .zip_with_index()     # gives an over-estimate of nth-finish
    .group_by(get[0])     # so we group by the f1 scores
    .map(get[1])          # then we get the actual place of the score
    .map(get[0])          # it's sorted, so take the first to account for ties
    .to_dict()            # convert to dictionary
)
seq(score_to_nth_finish.items()).to_pandas()

Unnamed: 0,0,1
0,1.0,0
1,0.923077,2211
2,0.909091,2230
3,0.888889,2248
4,0.857143,2383
5,0.833333,2497
6,0.833333,2506
7,0.8,2515
8,0.8,2638
9,0.769231,2713


  so we can sort by this key later, and also get a broad impression
 of the distribution of errors. Later we'll plot a histogram anyway.

 ## Error analysis

In [69]:
results = test_predictions
results['f1_scores'] = f1_scores
results['med_scores'] = edit_distances # med mean edit distance
results 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['f1_scores'] = f1_scores
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['med_scores'] = edit_distances # med mean edit distance


Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task,input_ids,attention_mask,labels,pred_ids,decoded,f1_scores,med_scores
77336,14495,0,test,Airport,0,d2s 0: Abilene Regional Airport|city served|Ab...,"Abilene, Texas is served by the Abilene region...",d2s,"[891, 23, 14205, 6, 2514, 19, 2098, 57, 8, 891...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 891, 23, 14205, 76...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...,1.000000,0.000000
77338,14495,1,test,Airport,0,d2s 1: Abilene Regional Airport|city served|Ab...,Abilene Regional Airport serves the city of Ab...,d2s,"[891, 23, 14205, 7676, 5735, 4657, 8, 690, 13,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 891, 23, 14205, 7676,...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...,1.000000,0.000000
77340,14496,0,test,Airport,1,d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport can be fo...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 89, ...","[0, 3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, ...",d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
77342,14496,1,test,Airport,1,d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
77344,14496,2,test,Airport,1,d2s 2: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85774,16092,1,test,University,1597,d2s 1: School of Business and Social Sciences ...,"Established in 1928, the School of Business an...",d2s,"[25275, 16, 29004, 6, 8, 1121, 13, 1769, 11, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.769231,30.238095
85776,16093,0,test,University,1598,d2s 0: School of Business and Social Sciences ...,Denmark is led by the Monarchy of Demark and t...,d2s,"[18001, 19, 2237, 57, 8, 2963, 7064, 63, 13, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769, 11...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.833333,41.428571
85778,16093,1,test,University,1598,d2s 1: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.666667,42.085714
85780,16093,2,test,University,1598,d2s 2: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.833333,42.200000


In [70]:
results = test_predictions
results['f1_scores'] = f1_scores
results['med_scores'] = edit_distances # med mean edit distance
results 
results.describe()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['f1_scores'] = f1_scores
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['med_scores'] = edit_distances # med mean edit distance


Unnamed: 0,record_idx,seed_number,split_index,f1_scores,med_scores
count,4224.0,4224.0,4224.0,4224.0,4224.0
mean,15306.6366,0.902462,811.6366,0.787838,16.194854
std,454.419137,0.838034,454.419137,0.297168,12.522817
min,14495.0,0.0,0.0,0.0,0.0
25%,14929.0,0.0,434.0,0.666667,9.8125
50%,15311.0,1.0,816.0,1.0,17.111111
75%,15697.25,2.0,1202.25,1.0,22.25
max,16094.0,7.0,1599.0,1.0,453.2


In [71]:
results = test_predictions
results['f1_scores'] = f1_scores
results['med_scores'] = edit_distances # med mean edit distance
results 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['f1_scores'] = f1_scores
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['med_scores'] = edit_distances # med mean edit distance


Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task,input_ids,attention_mask,labels,pred_ids,decoded,f1_scores,med_scores
77336,14495,0,test,Airport,0,d2s 0: Abilene Regional Airport|city served|Ab...,"Abilene, Texas is served by the Abilene region...",d2s,"[891, 23, 14205, 6, 2514, 19, 2098, 57, 8, 891...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 891, 23, 14205, 76...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...,1.000000,0.000000
77338,14495,1,test,Airport,0,d2s 1: Abilene Regional Airport|city served|Ab...,Abilene Regional Airport serves the city of Ab...,d2s,"[891, 23, 14205, 7676, 5735, 4657, 8, 690, 13,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 891, 23, 14205, 7676,...","[0, 3, 26, 357, 7, 209, 10, 891, 23, 14205, 76...",d2s 1: Abilene Regional Airport|city served|Ab...,1.000000,0.000000
77340,14496,0,test,Airport,1,d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport can be fo...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 89, ...","[0, 3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, ...",d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
77342,14496,1,test,Airport,1,d2s 1: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
77344,14496,2,test,Airport,1,d2s 2: Adolfo Suarez Madrid-Barajas Airport|lo...,Adolfo Suarez Madrid-Barajas Airport is locate...,d2s,"[1980, 32, 40, 89, 32, 1923, 9, 2638, 12033, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1980, 32, 40, 89, 32,...","[0, 3, 26, 357, 7, 3, 632, 10, 1980, 32, 40, 8...",d2s 0: Adolfo Suarez Madrid-Barajas Airport|lo...,1.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85774,16092,1,test,University,1597,d2s 1: School of Business and Social Sciences ...,"Established in 1928, the School of Business an...",d2s,"[25275, 16, 29004, 6, 8, 1121, 13, 1769, 11, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.769231,30.238095
85776,16093,0,test,University,1598,d2s 0: School of Business and Social Sciences ...,Denmark is led by the Monarchy of Demark and t...,d2s,"[18001, 19, 2237, 57, 8, 2963, 7064, 63, 13, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769, 11...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.833333,41.428571
85778,16093,1,test,University,1598,d2s 1: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.666667,42.085714
85780,16093,2,test,University,1598,d2s 2: School of Business and Social Sciences ...,The School of Business and Social Sciences at ...,d2s,"[37, 1121, 13, 1769, 11, 2730, 9226, 44, 8, 71...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 204, 10, 1121, 13, 1769, 11, 2...","[0, 3, 26, 357, 7, 3, 632, 10, 1121, 13, 1769,...",d2s 0: School of Business and Social Sciences ...,0.833333,42.200000


In [72]:
results.describe()

Unnamed: 0,record_idx,seed_number,split_index,f1_scores,med_scores
count,4224.0,4224.0,4224.0,4224.0,4224.0
mean,15306.6366,0.902462,811.6366,0.787838,16.194854
std,454.419137,0.838034,454.419137,0.297168,12.522817
min,14495.0,0.0,0.0,0.0,0.0
25%,14929.0,0.0,434.0,0.666667,9.8125
50%,15311.0,1.0,816.0,1.0,17.111111
75%,15697.25,2.0,1202.25,1.0,22.25
max,16094.0,7.0,1599.0,1.0,453.2
