# Import necessary packages

In [16]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0, 1, 2, 3"
from kilt import retrieval
from kilt import kilt_utils as utils
import tasks
import utils
from rouge_score import rouge_scorer
import random
import numpy as np
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from datasets import load_dataset
import json
from tqdm import tqdm
import opensource

# Set up indexer

In [None]:
torch.set_grad_enabled(False)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base", device_map='cuda')
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base", device_map='cuda')
wiki = load_dataset(path='wiki_dpr', name='psgs_w100.multiset.compressed', split='train')

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenize

In [11]:
task='trivia'
dataset_dpr = tasks.RQA_dpr(task=task)

In [12]:
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)
if semantic:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    # setup semantic model
    semantic_tokenizer = \
        AutoTokenizer.from_pretrained("microsoft/deberta-large-mnli")
    semantic_model = \
        AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-large-mnli"
        ).cuda()

# Collect data

In [13]:
indices = np.arange(len(dataset_dpr.elements))
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

elements = dataset_dpr.elements
query = [element['question'] for element in elements]

In [14]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")

Some weights of the model checkpoint at facebook/dpr-question_encoder-multiset-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
question_embedding = q_encoder(**q_tokenizer(query, return_tensors="pt", padding=True))
question_embedding = question_embedding[0].numpy()
scores, retrieved_examples = wiki.get_nearest_examples_batch('embeddings', question_embedding, k=20)

## Setup opensource model

In [None]:
model, pipeline, tokenizer = opensource.setup_openmodel(model='lmsys/vicuna-7b-v1.5-16k')

In [None]:
import pickle
def write_list(a_list, file_name):
    # store list in binary file so 'wb' mode
    with open(file_name, 'wb') as fp:
        pickle.dump(a_list, fp)
#         print('Done writing list into a binary file')
def read_list(file_name):
    # for reading also binary mode is important
    with open(file_name, 'rb') as fp:
        n_list = pickle.load(fp)
        return n_list

def save_results(task):
    # save retrieved_scores to a pickle file
    write_list(retrieved_scores, f'uncertain_retrieved_scores_{task}.p')
    # save retrieved_true_scores to a pickle file
    write_list(retrieved_true_scores, f'uncertain_retrieved_true_scores_{task}.p')
    # save queries to a pickle file
    write_list(queries, f'uncertain_queries_{task}.p')
    # save answers to a pickle file
    write_list(answers, f'uncertain_answers_{task}.p')
    # save passages to a pickle file
    write_list(passages, f'uncertain_passages_{task}.p')
    # save opensource_true_scores to a pickle file
    write_list(opensource_true_scores, f'uncertain_opensource_true_scores_{task}.p')
    # save opensource_texts to a pickle file
#     write_list(opensource_texts, f'opensource_texts_{task}.p')
    # save opensource_answers to a pickle file
    write_list(opensource_answers, f'uncertain_opensource_answers_{task}.p')
    # save opensource_semantics to a picle file
    write_list(opensource_semantics, f'uncertain_opensource_semantics_{task}.p')
    # save occurances to a pickle file
    write_list(occurances, f'uncertain_occurances_{task}.p')
    # save semantic_ids to a pickle file
    write_list(semantic_ids, f'uncertain_semantic_ids_{task}.p')
    # save probs to a picle file
    write_list(probs, f'uncertain_probs_{task}.p')

## Start Collect

In [None]:
%%time
queries = []
answers = []
passages = []
retrieved_scores = []
retrieved_true_scores = []
opensource_true_scores = []
opensource_texts = []
opensource_answers = []
opensource_semantics = []
semantic_probs = []
feasibilities = []
occurances = []
semantic_ids = []
probs = []
with torch.no_grad():
    for idx, (element, score, retrieved) in enumerate(zip(elements, scores, retrieved_examples)):
        print(f'{idx}, {task}', file=open(f'uncertaint_{task}.txt', 'a'))
        feasible = False
        if idx % 10 == 0:
            print(idx)
            save_results(task)
        query, answer, passage_id, passage_title, passage_text = \
            utils.dataset_info(element, dataset=task)
        if len(passage_id) == 0:
            continue
        retrieved_ids, retrieved_texts, retrieved_title, true_score = \
            utils.retrieved_info(score, retrieved, passage_id[0])
        if len(true_score) == 0:
            continue
        
        prompt = utils.get_prompt_template(query, passage_text[0], task='Natural Questions')
        sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer, return_sequences=30)
        generated_texts = []
        for seq in sequences:
            generated_texts.append(seq['generated_text'][len(prompt):].strip())
        semantic_set_ids, semantic_probs, item_occurance = \
            utils.clustering(generated_texts, prompt, scorer=scorer)
        true_scores, matched_answer, semantics = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        if len(true_scores) == 0:
            retrieved_scores.append(score)
            retrieved_true_scores.append(true_score)
            queries.append(query)
            answers.append(answer)
            passages.append(passage_text)
            opensource_true_scores.append(true_scores)
            opensource_answers.append(generated_texts)
            occurances.append(item_occurance)
            semantic_ids.append(semantic_set_ids)
            probs.append(semantic_probs)