## Import necessary packages

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1, 2, 3"
from kilt import retrieval
from kilt import kilt_utils as utils
import tasks
import utils
from rouge_score import rouge_scorer
import random
import numpy as np
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from datasets import load_dataset
import json
from tqdm import tqdm
import opensource

## Set up indexer

In [2]:
torch.set_grad_enabled(False)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base", device_map='cuda')
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base", device_map='cuda')
wiki = load_dataset(path='wiki_dpr', name='psgs_w100.multiset.compressed', split='train')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


In [3]:
task='nq'
dataset_dpr = tasks.RQA_dpr(task=task)

In [None]:
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)
if semantic:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    # setup semantic model
    semantic_tokenizer = \
        AutoTokenizer.from_pretrained("microsoft/deberta-large-mnli")
    semantic_model = \
        AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-large-mnli"
        ).cuda()

## Collect data

In [4]:
indices = np.arange(len(dataset_dpr.elements))
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

elements = dataset_dpr.elements
query = [element['question'] for element in elements]

In [5]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")

In [6]:
question_embedding = q_encoder(**q_tokenizer(query, return_tensors="pt", padding=True))
question_embedding = question_embedding[0].numpy()
scores, retrieved_examples = wiki.get_nearest_examples_batch('embeddings', question_embedding, k=50)

In [None]:
# save scores to a json file
with open(f'dpr_scores_{task}.json', 'w') as f:
    json.dump(scores, f)
# save retrieved examples to a json file
with open(f'dpr_retrieved_examples_{task}.json', 'w') as f:
    json.dump(retrieved_examples, f)

## Setup opensource model

In [None]:
model, pipeline, tokenizer = opensource.setup_openmodel()

## Start Collect

In [None]:
queries = []
answers = []
passages = []
retrieved_true_scores = []
opensource_true_scores = []
opensource_texts = []
opensource_answers = []
opensource_semantics = []
with torch.no_grad():
    for idx, (element, score, retrieved) in enumerate(zip(elements, scores, retrieved_examples)):
        print(idx)
        try:
            query, answer, passage_id, passage_title, passage_text = \
                utils.dataset_info(element)
            retrieved_ids, retrieved_texts, retrieved_title, true_score = \
                utils.retrieved_info(score, retrieved, passage_id[0])
            if len(true_score) == 0:
                continue
            
            prompt = utils.get_prompt_template(query, passage_text[0], task='Natural Questions')
            try:
                sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer, top_k=30)
            except:
                continue
            generated_texts = []
            for seq in sequences:
                generated_texts.append(seq['generated_text'][len(prompt):].strip())
            
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.clustering(generated_texts, prompt, scorer=scorer)
                
            true_scores, matched_answer, semantics = utils.processing_answers(
                semantic_set_ids, semantic_probs, 
                item_occurance, answer, scorer,
                threshold=0.3
            )
            if len(true_scores) == 0:
                continue
              
            retrieved_true_scores.append(true_score)
            queries.append(query)
            answers.append(answer)
            passages.append(passage_text)
            opensource_true_scores.append(true_scores)
            opensource_texts.append(generated_texts)

            answers_tmp = []
            semantics_tmp = []
            for ctx_idx, (context, s) in enumerate(zip(retrieved_texts, score)):
                prompt = utils.get_prompt_template(query, context, task='Natural Questions')
                sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer, top_k=30)
                for seq in sequences:
                    generated_texts.append(seq['generated_text'][len(prompt):].strip())
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.clustering(generated_texts, prompt, scorer=scorer)
                true_scores, matched_answer, semantics = utils.processing_answers(
                    semantic_set_ids, semantic_probs, 
                    item_occurance, answer, scorer,
                    threshold=0.3)
                answers_tmp.extend(generated_texts)
                semantics_tmp.extend(semantics)
            opensource_answers.append(answers_tmp)
            opensource_semantics.extend(semantics_tmp)
        except:
            pass

## Save results

In [None]:
# save retrieved_true_scores to a json file
with open(f'retrieved_true_scores_{task}.json', 'w') as f:
    json.dump(retrieved_true_scores, f)
# save queries to a json file
with open(f'queries_{task}.json', 'w') as f:
    json.dump(queries, f)
# save answers to a json file
with open(f'answers_{task}.json', 'w') as f:
    json.dump(answers, f)
# save passages to a json file
with open(f'passages_{task}.json', 'w') as f:
    json.dump(passages, f)
# save opensource_true_scores to a json file
with open(f'opensource_true_scores_{task}.json', 'w') as f:
    json.dump(opensource_true_scores, f)
# save opensource_texts to a json file
with open(f'opensource_texts_{task}.json', 'w') as f:
    json.dump(opensource_texts, f)
# save opensource_answers to a json file
with open(f'opensource_answers_{task}.json', 'w') as f:
    json.dump(opensource_answers, f)
# save opensource_semantics to a json file
with open(f'opensource_semantics_{task}.json', 'w') as f:
    json.dump(opensource_semantics, f)


In [None]:
indices = np.arange(len(dataset_dpr.elements))
random.shuffle(indices)
cal_indices = indices[:int(len(indices) * 0.5)]
test_indices = indices[int(len(indices) * 0.5):]

test_elements = utils.split(dataset_dpr.elements, test_indices)
cal_elements = utils.split(dataset_dpr.elements, cal_indices)

test_query = [element['question'] for element in test_elements]
cal_query = [element['question'] for element in cal_elements]

In [None]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")

In [None]:
question_embedding = q_encoder(**q_tokenizer(cal_query, return_tensors="pt", padding=True))
question_embedding = question_embedding[0].numpy()
scores, retrieved_examples = wiki.get_nearest_examples_batch('embeddings', question_embedding, k=50)

In [None]:
queries = []
answers = []
retrieved_texts = []
retrieved_scores = []
passages = []
retrieved_true_scores = []
for element, score, retrieved in zip(cal_elements, scores, retrieved_examples):
    query, answer, passage_id, passage_title, passage_text = \
        utils.dataset_info(element)
    retrieved_ids, retrieved_texts, retrieved_title, true_score = \
        utils.retrieved_info(score, retrieved, passage_id[0])
    if len(true_score) == 0:
        continue
    
    retrieved_true_scores.append(true_score)
    queries.append(query)
    answers.append(answer)
    passages.append(passage_text)

## Setup chatgpt

In [None]:
utils.setup_openai()

In [None]:
chatgpt_true_scores = []
chat = True
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)
with torch.no_grad():
    for idx, element in enumerate(tqdm(cal_elements)):
        
        query, answer, passage_id, passage_title, passage_text = \
            utils.dataset_info(element)
    
        sequences, prompt = utils.ask(query, passage_text[0], chat)
        
        semantic_set_ids, semantic_probs, item_occurance = \
            utils.clustering(sequences, prompt, scorer=scorer)
            
        true_scores, matched_answer, semantics = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        chatgpt_true_scores.append(true_scores)

In [None]:
chatgpt_thr_qa = utils.compute_threshold(chatgpt_true_scores, alpha=0.05, shuffle=True)

## Setup opensource

In [None]:
import opensource

In [None]:
model, pipeline, tokenizer = opensource.setup_openmodel()

In [None]:
opensource_true_scores = []
with torch.no_grad():
    for idx, element in enumerate(tqdm(cal_elements)):
        
        
        query, answer, passage_id, passage_title, passage_text = \
            utils.dataset_info(element)
    
        prompt = utils.get_prompt_template(query, passage_text[0], task='Natural Questions')
        try:
            sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer, top_k=30)
        except:
            continue
        generated_texts = []
        for seq in sequences:
            generated_texts.append(seq['generated_text'][len(prompt):].strip())
        sequences = generated_texts
        
        semantic_set_ids, semantic_probs, item_occurance = \
            utils.clustering(sequences, prompt, scorer=scorer)
            
        true_scores, matched_answer, semantics = utils.processing_answers(
            semantic_set_ids, semantic_probs, 
            item_occurance, answer, scorer,
            threshold=0.3
        )
        opensource_true_scores.append(true_scores)

In [None]:
retrieved_threshold = utils.compute_threshold(retrieved_true_scores, alpha=0.05, shuffle=True)

In [None]:
opensource_thr_qa = utils.compute_threshold(opensource_true_scores, alpha=0.1, shuffle=True)

In [None]:
question_embedding = q_encoder(**q_tokenizer(test_query, return_tensors="pt", padding=True))
question_embedding = question_embedding[0].numpy()
test_scores, test_retrieved_examples = wiki.get_nearest_examples_batch('embeddings', question_embedding, k=50)

In [None]:
import importlib
importlib.reload(utils)

In [None]:
def check_valid(element, score, retrieved, thr_qa, chat=True):
    query, answer, passage_id, passage_title, passage_text = \
        utils.dataset_info(element)
    retrieved_ids, retrieved_texts, retrieved_title, true_score = \
        utils.retrieved_info(score, retrieved, passage_id[0])

    if len(true_score) == 0:
        return False, True

    sequences, prompt = utils.ask(query, passage_text[0], chat, task='Natural Questions')
    semantic_set_ids, semantic_probs, item_occurance = \
        utils.clustering(sequences, prompt, scorer=scorer)
    true_scores, matched_answer, semantics = utils.processing_answers(
        semantic_set_ids, semantic_probs, 
        item_occurance, answer, scorer,
        threshold=0.3
    )
    if len(true_scores) == 0:
        return False, True
    elif np.sum(np.array(true_scores) >= chatgpt_thr_qa) == 0:
        chatgpt_covered.append(False)
        return True, False
    else:
        return True, True

In [None]:
def check_valid_opensource(element, score, retrieved, thr_qa):
    query, answer, passage_id, passage_title, passage_text = \
        utils.dataset_info(element)
    retrieved_ids, retrieved_texts, retrieved_title, true_score = \
        utils.retrieved_info(score, retrieved, passage_id[0])

    if len(true_score) == 0:
        return False, True

#     sequences, prompt = utils.ask(query, passage_text[0], chat, task='Natural Questions')
    prompt = utils.get_prompt_template(query, passage_text[0], task='Natural Questions')
    sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer, top_k=30)
    generated_texts = []
    for seq in sequences:
        generated_texts.append(seq['generated_text'][len(prompt):].strip())
    sequences = generated_texts
    semantic_set_ids, semantic_probs, item_occurance = \
        utils.clustering(sequences, prompt, scorer=scorer)
    true_scores, matched_answer, semantics = utils.processing_answers(
        semantic_set_ids, semantic_probs, 
        item_occurance, answer, scorer,
        threshold=0.3
    )
    if len(true_scores) == 0:
        return False, True
    elif np.sum(np.array(true_scores) >= thr_qa) == 0:
        return True, False
    else:
        return True, True

In [None]:
chatgpt_covered = []
chatgpt_sizes = []
queries = []
with torch.no_grad():
    for idx, (element, score, retrieved) in enumerate(zip(test_elements, test_scores, test_retrieved_examples)):
        print(idx)
        query, answer, passage_id, passage_title, passage_text = \
            utils.dataset_info(element)
        retrieved_ids, retrieved_texts, retrieved_title, true_score = \
            utils.retrieved_info(score, retrieved, passage_id[0])
        
        if idx > 100:
            break
        
        valid, covered = check_valid(element, score, retrieved, chatgpt_thr_qa)
        if not valid:
            continue
        elif not covered:
            print(False)
            chatgpt_covered.append(False)
            continue
        
        cover = False
        tmp = []
        query_count = 0
        for ctx_idx, (context, s) in enumerate(zip(contexts, score)):
            if s < retrieved_threshold:
                continue
            query_count += 1
            sequences, prompt = utils.ask(query, context, chat, task='Natural Questions')
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.clustering(sequences, prompt, scorer=scorer)
            true_scores, matched_answer, semantics = utils.processing_answers(
                semantic_set_ids, semantic_probs, 
                item_occurance, answer, scorer,
                threshold=0.3)

            tmp.extend(semantics)
            if len(true_scores) >= 1:
                cover = True
                break
        print(cover)
        chatgpt_covered.append(cover)
        chatgpt_sizes.append(len(tmp))
        queries.append(query_count)

In [None]:
print('coverage', np.mean(chatgpt_covered))
print('average size', np.mean(chatgpt_sizes))
print('average query count', np.mean(queries))

In [None]:
chatgpt_semantic_sizes = []
for semantic_meaning in semantics:
    semantic_set_ids, semantic_probs, item_occurance = \
        utils.clustering(semantic_meaning, prompt, scorer=scorer)
    chatgpt_semantic_sizes.append(len(semantic_set_ids.keys()))

In [None]:
print('average semantic size', np.mean(chatgpt_semantic_sizes))

In [None]:
opensource_covered = []
opensource_sizes = []
queries = []
with torch.no_grad():
    for idx, (element, score, retrieved) in enumerate(zip(test_elements, test_scores, test_retrieved_examples)):
        print(idx)
        try:
            query, answer, passage_id, passage_title, passage_text = \
                utils.dataset_info(element)
            retrieved_ids, retrieved_texts, retrieved_title, true_score = \
                utils.retrieved_info(score, retrieved, passage_id[0])

            if idx > 100:
                break

            valid, covered = check_valid_opensource(element, score, retrieved, opensource_thr_qa)
            if not valid:
                continue
            elif not covered:
                print(False)
                opensource_covered.append(False)
                continue

            cover = False
            tmp = []
            query_count = 0
            for ctx_idx, (context, s) in enumerate(zip(retrieved_texts, score)):
                if s < retrieved_threshold:
                    continue
                query_count += 1
                prompt = utils.get_prompt_template(query, context, task='Natural Questions')
                sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer, top_k=30)
                for seq in sequences:
                    generated_texts.append(seq['generated_text'][len(prompt):].strip())
                sequences = generated_texts
                semantic_set_ids, semantic_probs, item_occurance = \
                    utils.clustering(sequences, prompt, scorer=scorer)
                true_scores, matched_answer, semantics = utils.processing_answers(
                    semantic_set_ids, semantic_probs, 
                    item_occurance, answer, scorer,
                    threshold=0.3)

                tmp.extend(semantics)
                if len(true_scores) >= 1:
                    cover = True
                    break
            print(cover)
            opensource_covered.append(cover)
            opensource_sizes.append(len(tmp))
            queries.append(query_count)
        except:
            pass

In [None]:
len(opensource_covered)

In [None]:
print('coverage', np.mean(opensource_covered))
print('average size', np.mean(opensource_sizes))
print('average query count', np.mean(queries))

In [None]:
opensource_semantic_sizes = []
for semantic_meaning in semantics:
    semantic_set_ids, semantic_probs, item_occurance = \
        utils.clustering(semantic_meaning, prompt, scorer=scorer)
    opensource_semantic_sizes.append(len(semantic_set_ids.keys()))

In [None]:
print('average semantic size', np.mean(opensource_semantic_sizes))