In [1]:
import itertools
import json
import requests
import re
import os
import time
import numpy as np
import string
import pickle
from tqdm import tqdm
from itertools import compress

In [2]:
#settings (do not change it)

n_context = 5
trainx_prob = "75"
all_passage_set = set([1,2,3,4,5])

instruction_no_pert_aware = \
"""
Refer to the above passages and your knowledge, and answer the following question. 

"""

instruction_pert_aware = \
"""
Refer to the above passages and your knowledge, and answer the following question. Some passages may have been perturbed with wrong information. If there are passages that have been perturbed, find the perturbed passages and ignore them when eliciting the correct answer. If there is no perturbed passage, skip the process of finding the perturbed passage and derive the answer directly.

"""

def pert_inst_generator(pert_lbls):
    pas_label = set()
    pert_text = "perturbed: "
    for pid, pas in enumerate(pert_lbls):
        if pas == 1:
            pas_label.add(pid+1)
    if len(pas_label) == 0:
        pert_text += "No passages are perturbed. "
    else:
        pert_text += "Passage " + ", ".join([str(l) for l in pas_label]) + (" is" if len(pas_label) == 1 else " are") + " perturbed. "
    remaining_passage = all_passage_set - pas_label
    if len(remaining_passage) == 0:
        pert_text += "Deriving the answer with ignoring all the passages."
    else:
        pert_text += "Deriving the answer based on Passage " + ", ".join([str(l) for l in remaining_passage]) + "."
    return pert_text

def prompt_from_nqopen(nq_data_ins, pert_lbls_str, is_pert_prompt=False, pert_fid=None, post_fix=None, inst_at_end = None):
    oneshot_prompt_example = ""
    prefix_original = "title: "
    prefix_new = "Passage "
    pert_lbls = nq_data_ins['pert_lbls' + '_' + pert_lbls_str]
    
    for i, passage in enumerate(nq_data_ins['ctxs']):
        passage_to_add = prefix_new + str(i+1) + ": "
        if pert_lbls[i] == 0:
            passage_to_add += passage['text'][len(prefix_original):]
        else:
            passage_to_add += passage['text_pert']
        oneshot_prompt_example += passage_to_add.strip()
        oneshot_prompt_example += "\n"
        
    if not inst_at_end is None:
        oneshot_prompt_example += inst_at_end
    oneshot_prompt_example += "question: " + nq_data_ins['question'].strip() + "?\n"
    if post_fix is None: #sample
        if is_pert_prompt:
            oneshot_prompt_example += pert_inst_generator(pert_lbls) + "\n"
        oneshot_prompt_example += "answer: " + nq_data_ins['answers'][0] + "\n\n"
    else:
        if not pert_fid is None:
            oneshot_prompt_example += pert_inst_generator(pert_fid) + "\n"
        oneshot_prompt_example += post_fix
    return oneshot_prompt_example

def api_GPT_oneshot_pinst(prompt, api_key=None, model="text-davinci-003"):
    headers = {
    'Content-Type': 'application/json',
    'Authorization': 'Bearer ' + api_key
    }

    json_data = {
        'model': model,
        'prompt': prompt,
        'max_tokens': 256,
        'temperature': 0.0,
        'top_p':1,
        'frequency_penalty':0.0,
        #'best_of':1,
        'presence_penalty':0,
        'logprobs':10,
    }
    
    response = requests.post('https://api.openai.com/v1/completions', headers=headers, json=json_data)
    return response

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

def ems(prediction, ground_truths):
    return max([exact_match_score(prediction, gt) for gt in ground_truths])

def em_calc(predictions, n_sample = -1):
    predictions_refined = []
    for predidx, pred in enumerate(predictions):
        pred = pred['choices'][0]['text'].strip()
        if "answer:" in pred:
            predictions_refined.append(pred[pred.find("answer:")+8:])
        elif "Answer:" in pred:
            predictions_refined.append(pred[pred.find("Answer:")+8:])
        else:
            predictions_refined.append(pred)
    em_score = np.mean([ems(pred, sample_answers[i]) for i, pred in enumerate(predictions_refined[:n_sample])])
    return em_score

In [3]:
#config

GPT3_api_key = "" #your api key id for openai
is_dev = False #True: dev-256 dataset, False: test dataset

use_parametric_only = False #if True, use only parametric setting and ignore below two settings
use_pert_aware_instruction = True #True: instructions are perturbation-aware, False: instruction are not perturbation-aware
use_discriminator_fid = True #True: inject fid discriminator's prediction results in prompts, False: let GPT-3 to generate perturbation predictions

pert_ratio = '35' #perturbation probability ['00', '15', '25', '35']
train_sample_idx = 0 #which set of samples to use? should be one of [0,1,2,3,4]

train_sample_path = "../../DATA/GPT/GPT_trainx_longpre_samples.json"

In [4]:
with open(train_sample_path, 'r') as f:
    trainx_ins_list = json.load(f)

if is_dev:
    dataset_path = "../../DATA/corpus/NQ_eval_longpre_dev_256_new_fix.json"
    discriminator_fid_path = "../../DATA/GPT/Discriminator_FiD_predictions_contra_nq_dev_longpre_256.pkl"
else:
    dataset_path = "../../DATA/corpus/NQ_eval_longpre_test_fix"
    discriminator_fid_path = "../../DATA/GPT/Discriminator_FiD_predictions_contra_nq_test_longpre_full.pkl"

with open(dataset_path, 'r') as f:
    dataset = json.load(f)
with open(discriminator_fid_path, 'rb') as f:
    discriminator_fid_pred = pickle.load(f)

sample_n = len(dataset)

In [14]:
try:
    predictions
except NameError:
    predictions = []

start_from = len(predictions)

if use_parametric_only:
    prompt = "question: " + trainx_ins_list[train_sample_idx]['question'].strip() + "?\n" + "answer: "+ trainx_ins_list[train_sample_idx]['answers'][0] + "\n\n"
else:
    post_fix = "answer:"
    if use_pert_aware_instruction:
        instruction = instruction_pert_aware
        if use_discriminator_fid == False:
            post_fix = "perturbed:" # GPT-3 should generate predictions for perturbation (discriminator_inst setting)
    else:
        instruction = instruction_no_pert_aware
    prompt = prompt_from_nqopen(trainx_ins_list[train_sample_idx], pert_lbls_str=trainx_prob, is_pert_prompt=use_pert_aware_instruction, inst_at_end = instruction)
    

for idx, ins in enumerate(tqdm(dataset[start_from:sample_n])):
    if use_parametric_only:
        prompt_test = "question: " + ins['question'].strip() + "?\n" + "answer: "
    else:
        pert_fid_label = discriminator_fid_pred[pert_ratio][start_from + idx] if use_discriminator_fid else None #fid clssification results
        prompt_test = prompt_from_nqopen(ins, pert_lbls_str=pert_ratio, is_pert_prompt=use_pert_aware_instruction, pert_fid=pert_fid_label, post_fix = post_fix)
    result = api_GPT_oneshot_pinst(prompt +  prompt_test, api_key = GPT3_api_key, model = 'text-davinci-003')
    assert 'choices' in result.json()
    predictions.append(result.json())
assert len(predictions) == sample_n

100%|██████████| 31/31 [01:41<00:00,  3.27s/it]


In [15]:
#store result
list_to_store = [prompt]
list_to_store += predictions

gpt_outpath = "nq_longpre_GPT_outputs"
filename = "dev_" if is_dev else "test_"
filename += "para_" if use_parametric_only else "semipara_"

if use_parametric_only == False:
    filename += "pert_" if use_pert_aware_instruction else ""
    filename += "fidpred_" if use_discriminator_fid else ""
    filename += "p" + pert_ratio + "_"
filename +=  "sample" + str(train_sample_idx) + ".pkl"

with open(os.path.join(gpt_outpath, filename), 'wb') as output:
    pickle.dump(list_to_store, output)