## Import necessary packages

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1, 2, 3"
from kilt import retrieval
from kilt import kilt_utils as utils
import tasks
from kilt.retrievers import DPR_connector
import utils
from rouge_score import rouge_scorer
import random
import numpy as np
import torch

## Set up indexer

In [None]:
retriever = DPR_connector.DPR.from_config_file(
    "dpr", "kilt/configs/retriever/default_dpr.json"
)

## Setup dataset and get
- query
- golden passage titles
- retrieved passages
- answer

In [None]:
# ['Natural Questions', 'TriviaQA', 'FEVER']
task = 'Natural Questions'
dataset = tasks.RQA(task=task)
retriever.feed_data(dataset.query_data)
provenance = retriever.run()

In [None]:
query_data, validated_data, elements = \
    dataset.load_dataset()

In [None]:
elements[0]

In [None]:
indices = np.arange(len(elements))
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

test_elements = utils.split(elements, test_indices)
elements = utils.split(elements, cal_indices)

In [None]:
queries = []
answers = []
retrieved_texts = []
retrieved_scores = []
for element in elements:
    query_id = element['id']
    query = element['input']
    answer = [ans['answer'] for ans in element['output'] if "answer" in ans]
    wiki_id = [[wiki['wikipedia_id'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    wiki_title = [[wiki['title'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    ids = []
    for id in wiki_id:
        ids.extend(id)
    retrieved = provenance[query_id]
    retrieved_id = [ans['wikipedia_id'] for ans in retrieved]
    retrieved_title = [ans['wikipedia_title'] for ans in retrieved]
    retrieved_text = [ans['text'] for ans in retrieved]
    convert = utils.convert_list_to_dict(retrieved)
    score = [convert[id] for id in convert if id in ids]
    if len(score) == 0:
        continue
    
    queries.append(query)
    answers.append(answer)
    retrieved_texts.append(retrieved_text)
    retrieved_scores.append(score)

## Setup semantic model

In [None]:
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)
if semantic:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    # setup semantic model
    semantic_tokenizer = \
        AutoTokenizer.from_pretrained("microsoft/deberta-large-mnli")
    semantic_model = \
        AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-large-mnli"
        ).cuda()

## Setup open source model

In [None]:
import opensource
import importlib
importlib.reload(opensource)
model, pipeline, tokenizer = opensource.setup_openmodel()

## Setup prompt and ask open source model

In [None]:
opensource_true_scores = []
with torch.no_grad():
    for idx, (query, answer, contexts, score) \
        in enumerate(zip(queries, answers, retrieved_texts, retrieved_scores)):
        
        if idx > 2:
            break
        
        true_scores_tmp = []
        for context, s in zip(contexts, score):
            prompt = utils.get_prompt_template(query, "", task='Natural Questions')
            sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer)
            generated_texts = []
            for seq in sequences:
                generated_texts.append(seq['generated_text'][len(prompt):].strip())

            if semantic:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_semantic_clusterring(
                        semantic_model, 
                        semantic_tokenizer,
                        prompt,
                        generated_texts,
                    )
            else:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_keyword_clusterring(
                        generated_texts,
                        scorer
                    )
            true_scores, matched_answers = utils.processing_answers(
                semantic_set_ids, semantic_probs, 
                item_occurance, answer, scorer,
                threshold=0.3
            )
            true_scores_tmp.extend(true_scores)
        opensource_true_scores.append(true_scores_tmp)

In [None]:
prompt

## Setup chatgpt and ask chatgpt

In [None]:
utils.setup_openai()

In [None]:
chat = True
chatgpt_true_scores = []
for idx, (query, answer, contexts, score) \
    in enumerate(zip(queries, answers, retrieved_texts, retrieved_scores)):
    
    if idx > 2:
        break
    
    
    true_scores_tmp = []
    for context, s in zip(contexts, score):
        
        prompt = utils.get_prompt_template(query, context, task='Natural Questions')
        if chat:
            sequences = utils.ask_chatgpt(prompt)
        else:
            sequences, probs = utils.ask_chatgpt(prompt)
            
        if semantic:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_semantic_clusterring(
                    semantic_model, 
                    semantic_tokenizer,
                    prompt,
                    sequences,
                )
        else:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_keyword_clusterring(
                    sequences,
                    scorer
                )
        true_scores, matched_answer = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        true_scores_tmp.extend(true_scores)
    chatgpt_true_scores.append(true_scores_tmp)

- retrieved_scores: true scores for retriever
- opensource_true_scores: true scores for open source model
- chatgpt_true_scores: true scores for chatgpt

In [None]:
import importlib
importlib.reload(utils)
importlib.reload(opensource)

## Compute threshold on calibration set

In [None]:
retrieved_threshold = utils.compute_threshold(retrieved_scores, alpha=0.025, shuffle=True)

In [None]:
opensource_thr_qa = utils.compute_threshold(opensource_true_scores, alpha=0.025, shuffle=True)

In [None]:
chatgpt_thr_qa = utils.compute_threshold(chatgpt_true_scores, alpha=0.025, shuffle=True)

## Evaluate thresholds on testing set

In [None]:
queries = []
answers = []
retrieved_texts = []
covered = []
for element in test_elements:
    query_id = element['id']
    query = element['input']
    answer = [ans['answer'] for ans in element['output'] if "answer" in ans]
    wiki_id = [[wiki['wikipedia_id'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    wiki_title = [[wiki['title'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    ids = []
    for id in wiki_id:
        ids.extend(id)
    all_retrieved = [r for r in provenance[query_id]]
    all_id = [ans['wikipedia_id'] for ans in all_retrieved]
    retrieved = [r for r in provenance[query_id] if float(r['score']) >= retrieved_threshold]
    retrieved_id = [ans['wikipedia_id'] for ans in retrieved]
    retrieved_title = [ans['wikipedia_title'] for ans in retrieved]
    retrieved_text = [ans['text'] for ans in retrieved]
    
    if len(utils.intersection(all_id, ids)) == 0:
        continue
    
    covered.append(len(utils.intersection(retrieved_id, ids))>=1)
    
    queries.append(query)
    answers.append(answer)
    retrieved_texts.append(retrieved_text)

In [None]:
print("coverage rate", np.mean(covered))

In [None]:
opensource_covered = []
opensource_thr_qa = 0.5
with torch.no_grad():
    for idx, (query, answer, contexts) \
        in enumerate(zip(queries, answers, retrieved_texts)):
        
        if idx > 3:
            break
        
        cover = False
        for context, s in zip(contexts, score):
            prompt = utils.get_prompt_template(query, context, task='Natural Questions')
            sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer)
            generated_texts = []
            for seq in sequences:
                generated_texts.append(seq['generated_text'][len(prompt):].strip())

            if semantic:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_semantic_clusterring(
                        semantic_model, 
                        semantic_tokenizer,
                        prompt,
                        generated_texts,
                    )
            else:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_keyword_clusterring(
                        generated_texts,
                        scorer
                    )
            true_scores, matched_answers = utils.processing_answers(
                semantic_set_ids, semantic_probs, 
                item_occurance, answer, scorer,
                threshold=0.3, thr_qa=opensource_thr_qa
            )
            if len(true_scores) >= 1:
                cover = True
        opensource_covered.append(cover)

In [None]:
print("coverage rate", np.mean(opensource_covered))

In [None]:
import json
with open("data/biencoder-nq-dev.json", "r") as fin:
            nq_dpr = json.load(fin)

In [None]:
nq_dpr[2]

In [None]:
with open("data/nq-dev-kilt.jsonl", "r") as fin:
            nq_kilt = json.load(fin)

In [None]:
class RQA_dpr:
    def __init__(self, task='nq') -> None:
        self.task = task
        self.query_data, self.validated_data, self.elements = self.load_dataset()
    
    def load_dataset(self) -> None:
        with open("data/biencoder-nq-dev.json", "r") as fin:
            nq_dpr = json.load(fin)
        
        elements = []
        query_data = []
        validated_data = {}
        for idx, record in enumerate(nq_dpr):
            elements.append(record)
            validated_data[idx] = record
            query_data.append(
                {"query": record["question"], "id": idx}
            )
        return query_data, validated_data, elements

In [None]:
dataset_dpr = RQA_dpr()

In [None]:
len(dataset_dpr.query_data)

In [None]:
retriever.feed_data(dataset.query_data[:500])
provenance_dpr = retriever.run()

In [None]:
indices = np.arange(len(dataset_dpr.elements))
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

test_elements = utils.split(dataset_dpr.elements, test_indices)
elements_dpr = utils.split(dataset_dpr.elements, cal_indices)

In [None]:
queries = []
answers = []
retrieved_texts = []
retrieved_scores = []
passages = []
for query_id, element in zip(cal_indices.tolist(), elements_dpr):
    # extract data information
    query = element['question']
    answer = [ans for ans in element['answers']]
    passage_id = [ctx['passage_id'] for ctx in element['positive_ctxs']]
    passage_title = [ctx['title'] for ctx in element['positive_ctxs']]
    passage_text = [ctx['text'] for ctx in element['positive_ctxs']]
    
#     retrieved = provenance_dpr[query_id]
#     retrieved_id = [ans['wikipedia_id'] for ans in retrieved]
#     retrieved_title = [ans['wikipedia_title'] for ans in retrieved]
#     retrieved_text = [ans['text'] for ans in retrieved]
#     convert = utils.convert_list_to_dict(retrieved)
#     score = [convert[id] for id in convert if id in ids]
#     if len(score) == 0:
#         continue
    
    queries.append(query)
    answers.append(answer)
    passages.append(passage_text)
#     retrieved_texts.append(retrieved_text)
#     retrieved_scores.append(score)

In [None]:
opensource_true_scores = []
with torch.no_grad():
    for idx, (query, answer) in enumerate(zip(queries, answers)):
    
        if idx > 2:
            break
        
        if chat:
            sequences = utils.ask_chatgpt(query)
        else:
            sequences, probs = utils.ask_chatgpt(query)
            
        if semantic:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_semantic_clusterring(
                    semantic_model, 
                    semantic_tokenizer,
                    prompt,
                    sequences,
                )
        else:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_keyword_clusterring(
                    sequences,
                    scorer
                )
        true_scores, matched_answer = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        chatgpt_true_scores.append(true_scores_tmp)

In [None]:
semantic_set_ids

In [None]:
semantic_probs

In [None]:
opensource_true_scores = []
with torch.no_grad():
    for idx, (query, answer, passage) in enumerate(zip(queries, answers, passages)):
    
        if idx > 2:
            break
        
        query = utils.get_prompt_template(query, passage[0], task='Natural Questions')
        if chat:
            sequences = utils.ask_chatgpt(query)
        else:
            sequences, probs = utils.ask_chatgpt(query)
            
        if semantic:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_semantic_clusterring(
                    semantic_model, 
                    semantic_tokenizer,
                    prompt,
                    sequences,
                )
        else:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_keyword_clusterring(
                    sequences,
                    scorer
                )
        true_scores, matched_answer = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        chatgpt_true_scores.append(true_scores_tmp)

In [None]:
semantic_probs

In [None]:
answers[2]