## Import necessary packages

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 2, 3, 7"
from kilt import retrieval
from kilt import kilt_utils as utils
import tasks
from kilt.retrievers import DPR_connector
import utils
from rouge_score import rouge_scorer
import random
import numpy as np
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from datasets import load_dataset
import json
from tqdm import tqdm

## Set up indexer

In [None]:
torch.set_grad_enabled(False)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", device_map='cuda')
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", device_map='cuda')
wiki = load_dataset(path='wiki_dpr', name='psgs_w100.multiset.compressed', split='train')

In [None]:
class RQA_dpr:
    def __init__(self, task='nq') -> None:
        self.task = task
        self.query_data, self.validated_data, self.elements = self.load_dataset()
    
    def load_dataset(self) -> None:
        if self.task == 'nq':
            with open("data/biencoder-nq-dev.json", "r") as fin:
                nq_dpr = json.load(fin)
        elif self.task == 'trivia':
            with open("data/biencoder-qas-dev.json", "r") as fin:
                nq_dpr = json.load(fin)
        
        elements = []
        query_data = []
        validated_data = {}
        for idx, record in enumerate(nq_dpr):
            elements.append(record)
            validated_data[idx] = record
            query_data.append( 
                record['question']
            )
        return query_data, validated_data, elements

In [None]:
dataset_dpr = RQA_dpr()

In [298]:
# indices = np.arange(len(dataset_dpr.elements))
indices = np.arange(1000)
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

test_elements = utils.split(dataset_dpr.elements, test_indices)
cal_elements = utils.split(dataset_dpr.elements, cal_indices)

test_query = [element['question'] for element in test_elements]
cal_query = [element['question'] for element in cal_elements]

In [None]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

In [299]:
question_embedding = q_encoder(**q_tokenizer(cal_query, return_tensors="pt", padding=True))
question_embedding = question_embedding[0].numpy()
scores, retrieved_examples = wiki.get_nearest_examples_batch('embeddings', question_embedding, k=50)

In [319]:
queries = []
answers = []
retrieved_texts = []
retrieved_scores = []
passages = []
retrieved_true_scores = []
for element, score, retrieved in zip(cal_elements, scores, retrieved_examples):
    query, answer, passage_id, passage_title, passage_text = \
        utils.dataset_info(element)
    retrieved_ids, retrieved_texts, retrieved_title, true_score = \
        utils.retrieved_info(score, retrieved, passage_id[0])
    if len(true_score) == 0:
        continue
    
    retrieved_true_scores.append(true_score)
    queries.append(query)
    answers.append(answer)
    passages.append(passage_text)

## Setup chatgpt

In [239]:
utils.setup_openai()

In [None]:
chatgpt_true_scores = []
chat = True
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)
with torch.no_grad():
    for idx, element in enumerate(tqdm(cal_elements)):
        
        query, answer, passage_id, passage_title, passage_text = \
            utils.dataset_info(element)
    
        sequences, prompt = utils.ask(query, passage_text[0], chat)
        
        semantic_set_ids, semantic_probs, item_occurance = \
            utils.clustering(sequences, prompt, scorer=scorer)
            
        true_scores, matched_answer, semantics = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        chatgpt_true_scores.append(true_scores)

In [308]:
retrieved_threshold = utils.compute_threshold(retrieved_true_scores, alpha=0.05, shuffle=True)

Most relevant threshold: 64.70359802246094
Most relevant coverage: 0.963855421686747


In [305]:
chatgpt_thr_qa = utils.compute_threshold(chatgpt_true_scores, alpha=0.05, shuffle=True)

Most relevant threshold: 0.16666666666666666
Most relevant coverage: 0.9589521452145214


In [309]:
question_embedding = q_encoder(**q_tokenizer(test_query, return_tensors="pt", padding=True))
question_embedding = question_embedding[0].numpy()
test_scores, test_retrieved_examples = wiki.get_nearest_examples_batch('embeddings', question_embedding, k=50)

In [334]:
import importlib
importlib.reload(utils)

<module 'utils' from '/home/lishuo1/retriever_uncertainty/TRAC/utils.py'>

In [348]:
def check_valid(element, score, retrieved, chat=True):
    query, answer, passage_id, passage_title, passage_text = \
        utils.dataset_info(element)
    retrieved_ids, retrieved_texts, retrieved_title, true_score = \
        utils.retrieved_info(score, retrieved, passage_id[0])

    if len(true_score) == 0:
        return False, True

    sequences, prompt = utils.ask(query, passage_text[0], chat, task='Natural Questions')
    semantic_set_ids, semantic_probs, item_occurance = \
        utils.clustering(sequences, prompt, scorer=scorer)
    true_scores, matched_answer, semantics = utils.processing_answers(
        semantic_set_ids, semantic_probs, 
        item_occurance, answer, scorer,
        threshold=0.3
    )
    if len(true_scores) == 0:
        return False, True
    elif np.sum(np.array(true_scores) >= chatgpt_thr_qa) == 0:
        chatgpt_covered.append(False)
        return True, False
    else:
        return True, True

In [None]:
chatgpt_covered = []
chatgpt_sizes = []
queries = []
with torch.no_grad():
    for idx, (element, score, retrieved) in enumerate(zip(test_elements, test_scores, test_retrieved_examples)):
        print(idx)
        query, answer, passage_id, passage_title, passage_text = \
            utils.dataset_info(element)
        retrieved_ids, retrieved_texts, retrieved_title, true_score = \
            utils.retrieved_info(score, retrieved, passage_id[0])
        
        if idx > 100:
            break
        
        valid, covered = check_valid(element, score, retrieved, chatgpt_thr_qa)
        if not valid:
            continue
        elif not covered:
            print(False)
            chatgpt_covered.append(False)
            continue
        
        cover = False
        tmp = []
        query_count = 0
        for ctx_idx, (context, s) in enumerate(zip(contexts, score)):
            if s < retrieved_threshold:
                continue
            query_count += 1
            sequences, prompt = utils.ask(query, context, chat, task='Natural Questions')
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.clustering(sequences, prompt, scorer=scorer)
            true_scores, matched_answer, semantics = utils.processing_answers(
                semantic_set_ids, semantic_probs, 
                item_occurance, answer, scorer,
                threshold=0.3)

            tmp.extend(semantics)
            if len(true_scores) >= 1:
                cover = True
                break
        print(cover)
        chatgpt_covered.append(cover)
        chatgpt_sizes.append(len(tmp))
        queries.append(query_count)

0
1
2


In [315]:
print('coverage', np.mean(chatgpt_covered))
print('average size', np.mean(chatgpt_sizes))
print('average query count', np.mean(queries))

coverage 0.8936170212765957
average size 54.651162790697676
average query count 1.9767441860465116


In [148]:
chatgpt_semantic_sizes = []
for semantic_meaning in semantics:
    semantic_set_ids, semantic_probs, item_occurance = \
        utils.clustering(semantic_meaning, prompt, scorer=scorer)
    chatgpt_semantic_sizes.append(len(semantic_set_ids.keys()))

In [149]:
print('average semantic size', np.mean(chatgpt_semantic_sizes))

average semantic size 6.4
