## Import necessary packages

In [131]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4,5,6,7"
from kilt import retrieval
from kilt import kilt_utils as utils
import tasks
from kilt.retrievers import DPR_connector
import utils
from rouge_score import rouge_scorer
import random
import numpy as np
import torch

## Set up indexer

In [3]:
retriever = DPR_connector.DPR.from_config_file(
    "dpr", "kilt/configs/retriever/default_dpr.json"
)

[139629549371776] 2023-09-18 14:04:57,613 [INFO] root: Reading saved model from models/dpr_multi_set_hf_bert.0
[139629549371776] 2023-09-18 14:04:58,573 [INFO] root: model_state_dict keys odict_keys(['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 'encoder_params'])
[139629549371776] 2023-09-18 14:04:58,621 [INFO] dpr.models.hf_models: Initializing HF BERT Encoder. cfg_name=bert-base-uncased
Some weights of the model checkpoint at bert-base-uncased were not used when initializing HFBertEncoder: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing HFBertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTrainin

## Setup dataset and get
- query
- golden passage titles
- retrieved passages
- answer

In [5]:
# ['Natural Questions', 'TriviaQA', 'FEVER']
task = 'Natural Questions'
dataset = tasks.RQA(task=task)
retriever.feed_data(dataset.query_data)
provenance = retriever.run()

[139629549371776] 2023-09-18 14:31:48,658 [INFO] root: Total encoded queries tensor torch.Size([2837, 768])
[139629549371776] 2023-09-18 14:59:25,898 [INFO] root: index search time: 1657.236124 sec.


In [7]:
query_data, validated_data, elements = \
    dataset.load_dataset()

In [192]:
indices = np.arange(len(elements))
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

test_elements = utils.split(elements, test_indices)
elements = utils.split(elements, cal_indices)

In [193]:
queries = []
answers = []
retrieved_texts = []
retrieved_scores = []
for element in elements:
    query_id = element['id']
    query = element['input']
    answer = [ans['answer'] for ans in element['output'] if "answer" in ans]
    wiki_id = [[wiki['wikipedia_id'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    wiki_title = [[wiki['title'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    ids = []
    for id in wiki_id:
        ids.extend(id)
    retrieved = provenance[query_id]
    retrieved_id = [ans['wikipedia_id'] for ans in retrieved]
    retrieved_title = [ans['wikipedia_title'] for ans in retrieved]
    retrieved_text = [ans['text'] for ans in retrieved]
    convert = utils.convert_list_to_dict(retrieved)
    score = [convert[id] for id in convert if id in ids]
    if len(score) == 0:
        continue
    
    queries.append(query)
    answers.append(answer)
    retrieved_texts.append(retrieved_text)
    retrieved_scores.append(score)

## Setup semantic model

In [121]:
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)
if semantic:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    # setup semantic model
    semantic_tokenizer = \
        AutoTokenizer.from_pretrained("microsoft/deberta-large-mnli")
    semantic_model = \
        AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-large-mnli"
        ).cuda()

[139629549371776] 2023-09-18 16:35:00,748 [INFO] absl: Using default tokenizer.


## Setup open source model

In [22]:
import opensource
import importlib
importlib.reload(opensource)
model, pipeline, tokenizer = opensource.setup_openmodel()

Loading checkpoint shards: 100%|██████████| 2/2 [00:35<00:00, 17.97s/it]
Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


## Setup prompt and ask open source model

In [123]:
opensource_true_scores = []
with torch.no_grad():
    for idx, (query, answer, contexts, score) \
        in enumerate(zip(queries, answers, retrieved_texts, retrieved_scores)):
        
        if idx > 2:
            break
        
        true_scores_tmp = []
        for context, s in zip(contexts, score):
            prompt = utils.get_prompt_template(query, context, task='Natural Questions')
            sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer)
            generated_texts = []
            for seq in sequences:
                generated_texts.append(seq['generated_text'][len(prompt):].strip())

            if semantic:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_semantic_clusterring(
                        semantic_model, 
                        semantic_tokenizer,
                        prompt,
                        generated_texts,
                    )
            else:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_keyword_clusterring(
                        generated_texts,
                        scorer
                    )
            true_scores, matched_answers = utils.processing_answers(
                semantic_set_ids, semantic_probs, 
                item_occurance, answer, scorer,
                threshold=0.3
            )
            true_scores_tmp.extend(true_scores)
        opensource_true_scores.append(true_scores_tmp)

## Setup chatgpt and ask chatgpt

In [124]:
utils.setup_openai()

In [125]:
chat = True
chatgpt_true_scores = []
for idx, (query, answer, contexts, score) \
    in enumerate(zip(queries, answers, retrieved_texts, retrieved_scores)):
    
    if idx > 2:
        break
    
    
    true_scores_tmp = []
    for context, s in zip(contexts, score):
        
        prompt = utils.get_prompt_template(query, context, task='Natural Questions')
        if chat:
            sequences = utils.ask_chatgpt(prompt)
        else:
            sequences, probs = utils.ask_chatgpt(prompt)
            
        if semantic:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_semantic_clusterring(
                    semantic_model, 
                    semantic_tokenizer,
                    prompt,
                    sequences,
                )
        else:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_keyword_clusterring(
                    sequences,
                    scorer
                )
        true_scores, matched_answer = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        true_scores_tmp.extend(true_scores)
    chatgpt_true_scores.append(true_scores_tmp)

- retrieved_scores: true scores for retriever
- opensource_true_scores: true scores for open source model
- chatgpt_true_scores: true scores for chatgpt

In [140]:
import importlib
importlib.reload(utils)

<module 'utils' from '/home/lishuo1/retriever_uncertainty/TRAC/utils.py'>

## Compute threshold on calibration set

In [194]:
retrieved_threshold = utils.compute_threshold(retrieved_scores, alpha=0.025, shuffle=True)

Most relevant threshold: 69.74229
Most relevant coverage: 0.976


In [195]:
opensource_thr_qa = utils.compute_threshold(opensource_true_scores, alpha=0.025, shuffle=True)

Most relevant threshold: 0.4
Most relevant coverage: 0.7142857142857143


In [196]:
chatgpt_thr_qa = utils.compute_threshold(chatgpt_true_scores, alpha=0.025, shuffle=True)

Most relevant threshold: 1.0
Most relevant coverage: 1.0


## Evaluate thresholds on testing set

In [203]:
queries = []
answers = []
retrieved_texts = []
covered = []
for element in test_elements:
    query_id = element['id']
    query = element['input']
    answer = [ans['answer'] for ans in element['output'] if "answer" in ans]
    wiki_id = [[wiki['wikipedia_id'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    wiki_title = [[wiki['title'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    ids = []
    for id in wiki_id:
        ids.extend(id)
    all_retrieved = [r for r in provenance[query_id]]
    all_id = [ans['wikipedia_id'] for ans in all_retrieved]
    retrieved = [r for r in provenance[query_id] if float(r['score']) >= retrieved_threshold]
    retrieved_id = [ans['wikipedia_id'] for ans in retrieved]
    retrieved_title = [ans['wikipedia_title'] for ans in retrieved]
    retrieved_text = [ans['text'] for ans in retrieved]
    
    if len(utils.intersection(all_id, ids)) == 0:
        continue
    
    covered.append(len(utils.intersection(retrieved_id, ids))>=1)
    
    queries.append(query)
    answers.append(answer)
    retrieved_texts.append(retrieved_text)

In [204]:
print("coverage rate", np.mean(covered))

coverage rate 0.9984615384615385


In [205]:
opensource_covered = []
opensource_thr_qa = 0.5
with torch.no_grad():
    for idx, (query, answer, contexts) \
        in enumerate(zip(queries, answers, retrieved_texts)):
        
        if idx > 3:
            break
        
        cover = False
        for context, s in zip(contexts, score):
            prompt = utils.get_prompt_template(query, context, task='Natural Questions')
            sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer)
            generated_texts = []
            for seq in sequences:
                generated_texts.append(seq['generated_text'][len(prompt):].strip())

            if semantic:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_semantic_clusterring(
                        semantic_model, 
                        semantic_tokenizer,
                        prompt,
                        generated_texts,
                    )
            else:
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.compute_keyword_clusterring(
                        generated_texts,
                        scorer
                    )
            true_scores, matched_answers = utils.processing_answers(
                semantic_set_ids, semantic_probs, 
                item_occurance, answer, scorer,
                threshold=0.3, thr_qa=opensource_thr_qa
            )
            if len(true_scores) >= 1:
                cover = True
        opensource_covered.append(cover)

KeyboardInterrupt: 

In [None]:
print("coverage rate", np.mean(opensource_covered))