In [1]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name())

True
1
NVIDIA GeForce RTX 3080


In [2]:
import os
import re 
import argparse
import warnings
from tqdm import tqdm
from typing import Tuple, Dict, Optional

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from utils import *
from llm import LLM
from default_prompts import *
from prompt_dataset import PromptDataset


os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings('ignore')
SEED=10

info = {
    "nq": {
        "test": {
            "data_path": r'C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json',
            "contriever_search_results_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\processed\contriever_test_search_results_at150.pkl",
        }
    },
}

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):
    """
    Mimics argparse to parse arguments for LLM generation. Accepts custom arguments as a dictionary for notebooks.
    """
    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm',
        'llm_id': 'google/gemma-2-2b-it',
        'dataset': 'nq',
        'model_max_length': 4096,
        'quantization_bits': 4,
        'use_model_chat_template': True, 
        'gold_position': None,
        'num_retrieved_documents': 5,
        'use_test': True,
        'padding_strategy': 'longest',
        'max_new_tokens': 50,
        'use_task_with_proof': False,
        'batch_size': None,
        'save_every': 250,
    }

    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    # Perform validation
    if default_args['num_retrieved_documents'] is None:
        raise ValueError("'num_retrieved_documents' must be specified.")
    if default_args['num_retrieved_documents'] <= 0:
        raise ValueError("'num_retrieved_documents' must be a positive integer.")
    if default_args['gold_position'] is not None:
        if (default_args['gold_position'] < 0 or 
            default_args['gold_position'] >= default_args['num_retrieved_documents']):
            raise ValueError("'gold_position' must be within the range of 'num_retrieved_documents'.")

    return DotDict(**default_args)


def load_corpus(
    args: argparse.Namespace
) -> Tuple[List[Dict], Optional[Dict[int, int]]]:
    
    # Corpus with documents from Contriever
    corpus, full_to_subset_idx_map = read_test_corpus_with_random_and_contriever()

    return corpus, full_to_subset_idx_map

def load_search_results(args: argparse.Namespace) -> List[Tuple[List[int], List[float]]]:
    # Search results from Contriever
    search_results_path = info['nq']['test']['contriever_search_results_path'] 

    search_results = read_pickle(search_results_path)
    return search_results


def get_prompt_template(args: argparse.Namespace):
    prompt_configuration = args.dataset

    if args.use_model_chat_template:
        chat_task_template_str = chat_task_templates[args.llm_id]['template']
        
        task_instruction = task_instructions[prompt_configuration]
        if args.use_task_with_proof:
            task_instruction = task_instructions['qa_proof'][prompt_configuration]

        prompt_template = apply_chat_task_template(chat_task_template_str, task_instruction)
    else:
        task_template = task_templates[prompt_configuration]

        if args.use_task_with_proof:
            task_template = task_templates['qa_proof'][prompt_configuration]

        prompt_template = task_template.create_prompt_template()

    return prompt_template


def initialize_dataset_and_loader(
    args: argparse.Namespace, 
    corpus: List[Dict], 
    full_to_subset_idx_map: Optional[Dict[int, int]], 
    retriever_search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer
) -> DataLoader:
    
    prompt_template = get_prompt_template(args)
    
    prompt_ds = PromptDataset(
        corpus=corpus, data_path=info[args.dataset]['test']['data_path'], 
        tokenizer=tokenizer, 
        max_tokenized_length=args.model_max_length - 2, 
        search_results=retriever_search_results,
        prompt_template=prompt_template,
        full_to_subset_idx_map=full_to_subset_idx_map,
        do_normalize_query=True, 
        num_documents_in_context=args.num_retrieved_documents,
        gold_position=args.gold_position, # None in these experiments
    )
        
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_dataloader


def print_info(args: argparse.Namespace):
    print("INFO:")    
    print(f"DATA: {info[args.dataset]['test']['data_path']}")
    print(f"USE TEST: {args.use_test}")
    print(f"MODEL: {args.llm_id}")
    print(f"MODEL MAX LENGTH: {args.model_max_length}")
    print(f'MAX NEW TOKENS: {args.max_new_tokens}')
    print(f"USE MODEL CHAT TEMPLATE: {args.use_model_chat_template}")
    print(f"TASK WITH PROOF:", args.use_task_with_proof)
    print(f"GOLD POSITION: {args.gold_position}")
    print(f"NUM DOCUMENTS IN CONTEXT: {args.num_retrieved_documents}")
    print(f"BATCH SIZE: {args.batch_size}")
    print(f"SAVE EVERY: {args.save_every}")


def extract_generate_answers(
    args: argparse.Namespace, 
    generated_output: List[str]
) -> List[str]:
    answer_prefix = "Answer:"
    if args.use_model_chat_template:
        #answer_prefix = re.escape(chat_task_templates[args.llm_id]['answer_prefix'])

        answer_prefix = re.escape("Answer:") + r"\nmodel"  #adjust this


    generated_answers = []
    for output in generated_output:
        matches = list(re.finditer(answer_prefix, output))

        match_idx = 0

        # When using the proof there is a one-shot example that already 
        # contains the string "Answer:". Thus, we should get the second (match_idx=1) match.
        if args.use_task_with_proof:
            match_idx = 1
            if args.use_model_chat_template and answer_prefix != "Answer:":
                match_idx = 0
 
        answer_end = matches[match_idx].end()
        response = output[answer_end:].strip()
        generated_answers.append(response)
    
    return generated_answers


def generate_and_save(
    args: argparse.Namespace, 
    llm: LLM, 
    prompt_dataloader: DataLoader
):
    # Info from arguments
    llm_id = args.llm_id
    num_doc = args.num_retrieved_documents
    save_every = args.save_every
    retriever_str = "contriever"
    padding_str = f"_{args.padding_strategy}{args.model_max_length}" if args.padding_strategy != "longest" else "" 
    chat_template_str = "_template" if args.use_model_chat_template else ""
    prompt_type = "retrieved_proof" if args.use_task_with_proof else "retrieved"

    # Create the saving directory
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{args.output_dir}/{args.dataset}/{llm_folder}/test/{prompt_type}/{retriever_str}/{num_doc}_doc"
    os.makedirs(saving_dir, exist_ok=True)

    all_info = []  
    for idx, prompt_batch in enumerate(tqdm(prompt_dataloader)):
        prompts = prompt_batch['prompt']
        generated_output = llm.generate(
            prompts,
            padding_strategy=args.padding_strategy, 
            max_new_tokens=args.max_new_tokens
        )

        generated_answers = extract_generate_answers(args, generated_output)
        prompt_batch['generated_answer'] = generated_answers
        all_info.append(prompt_batch)
        
        if (idx + 1) % save_every == 0 or (idx + 1) == len(prompt_dataloader):
            print(f"Saving at {idx + 1}...")
            file_name = f"{saving_dir}/numdoc{num_doc}_retr{args.num_retrieved_documents}{padding_str}{chat_template_str}_info_{idx+1}.pkl"
            write_pickle(all_info, file_name)
            all_info = []


def main():
    args = parse_arguments()

    print("Loading LLM...")
    llm_id = args.llm_id
    llm = LLM(
        llm_id, device, 
        quantization_bits=args.quantization_bits, 
        model_max_length=args.model_max_length
    )
    tokenizer = llm.tokenizer
    print("LLM loaded")


    print("Loading corpus and search results...")
    corpus, full_to_subset_idx_map = load_corpus(args)
    retriever_search_results = load_search_results(args)
    print("Corpus and search results loaded")


    print("Loading prompt dataset...")
    prompt_dataloader = initialize_dataset_and_loader(
        args, corpus, full_to_subset_idx_map, 
        retriever_search_results, tokenizer
    )
    print("Prompt dataset loaded")

    print_info(args)
    generate_and_save(args, llm, prompt_dataloader)



if __name__ == "__main__":
    seed_everything(SEED)
    main()

  from .autonotebook import tqdm as notebook_tqdm


Loading LLM...


Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.05s/it]


LLM loaded
Loading corpus and search results...
Corpus and search results loaded
Loading prompt dataset...
Prompt dataset loaded
INFO:
DATA: C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json
USE TEST: True
MODEL: google/gemma-2-2b-it
MODEL MAX LENGTH: 4096
MAX NEW TOKENS: 50
USE MODEL CHAT TEMPLATE: True
TASK WITH PROOF: False
GOLD POSITION: None
NUM DOCUMENTS IN CONTEXT: 5
BATCH SIZE: None
SAVE EVERY: 250


  9%|▊         | 250/2889 [13:08<2:03:57,  2.82s/it]

Saving at 250...


 17%|█▋        | 500/2889 [25:08<2:17:12,  3.45s/it]

Saving at 500...


 26%|██▌       | 750/2889 [36:47<1:14:54,  2.10s/it]

Saving at 750...


 35%|███▍      | 1000/2889 [48:34<1:36:16,  3.06s/it]

Saving at 1000...


 43%|████▎     | 1250/2889 [1:01:31<1:52:55,  4.13s/it]

Saving at 1250...


 52%|█████▏    | 1500/2889 [1:13:57<59:12,  2.56s/it]  

Saving at 1500...


 61%|██████    | 1750/2889 [1:27:05<1:08:19,  3.60s/it]

Saving at 1750...


 69%|██████▉   | 2000/2889 [1:39:34<41:00,  2.77s/it]  

Saving at 2000...


 78%|███████▊  | 2250/2889 [1:52:16<33:42,  3.16s/it]

Saving at 2250...


 87%|████████▋ | 2500/2889 [2:05:08<15:20,  2.37s/it]

Saving at 2500...


 95%|█████████▌| 2750/2889 [2:17:26<07:29,  3.24s/it]

Saving at 2750...


100%|██████████| 2889/2889 [2:24:33<00:00,  3.74s/it]

Saving at 2889...


100%|██████████| 2889/2889 [2:24:37<00:00,  3.00s/it]


In [17]:
# Show the results of generate_answer_llm
from utils import *
from prompt_dataset import *

result_path=r"C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm\nq\gemma-2-2b-it\test\retrieved\contriever\5_doc\numdoc5_retr5_template_info_250.pkl"
data=read_pickle(result_path)

data

[{'example_id': '-3290814144789249484',
  'query': 'who got the first nobel prize in physics',
  'prompt': '<bos><start_of_turn>user\nYou are given a question and you must respond based on the provided documents. You must always provide an answer.\nDocuments:\nDocument [628506](Title: Nobel Prize in Physics) receive a diploma, a medal and a document confirming the prize amount. Nobel Prize in Physics The Nobel Prize in Physics () is a yearly award given by the Royal Swedish Academy of Sciences for those who have made the most outstanding contributions for mankind in the field of physics. It is one of the five Nobel Prizes established by the will of Alfred Nobel in 1895 and awarded since 1901; the others being the Nobel Prize in Chemistry, Nobel Prize in Literature, Nobel Peace Prize, and Nobel Prize in Physiology or Medicine. The first Nobel Prize in Physics was\nDocument [3546609](Title: E. C. George Sudarshan) had developed the breakthrough. In 2007, Sudarshan told the "Hindustan Tim

In [18]:
import os
import re
import json
import pickle
import argparse

import torch
import pandas as pd
from typing import List, Dict

from utils import str2bool
from normalize_answers import *
from read_negative_rejection import *


def are_answers_matching(prediction: str, ground_truths: List[str]) -> float:
    normalized_prediction = normalize_answer(prediction)

    for ground_truth in ground_truths:
        normalized_ground_truth = normalize_answer(ground_truth)
        if normalized_ground_truth in normalized_prediction:
            return True
    return False


def extract_proof_from_text(text: str) -> str:
    matches = list(re.finditer("Proof:", text))
    
    if matches:
        proof_end = matches[0].end()
        proof = text[proof_end:].strip()
        # Get the text until the first new line
        proof = proof.split('\n', 1)[0] 
    else:
        proof = "NO-PROOF"

    return proof


def compute_df_accuracy(df: pd.DataFrame, attribute: str) -> float:
    return round(df[attribute].sum() / len(df), 4) * 100


def read_generation_results(file_path: str, df: pd.DataFrame) -> List[Dict]:
    data = []
    with open(file_path, "r") as fin:
        file_data = json.load(fin)

        # Handle list-based or dict-based structures
        if isinstance(file_data, list):
            examples = file_data
        else:
            examples = [file_data]  # Wrap single dictionary in a list for consistency

        for example in examples:
            example_id = example.get('example_id', [])
            query = example.get('query', [])
            prompt = example.get('prompt', [])
            document_indices = example.get('document_indices', [])
            gold_document_idx = example.get('gold_document_idx', [])
            generated_answer = example.get('generated_answer', [])
            prompt_tokens_len = example.get('prompt_tokens_len', [])

            documents_idx = list(document_indices) if isinstance(document_indices, list) else [document_indices]
            generated_answer = generated_answer[0].split('\n', 1)[0]
            filtered_df = df[df['example_id'].astype(str) == str(example_id)]
            if filtered_df.empty:
                print(f"Warning: No matching entry found for example_id {example_id}")
                answers = None
            else:
                answers = filtered_df.answers.iloc[0]

            gold_in_retrieved = int(gold_document_idx) in map(int, documents_idx)
            ans_match_after_norm = are_answers_matching(generated_answer, answers) if answers else False
            ans_in_documents = is_answer_in_text(prompt, answers) if answers else False

            data.append({
                'example_id': str(example_id),
                'query': query,
                'prompt': prompt,
                'document_indices': documents_idx,
                'gold_document_idx': gold_document_idx,
                'generated_answer': generated_answer,
                'answers': answers,
                'ans_match_after_norm': ans_match_after_norm,
                'gold_in_retrieved': gold_in_retrieved,
                'ans_in_documents': ans_in_documents,
                "prompt_tokens_len": prompt_tokens_len,
            })

            if 'proof' in file_path:
                proof = extract_proof_from_text(generated_answer)
                data[-1]['proof'] = proof
                data[-1]['ans_in_proof'] = is_answer_in_text(proof, [generated_answer])

    return data


def read_generation_results_only_query(file_path: str, df: pd.DataFrame) -> List[Dict]:
    data = []
    with open(file_path, "r") as fin:
        file_data = json.load(fin)

        for example in file_data:
            example_ids = example['example_id']
            queries = example['query']
            prompts = example['prompt']
            generated_answers = example['generated_answer']

            for i in range(len(example_ids)):
                example_id = example_ids[i]
                query = queries[i]
                # After the first new line, LLMs usually generate random text,
                # so it is skipped in the matching comparison
                generated_answer = generated_answers[i].split('\n', 1)[0]
                prompt = prompts[i]

                answers = df[df['example_id'].astype(str) == str(example_id)].answers.iloc[0]

                ans_match_after_norm: bool = are_answers_matching(generated_answer, answers)
                ans_in_documents: bool = is_answer_in_text(prompt, answers)
                data.append({
                    'example_id': str(example_id),
                    'query': query,
                    'prompt': prompt,
                    'generated_answer': generated_answers[i],
                    'answers': answers,
                    'ans_match_after_norm': ans_match_after_norm,
                    'ans_in_documents': ans_in_documents,
                })

    return data


def convert_tensors(cell):
    """ Converts tensors in the given cell to lists, if they are tensors. """
    if isinstance(cell, list):
        return [[t.tolist() if torch.is_tensor(t) else t for t in inner_list] for inner_list in cell]
    return cell


def extract_number_from_filename(filename: str, pattern: re.Pattern) -> int:
    """ Extracts the number from the filename based on the provided pattern. """
    match = pattern.search(filename)
    return int(match.group(1)) if match else 0


def load_pickle_files(directory: str, filename_prefix: str) -> pd.DataFrame:
    """ Loads and concatenates data from all pickle files in the directory with the given prefix. """
    pattern = re.compile(r'(\d+).pkl')
    files = [f for f in os.listdir(directory) if f.endswith('.pkl') and filename_prefix in f]
    files.sort(key=lambda f: extract_number_from_filename(f, pattern))
    print("I'm using the following files: ", files)

    data_list = []
    for file in files:
        with open(os.path.join(directory, file), 'rb') as f:
            data = pickle.load(f)
            data_list.extend(data)
    
    data_df = pd.DataFrame(data_list)
    if 'only_query' in directory:
        if data_df['example_id'].dtype != "O":
            data_df['example_id'] = data_df['example_id'].apply(lambda x: x.tolist())
    
    '''
    else:
        print(type(data_df['document_indices'].values))
        if not isinstance(data_df['document_indices'], list):
            data_df['document_indices'] = data_df['document_indices'].apply(convert_tensors)
    

    if 'prompt_tokens_len' in data_df.columns:
        data_df['prompt_tokens_len'] = data_df['prompt_tokens_len'].apply(lambda x: x.tolist())

    '''
        
    return data_df


def save_data_to_json(data_df: pd.DataFrame, directory: str, filename_prefix: str):
    """ Saves the given DataFrame to a JSON file. """
    data_path = os.path.join(directory, f'{filename_prefix}all.json')
    # Check if the file already exists
    if os.path.exists(data_path):
        overwrite = input(f"File {data_path} already exists. Overwrite? (y/n): ")
        if overwrite.lower() != 'y':
            print("No overwrite.")

            results_df = pd.read_json(f'{directory}/{filename_prefix}all_extended.json')
            accuracy = compute_df_accuracy(results_df, 'ans_match_after_norm')
            print("ACCURACY: ", accuracy)

            if 'proof' in directory:
                accuracy_ans_in_proof = compute_df_accuracy(results_df, 'ans_in_proof')
                print("ACCURACY ANS IN PROOF", accuracy_ans_in_proof)

            correct_ans_not_in_context_accuracy = compute_accuracy_correct_answer_not_in_context(results_df)
            print(f"Correct Answer Not in Context Accuracy: {correct_ans_not_in_context_accuracy}")

            return None
        
    data_df.to_json(data_path, orient='records')
    return data_path


def get_retrieved_path(args):
    padding_str = f"_{args.padding_strategy}{args.model_max_length}" if args.padding_strategy != "longest" else "" 
    chat_template_str = "_template" if args.use_model_chat_template else ""

    filename_prefix = f"numdoc{args.num_doc}_retr{args.num_retrieved_documents}{padding_str}{chat_template_str}_info_"
    return filename_prefix


def get_only_query_path(args):
    chat_template_str = "_template" if args.use_model_chat_template else ""

    filename_prefix = f"only_query{chat_template_str}_info_"
    return filename_prefix

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):
    """
    Mimics argparse to parse arguments for LLM generation. Accepts custom arguments as a dictionary for notebooks.
    """
    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm',
        'llm_id': 'google/gemma-2-2b-it',
        'dataset': 'nq',
        'model_max_length': 4096,
        'use_model_chat_template': True, 
        'gold_position': None,
        'num_retrieved_documents': 5,
        'use_test': True,
        'padding_strategy': 'longest',
        'max_new_tokens': 50,
        'prompt_type': 'retrieved'
    }

    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    # Perform validation
    if not default_args['prompt_type'] in ['retrieved', 'retrieved_proof', 'only_query']:
        raise ValueError("Invalid prompt type. Must be one of ['retrieved', 'retrieved_proof', 'only_query']")
    
    return DotDict(**default_args)


info = {
    "nq": {
        "test": r'C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json',
    },
}

def main():
    args = parse_arguments()
    
    retriever_str = "contriever/"

    prompt_type = args.prompt_type
    if 'retrieved' in prompt_type:    
        args.num_doc = args.num_retrieved_documents
        filename_prefix = get_retrieved_path(args)
    elif prompt_type == 'only_query':
        filename_prefix = get_only_query_path(args)
    else:
        raise ValueError("Invalid prompt type")


    llm_id = args.llm_id
    split = "test" if args.use_test else "train"
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    doc_str = f"{args.num_doc}_doc" if 'only_query' not in prompt_type else ""
    directory = f'{args.output_dir}/{args.dataset}/{llm_folder}/{split}/{prompt_type}/{retriever_str}{doc_str}'
    print("Directory: ", directory)

    df = pd.read_json(info[args.dataset][split], dtype={'example_id': str})

    data_df = load_pickle_files(directory, filename_prefix)
    data_path = save_data_to_json(data_df, directory, filename_prefix)
    if data_path is None:
        return
    
    if 'only_query' in directory:
        results = read_generation_results_only_query(data_path, df)
    else:
        results = read_generation_results(data_path, df)

    results_df = pd.DataFrame(results)
    accuracy = compute_df_accuracy(results_df, 'ans_match_after_norm')
    print("ACCURACY: ", accuracy)
    if 'proof' in directory:
        accuracy_ans_in_proof = compute_df_accuracy(results_df, 'ans_in_proof')
        print("ACCURACY ANS IN PROOF", accuracy_ans_in_proof)
        
    results_df.to_json(os.path.join(directory, f'{filename_prefix}all_extended.json'), orient='records')

    correct_ans_not_in_context_accuracy = compute_accuracy_correct_answer_not_in_context(results_df)
    print(f"Correct Answer Not in Context Accuracy: {correct_ans_not_in_context_accuracy}")

if __name__ == "__main__":
    main()

Directory:  C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm/nq/gemma-2-2b-it/test/retrieved/contriever/5_doc
I'm using the following files:  ['numdoc5_retr5_template_info_250.pkl', 'numdoc5_retr5_template_info_500.pkl', 'numdoc5_retr5_template_info_750.pkl', 'numdoc5_retr5_template_info_1000.pkl', 'numdoc5_retr5_template_info_1250.pkl', 'numdoc5_retr5_template_info_1500.pkl', 'numdoc5_retr5_template_info_1750.pkl', 'numdoc5_retr5_template_info_2000.pkl', 'numdoc5_retr5_template_info_2250.pkl', 'numdoc5_retr5_template_info_2500.pkl', 'numdoc5_retr5_template_info_2750.pkl', 'numdoc5_retr5_template_info_2889.pkl']
ACCURACY:  34.65
Number of samples with no answer in context: 1406
Number of samples with correct answer not in context: 48
Correct Answer Not in Context Accuracy: 3.4099999999999997


In [46]:
from utils import *

path= r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm\nq\gemma-2-2b-it\test\retrieved\contriever\5_doc\numdoc5_retr5_template_info_all_extended.json'

read_results=read_json(path)

results = [dizionario['ans_match_after_norm'] for dizionario in read_results]

true_results= []
for i in range (len(results)):

    if results[i] == True:
        true_results.append(results[i])

true_results

print(f'Results len: {len(results)}')
print(f'True Results len: {len(true_results)}')

Results len: 2889
True Results len: 1001


In [None]:
read_results

[{'example_id': '-3290814144789249484',
  'query': 'who got the first nobel prize in physics',
  'prompt': '<bos><start_of_turn>user\nYou are given a question and you must respond based on the provided documents. You must always provide an answer.\nDocuments:\nDocument [628506](Title: Nobel Prize in Physics) receive a diploma, a medal and a document confirming the prize amount. Nobel Prize in Physics The Nobel Prize in Physics () is a yearly award given by the Royal Swedish Academy of Sciences for those who have made the most outstanding contributions for mankind in the field of physics. It is one of the five Nobel Prizes established by the will of Alfred Nobel in 1895 and awarded since 1901; the others being the Nobel Prize in Chemistry, Nobel Prize in Literature, Nobel Peace Prize, and Nobel Prize in Physiology or Medicine. The first Nobel Prize in Physics was\nDocument [3546609](Title: E. C. George Sudarshan) had developed the breakthrough. In 2007, Sudarshan told the "Hindustan Tim

In [None]:
import torch
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM
)
from typing import List, Tuple, Dict


class BridgeModel:
    """
    Bridge Model for selecting and ranking documents by generating document IDs.
    """
    def __init__(
        self, 
        model_id: str, 
        device: str = 'cuda', 
        model_max_length: int = 512
    ):
        self.device = device
        self.model_max_length = model_max_length

        # Initialize the seq2seq model and tokenizer
        self.model, self.tokenizer = self._initialize_model_tokenizer(model_id)

    def _initialize_model_tokenizer(self, model_id: str) -> Tuple[AutoModelForSeq2SeqLM, AutoTokenizer]:
        """
        Initializes the seq2seq model and tokenizer with the given model ID.
        """
        model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        model_config.max_seq_len = self.model_max_length

        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            config=model_config,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        ).to(self.device)
        model.eval()  # Set the model to evaluation mode

        tokenizer = AutoTokenizer.from_pretrained(
            model_id, 
            model_max_length=self.model_max_length,
            padding_side="left",
            truncation_side="left"
        )
        tokenizer.pad_token = tokenizer.eos_token  # Set pad token

        return model, tokenizer

    def generate(
        self, 
        prompt: str, 
        max_new_tokens: int = 15
    ) -> List[str]:
        """
        Generates the ordered document IDs based on the query and documents.
        """
        # Tokenize input
        inputs = self.tokenizer(
            prompt, 
            return_tensors="pt", 
            max_length=self.model_max_length, 
            padding=True, 
            truncation=True
        ).to(self.device)

        # Generate output
        generated_ids = self.model.generate(
            **inputs,
            do_sample=False,  # Deterministic output
            max_new_tokens=max_new_tokens,
            repetition_penalty=1.1,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )

        # Decode output
        return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].split()

# Example Usage
if __name__ == "__main__":
    model_id = "t5-small"  # Replace with a seq2seq model like T5 or BART
    bridge = BridgeModel(model_id=model_id)

    output = bridge.generate(prompt)
    print("Output document IDs:", output)