# Import necessary packages

In [13]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1, 2, 3"
from kilt import retrieval
from kilt import kilt_utils as utils
import tasks
import utils
from rouge_score import rouge_scorer
import random
import numpy as np
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from datasets import load_dataset
import json
from tqdm import tqdm

# Set up indexer

In [2]:
torch.set_grad_enabled(False)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base", device_map='cuda')
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base", device_map='cuda')
wiki = load_dataset(path='wiki_dpr', name='psgs_w100.multiset.compressed', split='train')

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.weight', 'ctx_encoder.bert_model.pooler.dense.bias']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenize

In [3]:
task='nq'
dataset_dpr = tasks.RQA_dpr(task=task)

In [4]:
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)
if semantic:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    # setup semantic model
    semantic_tokenizer = \
        AutoTokenizer.from_pretrained("microsoft/deberta-large-mnli")
    semantic_model = \
        AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-large-mnli"
        ).cuda()

# Collect data

In [5]:
indices = np.arange(len(dataset_dpr.elements))
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

elements = dataset_dpr.elements
query = [element['question'] for element in elements]

In [6]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")

Some weights of the model checkpoint at facebook/dpr-question_encoder-multiset-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
question_embedding = q_encoder(**q_tokenizer(query, return_tensors="pt", padding=True))
question_embedding = question_embedding[0].numpy()
scores, retrieved_examples = wiki.get_nearest_examples_batch('embeddings', question_embedding, k=20)

In [8]:
import pickle
def write_list(a_list, file_name):
    # store list in binary file so 'wb' mode
    with open(file_name, 'wb') as fp:
        pickle.dump(a_list, fp)
#         print('Done writing list into a binary file')
def read_list(file_name):
    # for reading also binary mode is important
    with open(file_name, 'rb') as fp:
        n_list = pickle.load(fp)
        return n_list

def save_results(task):
    # save retrieved_scores to a pickle file
    write_list(retrieved_scores, f'chatgpt_retrieved_scores_{task}_fewshot.p')
    # save retrieved_true_scores to a pickle file
    write_list(retrieved_true_scores, f'chatgpt_retrieved_true_scores_{task}_fewshot.p')
    # save queries to a pickle file
    write_list(queries, f'chatgpt_queries_{task}_fewshot.p')
    # save answers to a pickle file
    write_list(answers, f'chatgpt_true_answers_{task}_fewshot.p')
    # save passages to a pickle file
    write_list(passages, f'chatgpt_passages_{task}_fewshot.p')
    # save chatgpt_true_scores to a pickle file
    write_list(chatgpt_true_scores, f'chatgpt_true_scores_{task}_fewshot.p')
    # save chatgpt_texts to a pickle file
#     write_list(chatgpt_texts, f'chatgpt_texts_{task}.p')
    # save chatgpt_answers to a pickle file
    write_list(chatgpt_answers, f'chatgpt_answers_{task}_fewshot.p')
    # save chatgpt_semantics to a picle file
    write_list(chatgpt_semantics, f'chatgpt_semantics_{task}_fewshot.p')
    # save feasibilities to a pickle file
    write_list(feasibilities, f'chatgpt_feasibilities_{task}_fewshot.p')
    # save occurances to a pickle file
    write_list(occurances, f'chatgpt_occurances_{task}_fewshot.p')
    # save semantic_ids to a pickle file
    write_list(semantic_ids, f'chatgpt_semantic_ids_{task}_fewshot.p')
    # save probs to a picle file
    write_list(probs, f'chatgpt_probs_{task}_fewshot.p')
    
    write_list(retrieved_scores_unc, f'chatgpt_retrieved_scores_unc_{task}_fewshot.p')
    write_list(retrieved_true_scores_unc, f'chatgpt_retrieved_true_scores_unc_{task}_fewshot.p')
    write_list(queries_unc, f'chatgpt_queries_unc_{task}_fewshot.p')
    write_list(answers_unc, f'chatgpt_answers_unc_{task}_fewshot.p')
    write_list(passages_unc, f'chatgpt_passages_unc_{task}_fewshot.p')
    write_list(chatgpt_true_scores_unc, f'chatgpt_true_scores_unc_{task}_fewshot.p')
    write_list(chatgpt_answers_unc, f'chatgpt_answers_unc_{task}_fewshot.p')
    write_list(occurances_unc, f'chatgpt_occurances_unc_{task}_fewshot.p')
    write_list(semantic_ids_unc, f'chatgpt_semantic_ids_unc_{task}_fewshot.p')
    write_list(probs_unc, f'chatgpt_probs_unc_{task}_fewshot.p')

In [9]:
def read_results(task):
    retrieved_scores = read_list(f'chatgpt_retrieved_scores_{task}_fewshot.p')
    retrieved_true_scores = read_list(f'chatgpt_retrieved_true_scores_{task}_fewshot.p')
    queries = read_list(f'chatgpt_queries_{task}_fewshot.p')
    answers = read_list(f'chatgpt_answers_{task}_fewshot.p')
    chatgpt_true_scores = read_list(f'chatgpt_true_scores_{task}_fewshot.p')
    chatgpt_answers = read_list(f'chatgpt_answers_{task}_fewshot.p')
    chatgpt_passages = read_list(f'chatgpt_passages_{task}_fewshot.p')
    chatgpt_semantics = read_list(f'chatgpt_semantics_{task}_fewshot.p')
    chatgpt_occurances = read_list(f'chatgpt_occurances_{task}_fewshot.p')
    chatgpt_semantic_ids = read_list(f'chatgpt_semantic_ids_{task}_fewshot.p')
    chatgpt_probs = read_list(f'chatgpt_probs_{task}_fewshot.p')
    
    retrieved_scores_unc = read_list(f'chatgpt_retrieved_scores_unc_{task}_fewshot.p')
    retrieved_true_scores_unc = read_list(f'chatgpt_retrieved_true_scores_unc_{task}_fewshot.p')
    queries_unc = read_list(f'chatgpt_queries_unc_{task}_fewshot.p')
    answers_unc = read_list(f'chatgpt_answers_unc_{task}_fewshot.p')
    passages_unc = read_list(f'chatgpt_passages_unc_{task}_fewshot.p')
    chatgpt_true_scores_unc = read_list(f'chatgpt_true_scores_unc_{task}_fewshot.p')
    chatgpt_answers_unc = read_list(f'chatgpt_answers_unc_{task}_fewshot.p')
    chatgpt_occurances_unc = read_list(f'chatgpt_occurances_unc_{task}_fewshot.p')
    chatgpt_semantic_ids_unc = read_list(f'chatgpt_semantic_ids_unc_{task}_fewshot.p')
    chatgpt_probs_unc = read_list(f'chatgpt_probs_unc_{task}_fewshot.p')
    
    return retrieved_scores, retrieved_true_scores, queries, answers, chatgpt_true_scores, chatgpt_answers, chatgpt_passages, chatgpt_semantics, chatgpt_occurances, chatgpt_semantic_ids, chatgpt_probs, retrieved_scores_unc, retrieved_true_scores_unc, queries_unc, answers_unc, passages_unc, chatgpt_true_scores_unc, chatgpt_answers_unc, chatgpt_occurances_unc, chatgpt_semantic_ids_unc, chatgpt_probs_unc

## Setup chatgpt

In [10]:
utils.setup_openai()

In [11]:
# construct demonstrations
demonstrations = []
for element in elements[-2:]:
    query, answer, passage_id, passage_title, passage_text = \
        utils.dataset_info(element, dataset=task)
    demonstrations.append([query, passage_text[0], answer])

In [15]:
import importlib
importlib.reload(utils)

<module 'utils' from '/home/lishuo1/retriever_uncertainty/TRAC/utils.py'>

In [16]:
%%time
chat = True
semantic = False
queries = []
answers = []
passages = []
retrieved_scores = []
retrieved_true_scores = []
chatgpt_true_scores = []
chatgpt_texts = []
chatgpt_answers = []
chatgpt_semantics = []
semantic_probs = []
feasibilities = []
occurances = []
semantic_ids = []
probs = []
input_token_counts = []
output_token_counts = []

retrieved_scores_unc = []
retrieved_true_scores_unc = []
queries_unc = []
answers_unc = []
passages_unc = []
chatgpt_true_scores_unc = []
chatgpt_answers_unc = []
occurances_unc = []
semantic_ids_unc = []
probs_unc = []
        
for idx, (element, score, retrieved) in enumerate(zip(elements, scores, retrieved_examples)):
    if len(queries) > 1004:
        break
    
    print(f'{idx}', file=open(f'chatgpt_{task}_fewshot.txt', 'a'))
    feasible = False
    if idx % 10 == 0:
        print(idx)
        save_results(task)
    query, answer, passage_id, passage_title, passage_text = \
        utils.dataset_info(element, dataset=task)
    if len(passage_id) == 0:
        continue
    retrieved_ids, retrieved_texts, retrieved_title, true_score = \
        utils.retrieved_info(score, retrieved, passage_id[0])
    if len(true_score) == 0:
        continue
    
    prompt = utils.get_prompt_template_fewshot(
        query, 
        passage_text[0], 
        demonstrations)
    break
    if chat:
        sequences, input_token_count, output_token_count = \
            utils.ask_chatgpt(prompt, n_answers=30, model="gpt-3.5-turbo-0613")
    else:
        sequences, probs = utils.ask_chatgpt(prompt)
    semantic_set_ids, semantic_probs, item_occurance = \
        utils.clustering(sequences, prompt, scorer=scorer)
    true_scores, matched_answer, semantics = utils.processing_answers(
        semantic_set_ids, semantic_probs, 
        item_occurance, answer, scorer,
        threshold=0.3
    )
    if len(true_scores) == 0:
        retrieved_scores_unc.append(score)
        retrieved_true_scores_unc.append(true_score)
        queries_unc.append(query)
        answers_unc.append(answer)
        passages_unc.append(passage_text)
        chatgpt_true_scores_unc.append(true_scores)
        chatgpt_answers_unc.append(sequences)
        occurances_unc.append(item_occurance)
        semantic_ids_unc.append(semantic_set_ids)
        probs_unc.append(semantic_probs)
        input_token_counts.append(input_token_count)
        output_token_counts.append(output_token_count)
        continue
    else:
        feasible = True
        retrieved_scores.append(score)
        retrieved_true_scores.append(true_score)
        queries.append(query)
        answers.append(answer)
        passages.append(passage_text)
        chatgpt_true_scores.append(true_scores)

        probs_tmp = []
        answers_tmp = []
        semantic_id_tmp = []
        occurance_tmp = []
        semantic_tmp = []
        for context, s in zip(retrieved_texts, score):

            prompt = utils.get_prompt_template_fewshot(
                query, 
                context, 
                demonstrations)
            if chat:
                sequences, input_token_count_tmp, output_token_count_tmp = \
                    utils.ask_chatgpt(prompt, n_answers=30, model="gpt-3.5-turbo-0613")
            else:
                sequences, probs = utils.ask_chatgpt(prompt, n_answers=30)
            input_token_count += input_token_count_tmp
            output_token_count += output_token_count_tmp

            if semantic:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_semantic_clusterring(
                        semantic_model, 
                        semantic_tokenizer,
                        prompt,
                        sequences,
                    )
            else:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.clustering(sequences, prompt, scorer=scorer)

            probs_tmp.append(semantic_probs)
            answers_tmp.append(sequences)
            occurance_tmp.append(item_occurance)
            semantic_id_tmp.append(semantic_set_ids)

        chatgpt_answers.append(answers_tmp)
        feasibilities.append(feasible)
        occurances.append(occurance_tmp)
        semantic_ids.append(semantic_id_tmp)
        probs.append(probs_tmp)
        input_token_counts.append(input_token_count)
        output_token_counts.append(output_token_count)
print('Finished!', file=open(f'chatgpt_{task}_fewshot.txt', 'a'))

0
CPU times: user 4.57 ms, sys: 191 µs, total: 4.76 ms
Wall time: 4.92 ms


In [17]:
prompt

'Answer the following question based on the given context; Answer the question shortly.\n                Question: \'\'\'when did sunday became the seventh day of the week\'\'\'\n                Context: \'\'\'일요일 Il-yo-Il, meaning "day of sun". The international standard ISO 8601 for representation of dates and times, states that Sunday is the seventh and last day of the week. This method of representing dates and times unambiguously was first published in 1988. In the Judaic, some Christian, as well as in some Islamic tradition, Sunday has been considered the first day of the week. A number of languages express this position either by the name for the day or by the naming of the other days. In Hebrew it is called יום ראשון "yom rishon", in Arabic الأحد "al-ahad", in\'\'\'\n                Answer: \'\'\'[\'1988\']\'\'\'\n                Answer the following question based on the given context; Answer the question shortly.\n                Question: \'\'\'girl from the shut up and danc

## Estimated cost for each query

In [27]:
average_cost = np.mean([input_token_count]) / 1000 * 0.0015 + \
               np.mean([output_token_count]) / 1000 * 0.002
print(average_cost)

0.019968
