In [1]:
# This notebook runs the NND experiment in 
# Section 5.1 of the paper: "Extrapolating Model Performance"
# The 7 models that were part of the Quiz Design paper
# As well as 3 additional models, are compared on
# Quiz Design NND - extrapolating that more recent or larger models
# would lead to further gain on the Quiz Design task

In [2]:
import utils_misc
freer_gpu = utils_misc.select_freer_gpu()

from utils_nnd_dataset import load_qd_nnd
from utils_nnd import GeneratorHF
import pandas as pd
import utils_nnd

Will use GPU: 0


In [3]:
qgens = [
        {"model_card": "distilgpt2", "starter_file": "qgen/dgpt2_squad_aaware_1.794.bin", "model_name": "dgpt2_sup"},
        {"model_card": "gpt2", "starter_file": "qgen/gpt2b_squad_aaware_1.575.bin", "model_name": "gpt2b_sup"},
        {"model_card": "gpt2-medium", "starter_file": "qgen/gpt2m_nf_squad_aaware_1.392.bin", "model_name": "gpt2m_sup"},
        {"model_card": "facebook/bart-base", "starter_file": "qgen/bartb_nf_squad_aaware_1.492.bin", "model_name": "bartb_sup"},
        {"model_card": "facebook/bart-large", "starter_file": "qgen/bartL_nf_squad_aaware_1.290.bin", "model_name": "bartl_sup"},
        {"model_card": "microsoft/prophetnet-large-uncased-squad-qg", "starter_file": None, "model_name": "prophetnet"},
        {"model_card": "Salesforce/mixqg-large", "starter_file": None, "model_name": "mixqg"},
        
        {"model_card": "Salesforce/mixqg-3b", "starter_file": None, "model_name": "mixqg-3b"},
        {"model_card": "allenai/macaw-3b", "starter_file": None, "model_name": "macaw-3b", "params": {"force_dec_prepend": "$question$ = "}},
        # Better to run on CPU to avoid core dumps, feel free to comment out as it is very slow (~10 hours)
        # {"model_card": "allenai/macaw-11b", "starter_file": None, "model_name": "macaw-11b", "params": {"force_dec_prepend": "$question$ = ", "device": "cpu"}},
        {"model_card": "allenai/macaw-answer-11b", "starter_file": None, "model_name": "macaw-answer-11b", "params": {"force_dec_prepend": "$question$ = ", "device": "cpu"}},

    ]

In [6]:
results = []
for gen in qgens:
    qge_nnd_test = load_qd_nnd(datafolder="/export/home/data/", model_card=gen["model_card"])
    model = GeneratorHF(model_card=gen["model_card"], starter_file=gen["starter_file"], **gen.get("params", {}))
    result = utils_nnd.run_nnd(qge_nnd_test, model, gen["model_name"], no_error_label="No error", report_type="accuracy")
    results.append(result)
    print(result)
pd.DataFrame(results)

NND dgpt2_sup:   0%|          | 3/2686 [00:00<01:47, 24.98it/s]

<All keys matched successfully>


NND dgpt2_sup: 100%|██████████| 2686/2686 [00:52<00:00, 50.90it/s]


{'model_name': 'dgpt2_sup', 'accuracy': 44.899478778853315, 'A_disfluent': 0.5274261603375527, 'A_No error': 0.44899478778853313, 'A_wrong_context': 0.4599078341013825, 'A_off_target': 0.37303370786516854}


NND gpt2b_sup:   0%|          | 1/2686 [00:00<04:32,  9.86it/s]

<All keys matched successfully>


NND gpt2b_sup: 100%|██████████| 2686/2686 [01:39<00:00, 27.07it/s]


{'model_name': 'gpt2b_sup', 'accuracy': 52.34549516008935, 'A_No error': 0.5234549516008935, 'A_wrong_context': 0.4930875576036866, 'A_disfluent': 0.6033755274261603, 'A_off_target': 0.49662921348314604}


NND gpt2m_sup:   0%|          | 0/2686 [00:00<?, ?it/s]

<All keys matched successfully>


NND gpt2m_sup: 100%|██████████| 2686/2686 [03:18<00:00, 13.51it/s]


{'model_name': 'gpt2m_sup', 'accuracy': 60.79672375279226, 'A_No error': 0.6079672375279226, 'A_wrong_context': 0.5612903225806452, 'A_disfluent': 0.6329113924050633, 'A_off_target': 0.6449438202247191}


NND bartb_sup:   0%|          | 3/2686 [00:00<01:52, 23.84it/s]

<All keys matched successfully>


NND bartb_sup: 100%|██████████| 2686/2686 [01:02<00:00, 42.71it/s]


{'model_name': 'bartb_sup', 'accuracy': 59.60536113179449, 'A_disfluent': 0.6047819971870605, 'A_No error': 0.5960536113179449, 'A_wrong_context': 0.5502304147465438, 'A_off_target': 0.6449438202247191}


NND bartl_sup:   0%|          | 1/2686 [00:00<05:27,  8.20it/s]

<All keys matched successfully>


NND bartl_sup: 100%|██████████| 2686/2686 [02:03<00:00, 21.75it/s]


{'model_name': 'bartl_sup', 'accuracy': 64.22189128816083, 'A_No error': 0.6422189128816084, 'A_wrong_context': 0.5944700460829493, 'A_disfluent': 0.6329113924050633, 'A_off_target': 0.7078651685393258}


NND prophetnet: 100%|██████████| 2686/2686 [03:11<00:00, 14.04it/s]


{'model_name': 'prophetnet', 'accuracy': 67.6842889054356, 'A_disfluent': 0.580872011251758, 'A_No error': 0.676842889054356, 'A_wrong_context': 0.6405529953917051, 'A_off_target': 0.797752808988764}


NND mixqg: 100%|██████████| 2686/2686 [04:40<00:00,  9.58it/s]


{'model_name': 'mixqg', 'accuracy': 70.88607594936708, 'A_disfluent': 0.6694796061884669, 'A_No error': 0.7088607594936709, 'A_wrong_context': 0.6525345622119816, 'A_off_target': 0.8089887640449438}


NND mixqg-3b: 100%|██████████| 2686/2686 [05:31<00:00,  8.11it/s]


{'model_name': 'mixqg-3b', 'accuracy': 72.85927029039463, 'A_disfluent': 0.6947960618846695, 'A_No error': 0.7285927029039464, 'A_wrong_context': 0.6783410138248849, 'A_off_target': 0.8168539325842696}


NND macaw-3b: 100%|██████████| 2686/2686 [05:10<00:00,  8.64it/s]


{'model_name': 'macaw-3b', 'accuracy': 69.1734921816828, 'A_No error': 0.691734921816828, 'A_wrong_context': 0.6506912442396313, 'A_disfluent': 0.7032348804500703, 'A_off_target': 0.7325842696629213}


NND macaw-answer-11b: 100%|██████████| 2686/2686 [7:02:31<00:00,  9.44s/it]   

{'model_name': 'macaw-answer-11b', 'accuracy': 70.62546537602384, 'A_No error': 0.7062546537602383, 'A_wrong_context': 0.6543778801843319, 'A_disfluent': 0.6933895921237694, 'A_off_target': 0.7797752808988764}





Unnamed: 0,model_name,accuracy,A_disfluent,A_No error,A_wrong_context,A_off_target
0,dgpt2_sup,44.899479,0.527426,0.448995,0.459908,0.373034
1,gpt2b_sup,52.345495,0.603376,0.523455,0.493088,0.496629
2,gpt2m_sup,60.796724,0.632911,0.607967,0.56129,0.644944
3,bartb_sup,59.605361,0.604782,0.596054,0.55023,0.644944
4,bartl_sup,64.221891,0.632911,0.642219,0.59447,0.707865
5,prophetnet,67.684289,0.580872,0.676843,0.640553,0.797753
6,mixqg,70.886076,0.66948,0.708861,0.652535,0.808989
7,mixqg-3b,72.85927,0.694796,0.728593,0.678341,0.816854
8,macaw-3b,69.173492,0.703235,0.691735,0.650691,0.732584
9,macaw-answer-11b,70.625465,0.69339,0.706255,0.654378,0.779775
