In [24]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce RTX 3080


In [25]:
from datasets import load_dataset

dataset = load_dataset('florin-hf/nq_open_gold')

#train_questions = dataset['train']['question']
#validation_questions = dataset['validation']['question']
test_questions = dataset['test']['question']
example=dataset['test'][0]
example

{'question': 'who got the first nobel prize in physics',
 'idx_gold_in_corpus': 20994698,
 'answers': ['Wilhelm Conrad Röntgen'],
 'text': 'The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad Röntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Skłodowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .',
 'example_id': -3290814144789249484}

In [26]:
import json
import random
import hashlib
from typing import List, Tuple, Dict, Any, Optional
from torch.utils.data import Dataset
from transformers import AutoTokenizer

import normalize_text
from normalize_answers import *

class QueryDataset(Dataset):
    """
    A dataset class for managing queries data into structured prompts suitable for input to LLMS.

    Attributes:
        data_path (str): Path to the dataset file containing the query and related information.
        model_name (str): The name of the language model used for generating answers.
        do_normalize_query (bool): Flag to determine if text normalization is applied to the query.
    """
    def __init__(
        self, 
        data_path: str, 
        model_name: str,
        do_normalize_query: bool = False,
    ):
        super().__init__()
        self.data_path = data_path
        self.model_name = model_name
        self.do_normalize_query = do_normalize_query
        self._load_data()


    def _load_data(self):
        """
        Loads data from the specified path and processes it.
        """
        try:
            with open(self.data_path, "r") as fin:
                data = json.load(fin)
            self.process_file_data(data)
        except IOError as e:
            print(f"Error reading file {self.data_path}: {e}")


    def process_file_data(self, data: List[Dict]):
        """ Processes each example in the dataset to prepare prompts for the LLM. """  
        self.questions = []
        self.example_ids = []

        for example in data:
            self.example_ids.append(example['example_id'])

            if 'query' in example:
                question = example['query']
            elif 'question' in example:
                question = example['question']
            else:
                raise ValueError("No 'query' or 'question' key in example")
            
            if self.do_normalize_query:
                question = normalize_text.normalize(question)
            self.questions.append(question)


    def build_qa_prompt(self, query: str) -> str:
        task_instruction = "You are given a question and you must respond based on the provided documents. You must always provide an answer."
        prompt = f"""{task_instruction}\nQuestion: {query}\nAnswer:"""
        
        # Custom prompt format for mpt models
        if 'mpt' in self.model_name:
            INSTRUCTION_KEY = "### Instruction:"
            RESPONSE_KEY = "### Response:"
            INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
            PROMPT_FOR_GENERATION_FORMAT = """{intro}\n{instruction_key}\n{instruction}\n{response_key}""".format(
                intro=INTRO_BLURB,
                instruction_key=INSTRUCTION_KEY,
                instruction="{instruction}",
                response_key=RESPONSE_KEY,
            )
            prompt = PROMPT_FOR_GENERATION_FORMAT.format(
                instruction=prompt[:-8]
            )

        return prompt


    def __getitem__(self, idx: int):   
        prompt = self.build_qa_prompt(self.questions[idx])

        return {
            "example_id": self.example_ids[idx],
            "query": self.questions[idx],
            "prompt": prompt,
        }

    def __len__(self):
        return len(self.example_ids)


def hash_document(text: str) -> str:
    """
    Generate a SHA-256 hash for a given text.
    """
    return hashlib.sha256(text.encode()).hexdigest()

In [27]:
class PromptDataset(Dataset):
    """
    A dataset class for managing, preprocessing, and organizing document data into structured prompts suitable for input to LLMS.

    Attributes:
        corpus (List[Dict]): The list containing the document corpus.
        data_path (str): Path to the dataset file containing the query and related information.
        tokenizer (AutoTokenizer): The tokenizer used to tokenize the prompt, in order to check its tokenized length.
        max_tokenized_length (int): The maximum length of tokenized prompt. Prompts that exceed this length are excluded from the dataset.
        search_results (List[Tuple[List[str], List[float]]]): A list of tuples containing document indices and their scores. The results may come from a retriever.
        full_to_subset_idx_map (Dict[int, int]): Dictionary that maps the indices in the full corpus to the given subset (corpus).
        do_normalize_query (bool): Flag to determine if text normalization is applied to the query.
        num_documents_in_context (int): The total number of documents to consider in the context.
        gold_position (int): The specific position (0-indexed) of the gold document in the context.
        randomize_gold_position (bool): Flag to determine if the gold document position should be random.
        get_documents_without_answer (bool): Flag to determine if documents without the answer should be included in the prompt.
    """
    def __init__(
        self, 
        corpus: List[Dict],
        data_path: str,  
        tokenizer: AutoTokenizer,
        max_tokenized_length: int,
        search_results: List[Tuple[List[int], List[float]]],
        full_to_subset_idx_map: Dict[int, int] = None,
        do_normalize_query: bool = False,
        num_documents_in_context: int = 5,
        gold_position: int = None,
        randomize_gold_position: bool = False,
        get_documents_without_answer: bool = False,
    ):
        super().__init__()
        self.corpus = corpus
        self.data_path = data_path
        self.tokenizer = tokenizer
        self.max_tokenized_length = max_tokenized_length
        self.search_results = search_results
        self.full_to_subset_idx_map = full_to_subset_idx_map
        self.do_normalize_query = do_normalize_query
        self.num_documents_in_context = num_documents_in_context
        self.gold_position = gold_position
        self.randomize_gold_position = randomize_gold_position
        self.get_documents_without_answer = get_documents_without_answer
    
        
        self._validate_initialization_parameters()
        self._load_data()


    def _validate_initialization_parameters(self):
        """Validates initialization parameters for logical consistency and correctness."""
        if self.num_documents_in_context <= 0:
            raise ValueError("num_documents_in_context must be positive.")
        
        if self.max_tokenized_length <= 0:
            raise ValueError("max_tokenized_length must be positive.")

        if self.gold_position is not None:
            if self.gold_position < 0 or (self.gold_position >= self.num_documents_in_context):
                raise ValueError(f"Invalid gold position: {self.gold_position}")
        
        if self.gold_position is not None and self.randomize_gold_position:
            raise ValueError("Both 'gold_position' and 'randomize_gold_position' cannot be set at the same time.")


    def _load_data(self):
        """
        Loads data from the specified path and processes it.
        """
        try:
            with open(self.data_path, "r") as fin:
                data = json.load(fin)
            self.process_file_data(data)
        except IOError as e:
            print(f"Error reading file {self.data_path}: {e}")


    def process_file_data(self, data: List[Dict]):  
        """
        Processes each example in the dataset to prepare prompts for the LLM.

        This involves assembling document contexts, normalizing text as needed,
        and checking against the maximum token length to ensure compatibility with the LLM's input specifications.

        Args:
            data (List[Dict]): The dataset, where each entry contains information about an example,
            including the example's ID, the gold document index, answers, and the query.
        """
        self.example_ids = []
        self.queries = []
        self.prompts = []
        self.gold_document_idxs = []
        self.excluded_samples_ids = []
        self.preprocessed_data = []
        self.prompt_tokens_lengths = []

        for idx, example in enumerate(data):
            example_id = str(example['example_id'])
            gold_document_idx = str(example['idx_gold_in_corpus'])
            answers = example['answers']

            formatted_documents, document_indices = self.prepare_documents_for_prompt(
                idx, gold_document_idx, answers
            )

            # Build the prompt
            documents_str = '\n'.join(formatted_documents)
            query = example['question']
            if self.do_normalize_query:
                query = normalize_text.normalize(query)
            prompt = self.build_qa_prompt(query, documents_str)

            # Check if the prompt exceeds 'max_tokenized_length'
            tokens = self.tokenizer.tokenize(prompt)
            tokens_len = len(tokens)
            if tokens_len >= self.max_tokenized_length:
                self.excluded_samples_ids.append((idx, example_id))
                print("Skipping example {} due to prompt length.".format((idx, example_id)))
                continue  # Skip adding this example

            if len(formatted_documents) != self.num_documents_in_context:
                print(f"Warning: Not enough documents for example {idx}.")

            # If prompt is within limit, add to preprocessed data
            self.preprocessed_data.append((formatted_documents, list(document_indices)))
            self.example_ids.append(example_id)
            self.queries.append(query)
            self.prompts.append(prompt)
            self.gold_document_idxs.append(gold_document_idx)
            self.prompt_tokens_lengths.append(tokens_len)


    def prepare_documents_for_prompt(
        self, 
        example_idx: int, 
        gold_document_idx: int, 
        answers: List[str]
    ) -> Tuple[List[str], List[int]]:
        """
        Prepares and formats a set of documents for inclusion in a prompt, including the insertion of a gold document at the appropriate position.

        This function performs several key steps to prepare documents for a prompt:
        1. Retrieves document indices based on the example index.
        2. Inserts the gold document index into the retrieved list of indices at a specified or randomized position, if necessary.
        3. Formats the documents corresponding to the updated list of indices, preparing them for inclusion in the prompt. 
           This includes potentially filtering documents based on answers or other criteria.

        Args:
            example_idx (int): The index of the current example in the dataset. This is used to retrieve the appropriate set of document indices.
            gold_document_idx (int): The index of the gold document within the corpus. 
            answers (List[str]): A list of answers that can be used to ensure the relevance of documents included in the prompt.

        Returns:
            A tuple containing two lists:
            - The first list contains the formatted documents.
            - The second list contains the indices of the included documents.
        """
        indices = self._get_indices(example_idx)
        updated_indices, gold_position = self._insert_gold_document_idx(
            indices, gold_document_idx
        )

        # Get the documents and their indices in the corpus
        formatted_documents, document_indices = self._get_documents(
            updated_indices, answers, gold_document_idx, gold_position
        )
        return formatted_documents, document_indices


    def _get_indices(self, example_idx: int) -> List[int]:
        """ Get the indices in the corpus of the documents retrieved possibly by a retriever. """
        indices, scores = self.search_results[example_idx]
        return indices


    def _insert_gold_document_idx(
        self, 
        indices: List[int], 
        gold_document_idx: int
    ) -> Tuple[List[int], int]:
        """
        Inserts the index of a gold document into the provided list of indices at a specified or random position.

        Args:
            indices: A list of integers representing document indices.
            gold_document_idx: The index of the gold document to insert.

        Returns:
            A tuple containing:
            - The updated list of indices with the gold document index inserted.
            - The position at which the gold document index was inserted.
        """
        gold_position = None
        
        if self.gold_position is not None:
            # Direct insertion
            gold_position = self.gold_position
            indices = indices[:gold_position] + [gold_document_idx] + indices[gold_position:]
        elif self.randomize_gold_position:
            # Insert at a random position
            gold_position = random.randint(0, self.num_documents_in_context - 1)
            indices = indices[:gold_position] + [gold_document_idx] + indices[gold_position:]
        return indices, gold_position


    def _get_documents(    
        self,
        indices: List[int],
        answers: List[str],
        gold_document_idx: Optional[int],
        gold_position: Optional[int]
    ) -> Tuple[List[str], List[int]]:
        """ Choose the appropriate method based on the flag """
        if self.get_documents_without_answer:
            return self._get_answerless_documents_from_indices(
                indices, answers, gold_document_idx, gold_position
            )
        else:
            return self._get_documents_from_indices(indices)
            

    def _get_documents_from_indices(self, indices: List[int]) -> Tuple[List[str], List[int]]:
        """
        Selects documents from the corpus based on provided indices and formats them.
        Handles both full corpus and subsets by mapping indices if necessary.

        Args:
            indices: A list of integers representing the positions of documents to retrieve in the corpus.

        Returns:
            A tuple containing two lists:
            - The first list contains the formatted documents.
            - The second list contains the indices of the included documents.
        """
        formatted_documents = []
        
        # Full corpus
        if self.full_to_subset_idx_map is None:
            documents_info = [self.corpus[i] for i in map(int, indices)]
        else: 
            documents_info: List[Dict] = []
            # 'indices' are from the full corpus, so we need to map them to the subset
            for i in map(int, indices):
                documents_info.append(self.corpus[self.full_to_subset_idx_map[i]])
        
        seen_hashes = set()
        # List to store the indices of documents actually added
        document_indices = []  
        for doc_info in documents_info:
            if len(formatted_documents) == self.num_documents_in_context:
                break
            
            doc_idx = doc_info['full_corpus_idx']
            title = doc_info['title']
            text = doc_info['text']

            doc_hash = hash_document(text)
            # Skip the document if it is a duplicate
            if doc_hash in seen_hashes:
                continue
            seen_hashes.add(doc_hash)
            
            doc_str = f"Document [{doc_idx}](Title: {title}) {text}"
            formatted_documents.append(doc_str)
            document_indices.append(doc_idx)

        return formatted_documents, document_indices
    

    def _get_answerless_documents_from_indices(
        self,
        indices: List[int],
        answers: List[str],
        gold_document_idx: Optional[int],
        gold_position: Optional[int]
    ) -> Tuple[List[str], List[int]]:
        """
        Selects documents from the corpus that do not contain any of the given answers, optionally including
        a specific 'gold' document at a designated position.

        Args:
            indices: A list of integers representing the indices of documents to retrieve from the corpus.
            answers: A list of strings representing the answers to exclude from the documents.
            gold_document_idx: The index of the gold document in the full corpus.
            gold_position: The desired position of the gold document within the returned list, if any.

        Returns:
            A tuple containing two lists:
            - The first list contains the documents that do not contain the answer and possibly the gold.
            - The second list contains the indices of the included documents.
        """
        # Full corpus
        if self.full_to_subset_idx_map is None:
            documents_info = [self.corpus[i] for i in map(int, indices)]
        else: 
            documents_info: List[Dict] = []
            # 'indices' are from the full corpus, so we need to map them to the subset
            for i in map(int, indices):
                documents_info.append(self.corpus[self.full_to_subset_idx_map[i]])

        answerless_documents = []
        gold_document = None
        seen_hashes = set()
        # List to store the indices of documents actually added
        document_indices = [] 

        for doc_info in documents_info:
            doc_idx = doc_info['full_corpus_idx']
            title = doc_info['title']
            text = doc_info['text']

            doc_hash = hash_document(text)
            # Skip the document if it's a duplicate
            if doc_hash in seen_hashes:
                continue
            seen_hashes.add(doc_hash)

            if str(doc_idx) == gold_document_idx:
                gold_document = f"Document [{doc_idx}](Title: {title}) {text}"
                continue
            
            if not is_answer_in_text(text, answers):
                answerless_doc = f"Document [{doc_idx}](Title: {title}) {text}"
                answerless_documents.append(answerless_doc)
                document_indices.append(doc_idx)

        # Insert gold document at the specified/random position
        if gold_position is not None and gold_document is not None:
            gold_position = min(gold_position, len(answerless_documents))
            answerless_documents.insert(gold_position, gold_document)
            document_indices.insert(gold_position, gold_document_idx)

        # Limit the number of documents to the specified context size
        docs = answerless_documents[:self.num_documents_in_context]
        indices = document_indices[:self.num_documents_in_context]
        return docs, indices



    def build_qa_prompt(self, query: str, documents_str: str) -> str:
        task_instruction = "You are given a question and you must respond based on the provided documents. You must always provide an answer."
        prompt = f"""{task_instruction}\nDocuments:\n{documents_str}\nQuestion: {query}\nAnswer:"""

        # Custom prompt format for mpt models
        if 'mpt' in self.tokenizer.name_or_path:
            INSTRUCTION_KEY = "### Instruction:"
            RESPONSE_KEY = "### Response:"
            INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
            PROMPT_FOR_GENERATION_FORMAT = """{intro}\n{instruction_key}\n{instruction}\n{response_key}""".format(
                intro=INTRO_BLURB,
                instruction_key=INSTRUCTION_KEY,
                instruction="{instruction}",
                response_key=RESPONSE_KEY,
            )
            prompt = PROMPT_FOR_GENERATION_FORMAT.format(
                instruction=prompt[:-8]
            )

        return prompt


    def __getitem__(self, idx: int):
        _, document_indices = self.preprocessed_data[idx]

        return {
            "example_id": self.example_ids[idx],
            "query": self.queries[idx],
            "prompt": self.prompts[idx],
            "document_indices": document_indices,
            "gold_document_idx": self.gold_document_idxs[idx],
            "prompt_tokens_len": self.prompt_tokens_lengths[idx]
        }
    

    def __len__(self):
        return len(self.example_ids)

In [None]:
import os 
import argparse
import warnings
from tqdm import tqdm
from typing import Tuple, Dict, Optional

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from llm import LLM
from utils import *
from prompt_dataset import PromptDataset


os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings('ignore')
SEED=10

info = {
    "data_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json",
    "contriever_search_results_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\processed\contriever_test_search_results_at150.pkl",
}

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):
    """
    Mimics argparse to parse arguments for LLM generation. Accepts custom arguments as a dictionary for notebooks.
    """
    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm',
        'llm_id': 'google/gemma-2-2b-it',
        'model_max_length': 4096,
        'gold_position': None,
        'num_documents_in_context': None,
        'get_documents_without_answer': True,
        'max_new_tokens': 15,
        'batch_size': None,
        'save_every': 250,
    }

    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    # Perform validation
    if default_args['num_documents_in_context'] is None:
        raise ValueError("'num_documents_in_context' must be specified.")
    if default_args['num_documents_in_context'] <= 0:
        raise ValueError("'num_documents_in_context' must be a positive integer.")
    if default_args['gold_position'] is not None:
        if (default_args['gold_position'] < 0 or 
            default_args['gold_position'] >= default_args['num_documents_in_context']):
            raise ValueError("'gold_position' must be within the range of 'num_documents_in_context'.")

    return DotDict(**default_args)


def load_corpus(
    args: argparse.Namespace
) -> Tuple[List[Dict], Optional[Dict[int, int]]]:
    
    # Corpus with documents from Contriever
    corpus, full_to_subset_idx_map = read_test_corpus_with_random_and_contriever()

    return corpus, full_to_subset_idx_map


def load_search_results(args: argparse.Namespace) -> List[Tuple[List[int], List[float]]]:
    # Search results from Contriever
    search_results_path = info['contriever_search_results_path'] 

    search_results = read_pickle(search_results_path)
    return search_results


def initialize_dataset_and_loader(
    args: argparse.Namespace, 
    corpus: List[Dict], 
    full_to_subset_idx_map: Optional[Dict[int, int]], 
    search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer
) -> DataLoader:
    
    prompt_ds = PromptDataset(
        corpus=corpus, data_path=info['data_path'], 
        tokenizer=tokenizer, 
        max_tokenized_length=args.model_max_length - 2, 
        search_results=search_results,
        full_to_subset_idx_map=full_to_subset_idx_map,
        do_normalize_query=True, 
        num_documents_in_context=args.num_documents_in_context,
        gold_position=args.gold_position,
        get_documents_without_answer=args.get_documents_without_answer,
    )
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_dataloader


def print_info(args: argparse.Namespace):
    print("INFO:")
    print(f"DATA: {info['data_path']}")
    print(f"MODEL: {args.llm_id}")
    print(f"GOLD POSITION: {args.gold_position}")
    print(f"NUM DOCUMENTS IN CONTEXT: {args.num_documents_in_context}")
    print(f"DOCUMENTS WITHOUT ANSWER: {args.get_documents_without_answer}")
    print(f"BATCH SIZE: {args.batch_size}")
    print(f"SAVE EVERY: {args.save_every}")


def generate_and_save(
    args: argparse.Namespace, 
    llm: LLM, 
    prompt_dataloader: DataLoader
):
    # Info from arguments
    llm_id = args.llm_id
    num_doc = args.num_documents_in_context
    save_every = args.save_every
    gold_pos = args.gold_position
    retriever_str = "contriever"
    answerless_str = "_answerless" if args.get_documents_without_answer else ""

    # Create the saving directory
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{args.output_dir}/{llm_folder}/train/classic/{retriever_str}/{num_doc}_doc"
    if not os.path.exists(saving_dir):
        os.makedirs(saving_dir)

    
    # MPT has a different answer string in the prompt
    answer_string_in_prompt = "### Response:" if 'mpt' in llm_id else "Answer:"

    all_info = []  
    for idx, prompt_batch in enumerate(tqdm(prompt_dataloader)):
        prompts = prompt_batch['prompt']
        generated_output = llm.generate(prompts, max_new_tokens=args.max_new_tokens)
        
        generated_answers = []
        for output in generated_output:
            start = output.find(answer_string_in_prompt) + len(answer_string_in_prompt)
            response = output[start:].strip()
            generated_answers.append(response)

        prompt_batch['generated_answer'] = generated_answers
        all_info.append(prompt_batch)
        
        if (idx + 1) % save_every == 0 or (idx + 1) == len(prompt_dataloader):
            print(f"Saving at {idx + 1}...")
            file_name = f"{saving_dir}/numdoc{num_doc}_gold_at{gold_pos}{answerless_str}_info_{idx+1}.pkl"
            write_pickle(all_info, file_name)
            all_info = []


def main():
    args = parse_arguments({
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm',
        'llm_id': 'google/gemma-2-2b-it',
        'model_max_length': 4096,
        'gold_position': None,
        'num_documents_in_context': 5,
        'get_documents_without_answer': True,
        'max_new_tokens': 15,
        'batch_size': None,
        'save_every': 250,
    })

    print("Loading LLM...")
    llm_id = args.llm_id
    llm = LLM(
        llm_id, device, quantization_bits=4, 
        model_max_length=args.model_max_length
    )
    tokenizer = llm.tokenizer
    print("LLM loaded")


    print("Loading corpus and search results...")
    corpus, full_to_subset_idx_map = load_corpus(args)
    search_results = load_search_results(args)
    print("Corpus and search results loaded")


    print("Loading prompt dataset...")
    prompt_dataloader = initialize_dataset_and_loader(
        args, corpus, full_to_subset_idx_map, search_results, tokenizer
    )
    print("Prompt dataset loaded")

    print_info(args)
    generate_and_save(args, llm, prompt_dataloader)



if __name__ == "__main__":
    seed_everything(SEED)
    main()

Loading LLM...


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.96s/it]


LLM loaded
Loading corpus and search results...
Corpus and search results loaded
Loading prompt dataset...
Prompt dataset loaded
INFO:
DATA: C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json
MODEL: google/gemma-2-2b-it
GOLD POSITION: None
NUM DOCUMENTS IN CONTEXT: 5
DOCUMENTS WITHOUT ANSWER: True
BATCH SIZE: None
SAVE EVERY: 250


  9%|▊         | 247/2889 [03:50<30:49,  1.43it/s] 

In [None]:
import os
import re
import json
import pickle
import argparse

import torch
import pandas as pd
from typing import List, Dict

from utils import str2bool
from normalize_answers import *


def are_answers_matching(prediction: str, ground_truths: List[str]) -> float:
    normalized_prediction = normalize_answer(prediction)

    for ground_truth in ground_truths:
        normalized_ground_truth = normalize_answer(ground_truth)
        if normalized_ground_truth in normalized_prediction:
            return True
    return False


def read_generation_results(file_path: str, df: pd.DataFrame) -> List[Dict]:
    data = []
    with open(file_path, "r") as fin:
        file_data = json.load(fin)

        for example in file_data:
            example_ids = example['example_id']
            queries = example['query']
            prompts = example['prompt']
            document_indices = list(zip(*example['document_indices']))
            gold_document_indices = example['gold_document_idx']
            generated_answers = example['generated_answer']
            prompt_tokens_lens = example['prompt_tokens_len']

            for i in range(len(example_ids)):
                example_id = example_ids[i]
                query = queries[i]
                gold_document_idx = gold_document_indices[i]
                documents_idx = list(document_indices[i])
                generated_answer = generated_answers[i]
                prompt = prompts[i]
                prompt_tokens_len = prompt_tokens_lens[i]

                answers = df[df['example_id'] == int(example_id)].answers.iloc[0]
                gold_in_retrieved = False

                if int(gold_document_idx) in map(int, documents_idx):
                    gold_in_retrieved = True

                ans_match_after_norm: bool = are_answers_matching(generated_answer, answers)
                ans_in_documents: bool = is_answer_in_text(prompt, answers)
                data.append({
                    'example_id': int(example_id),
                    'query': query,
                    'prompt': prompt,
                    'document_indices': documents_idx,
                    'gold_document_idx': gold_document_idx,
                    'generated_answer': generated_answer,
                    'answers': answers,
                    'ans_match_after_norm': ans_match_after_norm,
                    'gold_in_retrieved': gold_in_retrieved,
                    'ans_in_documents': ans_in_documents,
                    "prompt_tokens_len": prompt_tokens_len,
                })

    return data


def convert_tensors(cell):
    """ Converts tensors in the given cell to lists, if they are tensors. """
    if isinstance(cell, list):
        return [[t.tolist() if torch.is_tensor(t) else t for t in inner_list] for inner_list in cell]
    return cell


def extract_number_from_filename(filename: str, pattern: re.Pattern) -> int:
    """ Extracts the number from the filename based on the provided pattern. """
    match = pattern.search(filename)
    return int(match.group(1)) if match else 0


def load_pickle_files(directory: str, filename_prefix: str) -> pd.DataFrame:
    """ Loads and concatenates data from all pickle files in the directory with the given prefix. """
    pattern = re.compile(r'(\d+).pkl')
    files = [f for f in os.listdir(directory) if f.endswith('.pkl') and filename_prefix in f]
    files.sort(key=lambda f: extract_number_from_filename(f, pattern))
    print("I'm using the following files: ", files)

    data_list = []
    for file in files:
        with open(os.path.join(directory, file), 'rb') as f:
            data = pickle.load(f)
            data_list.extend(data)
    
    data_df = pd.DataFrame(data_list)

    data_df['document_indices'] = data_df['document_indices'].apply(convert_tensors)
    print(data_df['document_indices'].head())
    print(data_df['document_indices'].apply(type).unique())

    if 'prompt_tokens_len' in data_df.columns:
        data_df['prompt_tokens_len'] = data_df['prompt_tokens_len'].apply(lambda x: x.tolist())
    return data_df


def save_data_to_json(data_df: pd.DataFrame, directory: str, filename_prefix: str):
    """ Saves the given DataFrame to a JSON file. """
    data_path = os.path.join(directory, f'{filename_prefix}all.json')
    # Check if the file already exists
    if os.path.exists(data_path):
        overwrite = input(f"File {data_path} already exists. Overwrite? (y/n): ")
        if overwrite.lower() != 'y':
            print("No overwrite.")

            results_df = pd.read_json(f'{directory}/{filename_prefix}all_extended.json')
            accuracy = round(results_df['ans_match_after_norm'].sum() / len(results_df), 4)
            print("ACCURACY: ", accuracy)
            return None
        
    data_df.to_json(data_path, orient='records')
    return data_path


def get_classic_path(args):
    gold_pos = args.gold_position
    answerless_str = "_answerless" if args.get_documents_without_answer else ""

    filename_prefix = f'numdoc{args.num_doc}_gold_at{gold_pos}{answerless_str}_info_'
    return filename_prefix

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):

    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm',
        'llm_id': 'google/gemma-2-2b-it',
        'model_max_length': 4096,
        'gold_position': None,
        'num_documents_in_context': 5,
        'get_documents_without_answer': True,
    }
    
    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    return DotDict(**default_args)

def main():
    args = parse_arguments({
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm',
        'llm_id': 'google/gemma-2-2b-it',
        'gold_position': None,
        'num_documents_in_context': 5,
        'get_documents_without_answer': True,
    }

    )
    
    retriever_str = ""
    
    prompt_type = "classic"

    retriever_str = "contriever/"
    args.num_doc = args.num_documents_in_context
    filename_prefix = get_classic_path(args)


    llm_id = args.llm_id
    split = "test"
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    doc_str = f"{args.num_doc}_doc"
    directory = f'{args.output_dir}/{llm_folder}/{split}/{prompt_type}/{retriever_str}{doc_str}'
    print("Directory: ", directory)

    data_df = load_pickle_files(directory, filename_prefix)
    data_path = save_data_to_json(data_df, directory, filename_prefix)
    if data_path is None:
        return
    

    df = pd.read_json(r"C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json")

    results = read_generation_results(data_path, df)

    results_df = pd.DataFrame(results)
    accuracy = round(results_df['ans_match_after_norm'].sum() / len(results_df), 4)
    print("ACCURACY: ", accuracy)
    results_df.to_json(os.path.join(directory, f'{filename_prefix}all_extended.json'), orient='records')


if __name__ == "__main__":
    main()

Directory:  C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm/gemma-2-2b-it/test/classic/contriever/5_doc
I'm using the following files:  ['numdoc5_gold_atNone_answerless_info_250.pkl', 'numdoc5_gold_atNone_answerless_info_500.pkl', 'numdoc5_gold_atNone_answerless_info_750.pkl', 'numdoc5_gold_atNone_answerless_info_1000.pkl', 'numdoc5_gold_atNone_answerless_info_1250.pkl', 'numdoc5_gold_atNone_answerless_info_1500.pkl', 'numdoc5_gold_atNone_answerless_info_1750.pkl', 'numdoc5_gold_atNone_answerless_info_2000.pkl', 'numdoc5_gold_atNone_answerless_info_2250.pkl', 'numdoc5_gold_atNone_answerless_info_2500.pkl', 'numdoc5_gold_atNone_answerless_info_2750.pkl', 'numdoc5_gold_atNone_answerless_info_2889.pkl']


TypeError: 'int' object is not iterable

In [None]:
# Show the results of generate_answer_llm

result_path=r"C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm\gemma-2-2b-it\test\classic\contriever\5_doc\numdoc5_gold_atNone_answerless_info_250.pkl"
data=read_pickle(result_path)

data[0]

{'example_id': '-3290814144789249484',
 'query': 'who got the first nobel prize in physics',
 'prompt': 'You are given a question and you MUST respond by EXTRACTING the answer (max 5 tokens) from one of the provided documents. If none of the documents contain the answer, respond with NO-RES.\nDocuments:\nDocument [2043329](Title: Universology) become the chief proponent of universology today. "Everything in this universe is part of an uninterrupted sequence of events" Mohri has said. In 1872 Andrews published "The Basic Outline of Universology" which was subtitled "An introduction to the newly discovered science of the universe, its elementary principles, and the first stages of their development in the special sciences." Ilya Romanovich Prigogine (born on January 25, 1917) was a Belgian and American physicist and chemist who was born in Russia and became a Nobel Prize laureate in chemistry. In the book "Order Out of Chaos: Man\'s New Dialogue With Nature", which\nDocument [1860765](Ti

In [None]:
import torch
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM
)
from typing import List, Tuple, Dict


class BridgeModel:
    """
    Bridge Model for selecting and ranking documents by generating document IDs.
    """
    def __init__(
        self, 
        model_id: str, 
        device: str = 'cuda', 
        model_max_length: int = 512
    ):
        self.device = device
        self.model_max_length = model_max_length

        # Initialize the seq2seq model and tokenizer
        self.model, self.tokenizer = self._initialize_model_tokenizer(model_id)

    def _initialize_model_tokenizer(self, model_id: str) -> Tuple[AutoModelForSeq2SeqLM, AutoTokenizer]:
        """
        Initializes the seq2seq model and tokenizer with the given model ID.
        """
        model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        model_config.max_seq_len = self.model_max_length

        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            config=model_config,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        ).to(self.device)
        model.eval()  # Set the model to evaluation mode

        tokenizer = AutoTokenizer.from_pretrained(
            model_id, 
            model_max_length=self.model_max_length,
            padding_side="left",
            truncation_side="left"
        )
        tokenizer.pad_token = tokenizer.eos_token  # Set pad token

        return model, tokenizer

    def generate(
        self, 
        prompt: str, 
        max_new_tokens: int = 15
    ) -> List[str]:
        """
        Generates the ordered document IDs based on the query and documents.
        """
        # Tokenize input
        inputs = self.tokenizer(
            prompt, 
            return_tensors="pt", 
            max_length=self.model_max_length, 
            padding=True, 
            truncation=True
        ).to(self.device)

        # Generate output
        generated_ids = self.model.generate(
            **inputs,
            do_sample=False,  # Deterministic output
            max_new_tokens=max_new_tokens,
            repetition_penalty=1.1,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )

        # Decode output
        return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].split()

# Example Usage
if __name__ == "__main__":
    model_id = "t5-small"  # Replace with a seq2seq model like T5 or BART
    bridge = BridgeModel(model_id=model_id)

    output = bridge.generate(prompt)
    print("Output document IDs:", output)

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 