In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from string import Template
from torch.utils.data import DataLoader

In [None]:
llm = '/kaggle/input/flan-t5/pytorch/xl/3'

device = "cuda" if torch.cuda.is_available() else "cpu"
model = T5ForConditionalGeneration.from_pretrained(llm, device_map="auto")
tokenizer = T5Tokenizer.from_pretrained(llm)
print(device)

In [None]:
def generate(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(**inputs)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


def query_model_batch(premessage, batch_prompts):
    responses = []
    for prompt in batch_prompts:
        responses.append(generate(premessage + prompt))
    return responses

In [None]:
class CustomDataset:
    def __init__(self, prompts, answers):
        self.prompts = prompts
        self.answers = answers

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return {"prompt": self.prompts[idx], "answer": self.answers[idx]}
    

In [None]:
premessage = """
You have to answer complex questions based on the provided contexts.
Respond with as few tokens as possible. You don't need to explain your answer. 
Don't add any extra information.
"""

import json
import pandas as pd

def read_json(file_path: str) -> dict:
    """Read a json file and return a dict."""
    with open(file_path) as f:
        data = json.load(f)
    return data

oracle = read_json('/kaggle/input/d/thanosgeorgoutsos/rag-retrieve-responses/oracle.json')
oracle_random = read_json('/kaggle/input/d/thanosgeorgoutsos/rag-retrieve-responses/oracle_random_3.json')
oracle_adore = read_json('/kaggle/input/d/thanosgeorgoutsos/rag-retrieve-responses/oracle_adore_3.json')
top3 = read_json('/kaggle/input/d/thanosgeorgoutsos/rag-retrieve-responses/top3_3.json')
top3_random = read_json('/kaggle/input/d/thanosgeorgoutsos/rag-retrieve-responses/top3_random_3.json')
top3_adore = read_json('/kaggle/input/d/thanosgeorgoutsos/rag-retrieve-responses/top3_adore_3.json')
response = read_json('/kaggle/input/r-and-h/random_contexts.json')
corpus = read_json('/kaggle/input/wiki-data/wiki_musique_corpus.json')
dev = read_json('/kaggle/input/wiki-data/dev.json')[:1200]

# faster way to index the data
dev_df = pd.DataFrame(dev)
dev_df.set_index('_id', inplace=True)

def get_ground_truth(idx, supporting_facts=False):
    """
    Extract the question, oracle contexts and answer from dev.
    The oracle contexts are tuples (title, text).
    """
    row = dev_df.loc[idx]
    if supporting_facts:
        sfs = [t for t, _ in row['supporting_facts']]
        oracle = [(title, ' '.join(texts)) for title, texts in row['context'] if title in sfs]
    else:
        oracle = [(title, ' '.join(texts)) for title, texts in row['context']]
    return row['question'], oracle, row['answer']

def get_top_k_contexts(responses, k=5):
    """Extract the top k similar contexts from the corpus. Each
    context is a tuple (title, text).
    """
    top_k_contexts = []
    for i, (idx, score) in enumerate(responses.items()):
        if i < k:
            tittle = corpus[idx]['title']
            text = corpus[idx]['text']
            top_k_contexts.append((tittle, text))
    return top_k_contexts

def create_prompt(contexts, question):
    """Create a prompt for the model."""
    prompt = "Documents:\n"
    for i, (title, text) in enumerate(contexts):
        prompt += f"Document[{i+1}](Title: {title}) {text}\n"
    prompt += f"\nQuestion: {question}"
    return prompt

In [None]:
oracle_list = []
questions = []
for _id, similar_ctx_ids in oracle.items():
    question, oracle_ctxs, answer = get_ground_truth(_id, supporting_facts=True)
    questions.append(question)
    
    prompt_oracle =  premessage +similar_ctx_ids['prompt']
    
    # save prompts in a list
    oracle_list.append({
        "prompt_oracle": prompt_oracle,
        "answer": answer
    })

# save results in json file with nice indentation
with open('/kaggle/working/prompts_oracle.json', 'w') as f:
    json.dump(oracle_list, f, indent=4)
    
prompts_oracle = [oracle_list[i]['prompt_oracle'] for i in range(len(oracle_list[:100]))]
answers = [oracle_list[i]['answer'] for i in range(len(oracle_list))]

oracle_random_list = []
for _id, similar_ctx_ids in oracle_random.items():
    question, oracle_ctxs, answer = get_ground_truth(_id, supporting_facts=True)
    
    prompt_oracle =  premessage +similar_ctx_ids['prompt']
    
    # save prompts in a list
    oracle_random_list.append({
        "prompt_oracle_random": prompt_oracle,
        "answer": answer
    })

# save results in json file with nice indentation
with open('/kaggle/working/prompts_oracle_random.json', 'w') as f:
    json.dump(oracle_random_list, f, indent=4)
    
prompts_oracle_random = [oracle_random_list[i]['prompt_oracle_random'] for i in range(len(oracle_random_list[:100]))]

oracle_adore_list = []
for _id, similar_ctx_ids in oracle_adore.items():
    question, oracle_ctxs, answer = get_ground_truth(_id, supporting_facts=True)
    
    prompt_oracle =  premessage +similar_ctx_ids['prompt']
    
    # save prompts in a list
    oracle_adore_list.append({
        "prompt_oracle_adore": prompt_oracle,
        "answer": answer
    })

# save results in json file with nice indentation
with open('/kaggle/working/prompts_oracle_adore.json', 'w') as f:
    json.dump(oracle_adore_list, f, indent=4)
    
prompts_oracle_adore = [oracle_adore_list[i]['prompt_oracle_adore'] for i in range(len(oracle_adore_list[:100]))]

top3_list = []
for _id, similar_ctx_ids in top3.items():
    question, oracle_ctxs, answer = get_ground_truth(_id, supporting_facts=True)
    
    prompt_oracle =  premessage +similar_ctx_ids['prompt']
    
    # save prompts in a list
    top3_list.append({
        "prompt_top3": prompt_oracle,
        "answer": answer
    })

# save results in json file with nice indentation
with open('/kaggle/working/prompts_top3.json', 'w') as f:
    json.dump(top3_list, f, indent=4)
    
prompts_top3 = [top3_list[i]['prompt_top3'] for i in range(len(top3_list[:100]))]

top3_random_list = []
for _id, similar_ctx_ids in top3_random.items():
    question, oracle_ctxs, answer = get_ground_truth(_id, supporting_facts=True)
    
    prompt_oracle =  premessage +similar_ctx_ids['prompt']
    
    # save prompts in a list
    top3_random_list.append({
        "prompt_top3_random": prompt_oracle,
        "answer": answer
    })

# save results in json file with nice indentation
with open('/kaggle/working/prompts_top3_random.json', 'w') as f:
    json.dump(top3_random_list, f, indent=4)
    
prompts_top3_random = [top3_random_list[i]['prompt_top3_random'] for i in range(len(top3_random_list[:100]))]

top3_adore_list = []
for _id, similar_ctx_ids in top3_adore.items():
    question, oracle_ctxs, answer = get_ground_truth(_id, supporting_facts=True)
    
    prompt_oracle =  premessage +similar_ctx_ids['prompt']
    
    # save prompts in a list
    top3_adore_list.append({
        "prompt_top3_adore": prompt_oracle,
        "answer": answer
    })

# save results in json file with nice indentation
with open('/kaggle/working/prompts_top3_adore.json', 'w') as f:
    json.dump(top3_adore_list, f, indent=4)
    
prompts_top3_adore = [top3_adore_list[i]['prompt_top3_adore'] for i in range(len(top3_adore_list[:100]))]

In [None]:
titles = ['oracle', 'oracle_random', 'oracle_adore', 'top3', 'top3_random', 'top3_adore']
prs = [prompts_oracle, prompts_oracle_random, prompts_oracle_adore, prompts_top3, prompts_top3_random, prompts_top3_adore]

for j, p in enumerate(prs): 
    dataset = CustomDataset(p, answers)
    top5_dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

    dictionary = {}
    counter = 0

    for batch in top5_dataloader:
        batch_prompts = batch['prompt']
        batch_answers = batch['answer']
        responses = query_model_batch(premessage, batch_prompts)
        counter += len(batch_prompts)
        print(counter)

        for i in range(len(responses)):

            response = responses[i][0]
            answer = batch_answers[i]
#             print(response, answer)
            dictionary[batch_prompts[i]] = [response, answer]


    with open(f'/kaggle/working/{titles[j]}.json', 'w') as f:
        json.dump(dictionary, f, indent=4)
    print("Done!")