## Import necessary packages

In [None]:
from kilt import retrieval
from kilt import kilt_utils as utils
import tasks
from kilt.retrievers import DPR_connector
import utils
from rouge_score import rouge_scorer

In [203]:
import importlib
importlib.reload(utils)

<module 'utils' from '/home/lishuo1/retriever_uncertainty/TRAC/utils.py'>

## Setup cuda

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4,5,6,7"

## Set up indexer

In [None]:
retriever = DPR_connector.DPR.from_config_file(
    "dpr", "kilt/configs/retriever/default_dpr.json"
)

## Setup dataset and get
- query
- golden passage titles
- retrieved passages
- answer

In [None]:
# ['Natural Questions', 'TriviaQA', 'FEVER']
task = 'Natural Questions'
dataset = tasks.RQA(task=task)
retriever.feed_data(dataset.query_data)
provenance = retriever.run()

In [None]:
query_data, validated_data, elements = \
    dataset.load_dataset()

In [None]:
queries = []
answers = []
retrieved_texts = []
scores = []
for element in elements:
    query_id = element['id']
    query = element['input']
    answer = [ans['answer'] for ans in element['output'] if "answer" in ans]
    wiki_id = [[wiki['wikipedia_id'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    wiki_title = [[wiki['title'] for wiki in ans['provenance']] for ans in element['output'] if 'provenance' in ans]
    ids = []
    for id in wiki_id:
        ids.extend(id)
    retrieved = provenance[query_id]
    retrieved_id = [ans['wikipedia_id'] for ans in retrieved]
    retrieved_title = [ans['wikipedia_title'] for ans in retrieved]
    retrieved_text = [ans['text'] for ans in retrieved]
    convert = utils.convert_list_to_dict(retrieved)
    score = [convert[id] for id in convert if id in ids]
    if len(score) == 0:
        continue
    
    queries.append(query)
    answers.append(answer)
    retrieved_texts.append(retrieved_text)
    scores.append(score)

## Setup semantic model

In [192]:
semantic = False
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"],
                                        use_stemmer=True)

[140665223045504] 2023-09-18 13:12:48,250 [INFO] absl: Using default tokenizer.


## Setup open source model

In [None]:
import opensource

In [None]:
pipeline, tokenizer = opensource.setup_openmodel()

In [204]:
importlib.reload(opensource)

<module 'opensource' from '/home/lishuo1/retriever_uncertainty/TRAC/opensource.py'>

## Setup prompt and ask open source model

In [206]:
for query, answer, contexts, score \
    in zip(queries, answers, retrieved_texts, scores):
    
    for context, s in zip(contexts, score):
        prompt = utils.get_prompt_template(query, context, task='Natural Questions')
        sequences = opensource.ask_openmodel(prompt, pipeline, tokenizer)
        generated_texts = []
        for seq in sequences:
            generated_texts.append(seq['generated_text'][len(prompt):].strip())
        
        if semantic:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_semantic_clusterring(
                    semantic_model, 
                    semantic_tokenizer,
                    prompt,
                    generated_texts,
                )
        else:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_keyword_clusterring(
                    generated_texts,
                    scorer
                )
        print(semantic_set_ids)
        break
    break

{'therefore': 0, 'Therefore': 0, 'Triangle': 2, 'triangle': 2}


## Setup chatgpt and ask chatgpt

In [190]:
utils.setup_openai()

In [191]:
chat = True
if semantic:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    # setup semantic model
    semantic_tokenizer = \
        AutoTokenizer.from_pretrained("microsoft/deberta-large-mnli")
    semantic_model = \
        AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-large-mnli"
        ).cuda()
    
for query, answer, contexts, score \
    in zip(queries, answers, retrieved_texts, scores):
    for context, s in zip(contexts, score):
        
        prompt = utils.get_prompt_template(query, context, task='Natural Questions')
        if chat:
            sequences = utils.ask_chatgpt(prompt)
        else:
            sequences, probs = utils.ask_chatgpt(prompt)
    
        for seq in sequences:
            print(seq)
            
        if semantic:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_semantic_clusterring(
                    semantic_model, 
                    semantic_tokenizer,
                    prompt,
                    sequences,
                )
        else:
            semantic_set_ids, semantic_probs, item_occurance = \
                utils.compute_keyword_clusterring(
                    sequences,
                    scorer
                )
        print(semantic_set_ids)
        break
    break

[140665223045504] 2023-09-18 13:10:38,788 [INFO] absl: Using default tokenizer.


conclusion
consequence
conclusion
consequence
conclusion
{'conclusion': 0, 'consequence': 1}
