In [1]:
import json
import os
import csv 
import pandas as pd

from eval_v1_1_question import evaluate

In [2]:
def get_augmented_filename(output_dir, question_set, parts_of_speech=None, frequency_percentile=None):
    assert bool(parts_of_speech) ^ bool(frequency_percentile), "Can only pass one of parts_of_speech and frequency_percentile"
    
    if parts_of_speech:
        filename = output_dir+question_set+"_"+model_type+"_"+"_".join(parts_of_speech)+".json"
    elif frequency_percentile:
        filename = output_dir+question_set+"_"+model_type+"_Percentile_"+str(frequency_percentile)+".json"
    
    return filename


def get_augmented_filename(input_dir, model_name, question_set, parts_of_speech=None, frequency_percentile=None):
    if model_name == 'orig':
        filename = input_dir+question_set+".json"
        return filename

    assert bool(parts_of_speech) ^ bool(frequency_percentile), "Can only pass one of parts_of_speech and frequency_percentile"
    
    if parts_of_speech:
        filename = input_dir+question_set+"_"+model_name+"_"+"_".join(parts_of_speech)+".json"
    elif frequency_percentile:
        filename = input_dir+question_set+"_"+model_name+"_Percentile_"+str(frequency_percentile)+".json"

    return filename


def get_prediction_filename(output_dir, model_name, question_set, qa_model, parts_of_speech=None, frequency_percentile=None):
    if model_name == 'orig':
        filename = output_dir+question_set+"_"+qa_model+".json"
        return filename
    
    assert bool(parts_of_speech) ^ bool(frequency_percentile), "Can only pass one of parts_of_speech and frequency_percentile"
    
    if parts_of_speech:
        filename = output_dir+question_set+"_"+model_name+"_"+"_".join(parts_of_speech)+"_"+qa_model+".json"
    elif frequency_percentile:
        filename = output_dir+question_set+"_"+model_name+"_Percentile_"+str(frequency_percentile)+"_"+qa_model+".json"

    return filename

def strip_filename(filepath):
    filename = filepath.split('/')[-1:][0]
    return filename[:-5]


def get_eval_1_1(qa_json, preds_json):
    
    with open(qa_json) as qa_json_file:
        qa = json.load(qa_json_file)['data']

    with open(preds_json) as preds_json_file:
        preds = json.load(preds_json_file)
    
    eval_results, question_results = evaluate(qa, preds)
    
    return eval_results, question_results

augmented_dir = '/data/distribution_shift/augmented_qa/'
predictions_dir = '/data/distribution_shift/augmented_qa/predictions/'

qa_files = {
    "amazon_reviews_v1_0": '/data/distribution_shift/new_qa/amazon_reviews_v1.0.json',
    "reddit_v1_0": '/data/distribution_shift/new_qa/reddit_v1.0.json',
    "new_wiki_v1.0": '/data/distribution_shift/new_qa/new_wiki_v1.0.json',
    "nyt_v1.0": '/data/distribution_shift/new_qa/nyt_v1.0.json',
}

parts_of_speech_list = [
    ['JJ', 'VB'],
    ['JJ'],
    ['VB', 'RB'],
    ['VB'],
    ['RB'],
    ['RB', 'RBR', 'RBZ'],
    ['VB', 'VBD', 'VBG', 'VBN', 'VBP'],
    ['RB', 'RBR', 'RBZ', 'VB', 'VBD', 'VBG' 'VBN', 'VBP']
]

augmenting_models = [
    'orig',
    'bert',
    'roberta'
]

qa_models = [
    'bert-large-cased-whole-word-masking-finetuned-squad',
    'bert-large-uncased-whole-word-masking-finetuned-squad',
    'distilbert-base-cased-distilled-squad',
    'distilbert-base-uncased-distilled-squad'
]

frequency_percentiles = [
    0.10,
    0.20,
    0.30,
    0.50
]


model_results = [["question_set", 
                  "questions", 
                  "predictions", 
                  "augmentation_model", 
                  "parts_of_speech", 
                  "frequency_percentile", 
                  "augmentation_name", 
                  "question_answering_model", 
                  "exact_match", "f1"
                 ]]

model_question_results = [["question_set", 
                           "questions", 
                           "predictions", 
                           "augmentation_model", 
                           "augmentation_name", 
                           "question_answering_model", 
                           "question_id", 
                           "exact_match", 
                           "f1"]
                         ]

for model_name in augmenting_models:
    for question_set, orig_filename in qa_files.items():
        for qa_model in qa_models:

            if model_name == 'orig':
                qa_filepath = orig_filename
                pred_filepath = get_prediction_filename(predictions_dir, model_name, question_set, qa_model)

                if os.path.exists(qa_filepath) and os.path.exists(pred_filepath):
                    print("Predicting for\n QA: {}\n Pred:{}".format(qa_filepath, pred_filepath))
                
                    model_eval, question_results = get_eval_1_1(qa_filepath, pred_filepath)

                    model_results.append([question_set, 
                                          strip_filename(qa_filepath),
                                          strip_filename(pred_filepath),
                                          model_name,
                                          None,
                                          None,
                                          "Original",
                                          qa_model,
                                          model_eval['exact_match'],
                                          model_eval['f1'],
                                         ])

                    # Add meta data and append
                    for question_result in question_results:
                        model_question_results.append([question_set, 
                                                       strip_filename(qa_filepath),
                                                       strip_filename(pred_filepath),
                                                       model_name, 
                                                       "Original",
                                                       qa_model
                                                      ] + question_result
                                                     )

                else:
                    print("Missing one of the orig files:\n{}\n{}\n\n".format(qa_filepath, pred_filepath))


                continue

            for parts_of_speech in parts_of_speech_list:
                qa_filepath = get_augmented_filename(augmented_dir, model_name, question_set, parts_of_speech=parts_of_speech)

                pred_filepath = get_prediction_filename(predictions_dir, model_name, question_set, qa_model, parts_of_speech=parts_of_speech)

                if os.path.exists(qa_filepath) and os.path.exists(pred_filepath):
                    model_eval, question_results = get_eval_1_1(qa_filepath, pred_filepath)

                    model_results.append([question_set, 
                                          strip_filename(qa_filepath),
                                          strip_filename(pred_filepath),
                                          model_name,
                                          parts_of_speech,
                                          None,
                                          "PoS_"+"_".join(parts_of_speech),
                                          qa_model,
                                          model_eval['exact_match'],
                                          model_eval['f1'],
                                         ])

                    # Add meta data and append
                    for question_result in question_results:
                        model_question_results.append([question_set, 
                                                       strip_filename(qa_filepath),
                                                       strip_filename(pred_filepath),
                                                       model_name,
                                                       "PoS_"+"_".join(parts_of_speech),
                                                       qa_model
                                                      ] + question_result
                                                     )

                else:
                    print("Missing one of the files:\n{}\n{}\n\n".format(qa_filepath, pred_filepath))

            for frequency_percentile in frequency_percentiles:
                qa_filepath = get_augmented_filename(augmented_dir, model_name, question_set, frequency_percentile=frequency_percentile)
                pred_filepath = get_prediction_filename(predictions_dir, model_name, question_set, qa_model, frequency_percentile=frequency_percentile)

                if os.path.exists(qa_filepath) and os.path.exists(pred_filepath):

                    model_eval, question_results = get_eval_1_1(qa_filepath, pred_filepath)

                    model_results.append([question_set, 
                                          strip_filename(qa_filepath),
                                          strip_filename(pred_filepath),
                                          model_name,
                                          None,
                                          frequency_percentile,
                                          "Percentile_"+str(frequency_percentile),
                                          qa_model,
                                          model_eval['exact_match'],
                                          model_eval['f1'],
                                         ])

                    for question_result in question_results:
                        model_question_results.append([question_set, 
                                                       strip_filename(qa_filepath),
                                                       strip_filename(pred_filepath),
                                                       model_name,
                                                       "Percentile_"+str(frequency_percentile),
                                                       qa_model
                                                      ] + question_result
                                                     )                    
                else:
                    print("Missing one of the files:\n{}\n{}\n\n".format(qa_filepath, pred_filepath))




Predicting for
 QA: /data/distribution_shift/new_qa/amazon_reviews_v1.0.json
 Pred:/data/distribution_shift/augmented_qa/predictions/amazon_reviews_v1_0_bert-large-cased-whole-word-masking-finetuned-squad.json
Predicting for
 QA: /data/distribution_shift/new_qa/amazon_reviews_v1.0.json
 Pred:/data/distribution_shift/augmented_qa/predictions/amazon_reviews_v1_0_bert-large-uncased-whole-word-masking-finetuned-squad.json
Predicting for
 QA: /data/distribution_shift/new_qa/amazon_reviews_v1.0.json
 Pred:/data/distribution_shift/augmented_qa/predictions/amazon_reviews_v1_0_distilbert-base-cased-distilled-squad.json
Predicting for
 QA: /data/distribution_shift/new_qa/amazon_reviews_v1.0.json
 Pred:/data/distribution_shift/augmented_qa/predictions/amazon_reviews_v1_0_distilbert-base-uncased-distilled-squad.json
Predicting for
 QA: /data/distribution_shift/new_qa/reddit_v1.0.json
 Pred:/data/distribution_shift/augmented_qa/predictions/reddit_v1_0_bert-large-cased-whole-word-masking-finetuned-s

Unanswered question 5d706342c8e4820a9b66f05a will receive score 0.
Unanswered question 5d6fb418c8e4820a9b66a7ae will receive score 0.
Unanswered question 5d706575c8e4820a9b66f08f will receive score 0.
Unanswered question 5d70055bc8e4820a9b66a8c3 will receive score 0.
Unanswered question 5d70068bc8e4820a9b66ab9c will receive score 0.
Unanswered question 5d700715c8e4820a9b66acd4 will receive score 0.
Unanswered question 5d7004a9c8e4820a9b66a800 will receive score 0.
Unanswered question 5d700632c8e4820a9b66aac5 will receive score 0.
Unanswered question 5d700632c8e4820a9b66aac8 will receive score 0.
Unanswered question 5d7004f2c8e4820a9b66a82f will receive score 0.
Unanswered question 5d700598c8e4820a9b66a94c will receive score 0.
Unanswered question 5d7023f3c8e4820a9b66d03d will receive score 0.
Unanswered question 5d700525c8e4820a9b66a87e will receive score 0.
Unanswered question 5d7005e5c8e4820a9b66a9ff will receive score 0.
Unanswered question 5d7005e4c8e4820a9b66a9ed will receive scor

Unanswered question 5d704668c8e4820a9b66e848 will receive score 0.
Unanswered question 5d7048dbc8e4820a9b66e910 will receive score 0.
Unanswered question 5d708f46c8e4820a9b66f56d will receive score 0.
Unanswered question 5d701892c8e4820a9b66c4af will receive score 0.
Unanswered question 5d701d56c8e4820a9b66c8e6 will receive score 0.
Unanswered question 5d7025e9c8e4820a9b66d21c will receive score 0.
Unanswered question 5d7019fec8e4820a9b66c5eb will receive score 0.
Unanswered question 5d7019fec8e4820a9b66c5ec will receive score 0.
Unanswered question 5d701bb1c8e4820a9b66c727 will receive score 0.
Unanswered question 5d701af7c8e4820a9b66c67e will receive score 0.
Unanswered question 5d701af7c8e4820a9b66c67f will receive score 0.
Unanswered question 5d701be8c8e4820a9b66c76a will receive score 0.
Unanswered question 5d701ebec8e4820a9b66ca55 will receive score 0.
Unanswered question 5d701ebec8e4820a9b66ca56 will receive score 0.
Unanswered question 5d702056c8e4820a9b66cc1a will receive scor

Unanswered question 5d7078f6c8e4820a9b66f2d9 will receive score 0.
Unanswered question 5d7078f6c8e4820a9b66f2da will receive score 0.
Unanswered question 5d7079d7c8e4820a9b66f308 will receive score 0.
Unanswered question 5d707c9ec8e4820a9b66f354 will receive score 0.
Unanswered question 5d707c49c8e4820a9b66f349 will receive score 0.
Unanswered question 5d707d9dc8e4820a9b66f37b will receive score 0.
Unanswered question 5d707efdc8e4820a9b66f3ac will receive score 0.
Unanswered question 5d7082d5c8e4820a9b66f405 will receive score 0.
Unanswered question 5d7082d5c8e4820a9b66f407 will receive score 0.
Unanswered question 5d708701c8e4820a9b66f454 will receive score 0.
Unanswered question 5d708701c8e4820a9b66f455 will receive score 0.
Unanswered question 5d70896fc8e4820a9b66f4aa will receive score 0.
Unanswered question 5d708b0bc8e4820a9b66f504 will receive score 0.
Unanswered question 5d70a21dc8e4820a9b66f67c will receive score 0.
Unanswered question 5d70aa38c8e4820a9b66f6c4 will receive scor

Unanswered question 5d700d41c8e4820a9b66b8be will receive score 0.
Unanswered question 5d700b4ac8e4820a9b66b5f1 will receive score 0.
Unanswered question 5d700cc7c8e4820a9b66b821 will receive score 0.
Unanswered question 5d700a4ac8e4820a9b66b3f1 will receive score 0.
Unanswered question 5d700bb3c8e4820a9b66b693 will receive score 0.
Unanswered question 5d700a54c8e4820a9b66b40d will receive score 0.
Unanswered question 5d700c5ac8e4820a9b66b75a will receive score 0.
Unanswered question 5d700a69c8e4820a9b66b419 will receive score 0.
Unanswered question 5d700a93c8e4820a9b66b491 will receive score 0.
Unanswered question 5d700b05c8e4820a9b66b549 will receive score 0.
Unanswered question 5d700b0ec8e4820a9b66b557 will receive score 0.
Unanswered question 5d700b0ec8e4820a9b66b55b will receive score 0.
Unanswered question 5d700b4ac8e4820a9b66b5fc will receive score 0.
Unanswered question 5d700b4fc8e4820a9b66b609 will receive score 0.
Unanswered question 5d700f6bc8e4820a9b66bb7b will receive scor

Unanswered question 5d702e78c8e4820a9b66dbbd will receive score 0.
Unanswered question 5d703188c8e4820a9b66dd53 will receive score 0.
Unanswered question 5d703188c8e4820a9b66dd56 will receive score 0.
Unanswered question 5d702ed3c8e4820a9b66dc0b will receive score 0.
Unanswered question 5d703091c8e4820a9b66dcc3 will receive score 0.
Unanswered question 5d702ee3c8e4820a9b66dc17 will receive score 0.
Unanswered question 5d702f4fc8e4820a9b66dc29 will receive score 0.
Unanswered question 5d7030a4c8e4820a9b66dcd4 will receive score 0.
Unanswered question 5d7030a4c8e4820a9b66dcd5 will receive score 0.
Unanswered question 5d7030a4c8e4820a9b66dcd7 will receive score 0.
Unanswered question 5d702ff0c8e4820a9b66dc89 will receive score 0.
Unanswered question 5d70315fc8e4820a9b66dd38 will receive score 0.
Unanswered question 5d7031ddc8e4820a9b66dd84 will receive score 0.
Unanswered question 5d703059c8e4820a9b66dca3 will receive score 0.
Unanswered question 5d7031ffc8e4820a9b66dd9c will receive scor

Unanswered question 5d7078f6c8e4820a9b66f2d9 will receive score 0.
Unanswered question 5d7078f6c8e4820a9b66f2da will receive score 0.
Unanswered question 5d7079d7c8e4820a9b66f308 will receive score 0.
Unanswered question 5d707c9ec8e4820a9b66f354 will receive score 0.
Unanswered question 5d707c49c8e4820a9b66f349 will receive score 0.
Unanswered question 5d707d9dc8e4820a9b66f37b will receive score 0.
Unanswered question 5d707efdc8e4820a9b66f3ac will receive score 0.
Unanswered question 5d7082d5c8e4820a9b66f405 will receive score 0.
Unanswered question 5d7082d5c8e4820a9b66f407 will receive score 0.
Unanswered question 5d708701c8e4820a9b66f454 will receive score 0.
Unanswered question 5d708701c8e4820a9b66f455 will receive score 0.
Unanswered question 5d70896fc8e4820a9b66f4aa will receive score 0.
Unanswered question 5d708b0bc8e4820a9b66f504 will receive score 0.
Unanswered question 5d70a21dc8e4820a9b66f67c will receive score 0.
Unanswered question 5d70aa38c8e4820a9b66f6c4 will receive scor

Unanswered question 5d700d41c8e4820a9b66b8be will receive score 0.
Unanswered question 5d700b4ac8e4820a9b66b5f1 will receive score 0.
Unanswered question 5d700cc7c8e4820a9b66b821 will receive score 0.
Unanswered question 5d700a4ac8e4820a9b66b3f1 will receive score 0.
Unanswered question 5d700bb3c8e4820a9b66b693 will receive score 0.
Unanswered question 5d700a54c8e4820a9b66b40d will receive score 0.
Unanswered question 5d700c5ac8e4820a9b66b75a will receive score 0.
Unanswered question 5d700a69c8e4820a9b66b419 will receive score 0.
Unanswered question 5d700a93c8e4820a9b66b491 will receive score 0.
Unanswered question 5d700b05c8e4820a9b66b549 will receive score 0.
Unanswered question 5d700b0ec8e4820a9b66b557 will receive score 0.
Unanswered question 5d700b0ec8e4820a9b66b55b will receive score 0.
Unanswered question 5d700b4ac8e4820a9b66b5fc will receive score 0.
Unanswered question 5d700b4fc8e4820a9b66b609 will receive score 0.
Unanswered question 5d700f6bc8e4820a9b66bb7b will receive scor

Unanswered question 5d702cd9c8e4820a9b66da59 will receive score 0.
Unanswered question 5d702dd6c8e4820a9b66db35 will receive score 0.
Unanswered question 5d702dd6c8e4820a9b66db37 will receive score 0.
Unanswered question 5d702ed1c8e4820a9b66dc02 will receive score 0.
Unanswered question 5d702e78c8e4820a9b66dbbd will receive score 0.
Unanswered question 5d703188c8e4820a9b66dd53 will receive score 0.
Unanswered question 5d703188c8e4820a9b66dd56 will receive score 0.
Unanswered question 5d702ed3c8e4820a9b66dc0b will receive score 0.
Unanswered question 5d703091c8e4820a9b66dcc3 will receive score 0.
Unanswered question 5d702ee3c8e4820a9b66dc17 will receive score 0.
Unanswered question 5d702f4fc8e4820a9b66dc29 will receive score 0.
Unanswered question 5d7030a4c8e4820a9b66dcd4 will receive score 0.
Unanswered question 5d7030a4c8e4820a9b66dcd5 will receive score 0.
Unanswered question 5d7030a4c8e4820a9b66dcd7 will receive score 0.
Unanswered question 5d702ff0c8e4820a9b66dc89 will receive scor

### Write model results to CSV and parquet

In [3]:
def write_results_csv(model_results):
    filename = '/data/distribution_shift/augmented_qa/results/model_results.csv'
    
    print("Writting: {}".format(filename))
    with open(filename, 'w', encoding='utf-8') as f:
        csv_writer = csv.writer(f, dialect='excel')
        csv_writer.writerows(model_results)
        
    return filename

write_results_csv(model_results)

Writting: /data/distribution_shift/augmented_qa/results/model_results.csv


'/data/distribution_shift/augmented_qa/results/model_results.csv'

In [4]:
filename = '/data/distribution_shift/augmented_qa/results/model_results.parquet.gzip'

model_results_df = pd.DataFrame(model_results[1:], columns=model_results[0])

model_results_df['question_set'] = model_results_df['question_set'].astype('category')
model_results_df['questions'] = model_results_df['questions'].astype('category')
model_results_df['predictions'] = model_results_df['predictions'].astype('category')
model_results_df['augmentation_model'] = model_results_df['augmentation_model'].astype('category')
model_results_df['question_answering_model'] = model_results_df['question_answering_model'].astype('category')
model_results_df['augmentation_name'] = model_results_df['augmentation_name'].astype('category')

print("Writting: {}".format(filename))
model_results_df.to_parquet(filename,
              compression='gzip') 

model_results_df.head()

Writting: /data/distribution_shift/augmented_qa/results/model_results.parquet.gzip


Unnamed: 0,question_set,questions,predictions,augmentation_model,parts_of_speech,frequency_percentile,augmentation_name,question_answering_model,exact_match,f1
0,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_bert-large-cased-whole-wor...,orig,,,Original,bert-large-cased-whole-word-masking-finetuned-...,61.294891,76.448281
1,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_bert-large-uncased-whole-w...,orig,,,Original,bert-large-uncased-whole-word-masking-finetune...,61.446636,77.010119
2,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_distilbert-base-cased-dist...,orig,,,Original,distilbert-base-cased-distilled-squad,52.149722,67.433575
3,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_distilbert-base-uncased-di...,orig,,,Original,distilbert-base-uncased-distilled-squad,51.785534,67.140229
4,reddit_v1_0,reddit_v1.0,reddit_v1_0_bert-large-cased-whole-word-maskin...,orig,,,Original,bert-large-cased-whole-word-masking-finetuned-...,61.215954,75.00533


### Write model question results to CSV and parquet

In [5]:
def write_question_results_csv(question_results):
    filename = '/data/distribution_shift/augmented_qa/results/question_results.csv'
    
    print("Writting: {}".format(filename))
    with open(filename, 'w', encoding='utf-8') as f:
        csv_writer = csv.writer(f, dialect='excel')
        csv_writer.writerows(question_results)
        
    return filename

write_question_results_csv(model_question_results)

Writting: /data/distribution_shift/augmented_qa/results/question_results.csv


'/data/distribution_shift/augmented_qa/results/question_results.csv'

In [6]:
filename = '/data/distribution_shift/augmented_qa/results/question_results.parquet.gzip'

model_question_results_df = pd.DataFrame(model_question_results[1:],
                                        columns=model_question_results[0]
                                        )

model_question_results_df['question_set'] = model_question_results_df['question_set'].astype('category')
model_question_results_df['questions'] = model_question_results_df['questions'].astype('category')
model_question_results_df['predictions'] = model_question_results_df['predictions'].astype('category')
model_question_results_df['augmentation_model'] = model_question_results_df['augmentation_model'].astype('category')
model_question_results_df['augmentation_name'] = model_question_results_df['augmentation_name'].astype('category')
model_question_results_df['question_answering_model'] = model_question_results_df['question_answering_model'].astype('category')


print("Writting: {}".format(filename))
model_question_results_df.to_parquet(filename,
              compression='gzip') 

model_question_results_df.head()

Writting: /data/distribution_shift/augmented_qa/results/question_results.parquet.gzip


Unnamed: 0,question_set,questions,predictions,augmentation_model,augmentation_name,question_answering_model,question_id,exact_match,f1
0,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_bert-large-cased-whole-wor...,orig,Original,bert-large-cased-whole-word-masking-finetuned-...,5dd465dacc027a086d65bc6c,True,1.0
1,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_bert-large-cased-whole-wor...,orig,Original,bert-large-cased-whole-word-masking-finetuned-...,5dd465dacc027a086d65bc6d,False,0.888889
2,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_bert-large-cased-whole-wor...,orig,Original,bert-large-cased-whole-word-masking-finetuned-...,5dd465dacc027a086d65bc6e,True,1.0
3,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_bert-large-cased-whole-wor...,orig,Original,bert-large-cased-whole-word-masking-finetuned-...,5dd465dacc027a086d65bc6f,False,0.571429
4,amazon_reviews_v1_0,amazon_reviews_v1.0,amazon_reviews_v1_0_bert-large-cased-whole-wor...,orig,Original,bert-large-cased-whole-word-masking-finetuned-...,5dd465dacc027a086d65bc70,False,0.0
