In [None]:
import json
import pandas as pd

def read_json(file_path: str) -> dict:
    """Read a json file and return a dict."""
    with open(file_path) as f:
        data = json.load(f)
    return data

response = read_json('/kaggle/input/retriever-responses/cosine_sim_response.json')
corpus = read_json('/kaggle/input/wiki-data/wiki_musique_corpus.json')
dev = read_json('/kaggle/input/wiki-data/dev.json')[:1200]

# faster way to index the data
dev_df = pd.DataFrame(dev)
dev_df.set_index('_id', inplace=True)

def get_ground_truth(idx, supporting_facts=False):
    """Extract the question, oracle contexts and answer from dev.
    The oracle contexts are tuples (title, text).
    """
    row = dev_df.loc[idx]
    if supporting_facts:
        sfs = [t for t, _ in row['supporting_facts']]
        oracle = [(title, ' '.join(texts)) for title, texts in row['context'] if title in sfs]
    else:
        oracle = [(title, ' '.join(texts)) for title, texts in row['context']]
    return row['question'], oracle, row['answer']

def get_top_k_contexts(responses, k=5):
    """Extract the top k similar contexts from the corpus. Each
    context is a tuple (title, text).
    """
    top_k_contexts = []
    for i, (idx, score) in enumerate(responses.items()):
        if i < k:
            tittle = corpus[idx]['title']
            text = corpus[idx]['text']
            top_k_contexts.append((tittle, text))
    return top_k_contexts

def create_prompt(contexts, question):
    """Create a prompt for the model."""
    prompt = "Documents:\n"
    for i, (title, text) in enumerate(contexts):
        prompt += f"Document[{i+1}](Title: {title}) {text}\n"
    prompt += f"\nQuestion: {question}"
    return prompt

results = []
questions = []
for _id, similar_ctx_ids in response.items():
    question, oracle_ctxs, answer = get_ground_truth(_id, supporting_facts=True)
    questions.append(question)
    similar_top1_ctxs = get_top_k_contexts(similar_ctx_ids, k=1)
    similar_top3_ctxs = get_top_k_contexts(similar_ctx_ids, k=3)
    similar_top5_ctxs = get_top_k_contexts(similar_ctx_ids, k=5)
    
    # create prompts
    prompt_oracle = create_prompt(oracle_ctxs, question)
    prompt_top1_similar = create_prompt(similar_top1_ctxs, question)
    prompt_top3_similar = create_prompt(similar_top3_ctxs, question)
    prompt_top5_similar = create_prompt(similar_top5_ctxs, question)
    
    # save prompts in a list
    results.append({
        "prompt_oracle": prompt_oracle,
        "prompt_top1_similar": prompt_top1_similar,
        "prompt_top3_similar": prompt_top3_similar,
        "prompt_top5_similar": prompt_top5_similar,
        "answer": answer
    })

# save results in json file with nice indentation
with open('/kaggle/working/prompts.json', 'w') as f:
    json.dump(results, f, indent=4)

In [None]:
from time import time
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TextDataset, DataCollatorForLanguageModeling
from IPython.display import display, Markdown
from torch.utils.data import DataLoader

In [None]:
model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"

pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
def query_model_batch(system_message, user_messages, temperature=0.7, max_length=1024):
    start_time = time()
    batch_prompts = []
    for user_message in user_messages:
        user_message = user_message + " Answer:"
        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message},
        ]
        prompt = pipeline.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        batch_prompts.append(prompt)

    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    # Generate sequences in batch
    sequences = pipeline(
        batch_prompts,
        do_sample=True,
        top_p=0.9,
        temperature=temperature,
        eos_token_id=terminators,
        max_new_tokens=max_length,
        return_full_text=False,
        pad_token_id=pipeline.model.config.eos_token_id
    )
    
    responses = [seq[0]['generated_text'] for seq in sequences]
    end_time = time()
    ttime = f"Total time: {round(end_time-start_time, 2)} sec."

    return responses, ttime


In [None]:
def colorize_text(text):
    for word, color in zip(["Documents", "Question", "Answer", "Total time"], ["blue","red", "green", "magenta"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

In [None]:
prompts_oracle = [results[i]['prompt_oracle'] for i in range(len(results))]
prompts_top_1 = [results[i]['prompt_top1_similar'] for i in range(len(results))]
prompts_top_3 = [results[i]['prompt_top3_similar'] for i in range(len(results))]
prompts_top_5 = [results[i]['prompt_top5_similar'] for i in range(len(results))]
answers = [results[i]['answer'] for i in range(len(results))]


In [None]:
class CustomDataset:
    def __init__(self, prompts, answers):
        self.prompts = prompts
        self.answers = answers

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return {"prompt": self.prompts[idx], "answer": self.answers[idx]}

dataset = CustomDataset(prompts_oracle, answers)
oracle_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
dataset = CustomDataset(prompts_top_1, answers)
top1_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
dataset = CustomDataset(prompts_top_3, answers)
top3_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
dataset = CustomDataset(prompts_top_5, answers)
top5_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

In [None]:
!pip install langchain

In [None]:
from langchain.evaluation import ExactMatchStringEvaluator

evaluator = ExactMatchStringEvaluator(
    ignore_case=True,
    ignore_punctuation=True,
)

premessage = """
You have to answer complex questions based on the provided contexts.
Respond with as few tokens as possible. You don't need to explain your answer. 
Don't add any extra information.
"""

In [None]:
exact_match_result = 0
dictionary = {}

for batch in oracle_dataloader:
    batch_prompts = batch['prompt']
    batch_answers = batch['answer']
    responses, ttime = query_model_batch(premessage, batch_prompts, temperature=0.1, max_length=32)
    
    for i in range(len(responses)):
        response = responses[i]
        answer = batch_answers[i]
        
        #print("Generated: ", response)
        #print("Correct: ", answer)
        
        dictionary[batch_prompts[i]] = [response, answer]
        ex = float(evaluator.evaluate_strings(prediction=response, reference=answer)['score'])
        exact_match_result += ex
        

print("Oracle: ", exact_match_result/1200)

with open('/kaggle/working/oracle_esit.json', 'w') as f:
    json.dump(dictionary, f, indent=4)


In [None]:
exact_match_result = 0
dictionary = {}

for batch in top1_dataloader:
    batch_prompts = batch['prompt']
    batch_answers = batch['answer']
    responses, ttime = query_model_batch(premessage, batch_prompts, temperature=0.1, max_length=32)
    
    for i in range(len(responses)):
        response = responses[i]
        answer = batch_answers[i]
        
        #print("Generated: ", response)
        #print("Correct: ", answer)
        
        dictionary[batch_prompts[i]] = [response, answer]
        ex = float(evaluator.evaluate_strings(prediction=response, reference=answer)['score'])
        exact_match_result += ex
        

print("Top-1: ", exact_match_result/1200)

with open('/kaggle/working/top1_esit.json', 'w') as f:
    json.dump(dictionary, f, indent=4)

In [None]:
exact_match_result = 0
dictionary = {}

for batch in top3_dataloader:
    batch_prompts = batch['prompt']
    batch_answers = batch['answer']
    responses, ttime = query_model_batch(premessage, batch_prompts, temperature=0.1, max_length=32)
    
    for i in range(len(responses)):
        response = responses[i]
        answer = batch_answers[i]
        
        #print("Generated: ", response)
        #print("Correct: ", answer)
        
        dictionary[batch_prompts[i]] = [response, answer]
        ex = float(evaluator.evaluate_strings(prediction=response, reference=answer)['score'])
        exact_match_result += ex
        

print("Top-3: ", exact_match_result/1200)

with open('/kaggle/working/top3_esit.json', 'w') as f:
    json.dump(dictionary, f, indent=4)

In [None]:
exact_match_result = 0
dictionary = {}

for batch in top5_dataloader:
    batch_prompts = batch['prompt']
    batch_answers = batch['answer']
    responses, ttime = query_model_batch(premessage, batch_prompts, temperature=0.1, max_length=32)
    
    for i in range(len(responses)):
        
        response = responses[i]
        answer = batch_answers[i]
        
        print("Generated: ", response)
        print("Correct: ", answer)
        
        dictionary[batch_prompts[i]] = [response, answer]
        ex = float(evaluator.evaluate_strings(prediction=response, reference=answer)['score'])
        exact_match_result += ex
        

print("Top-5: ", exact_match_result/1200)

with open('/kaggle/working/top5_esit.json', 'w') as f:
    json.dump(dictionary, f, indent=4)