In [1]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_name(1))

True
2
NVIDIA GeForce RTX 3080
NVIDIA GeForce RTX 3060


In [3]:
from utils import *

def analyze_json_responses(my_path):
    data = read_json(my_path)  # Legge il file JSON

    total_examples = len(data)  # Numero totale di esempi
    correct_with_context = 0  # Risposte corrette con risposta nel contesto
    correct_without_context = 0  # Risposte corrette senza risposta nel contesto
    has_answer_in_context = 0  # Esempi con risposta nel contesto
    no_answer_in_context = 0  # Esempi senza risposta nel contesto
    total_correct = 0  # Totale delle risposte corrette

    for entry in data:
        ans_in_documents = entry.get("ans_in_documents", False)
        ans_match_after_norm = entry.get("ans_match_after_norm", False)

        if ans_in_documents:
            has_answer_in_context += 1
            if ans_match_after_norm:
                correct_with_context += 1
        else:
            no_answer_in_context += 1
            if ans_match_after_norm:
                correct_without_context += 1

        # Conta ogni risposta corretta
        if ans_match_after_norm:
            total_correct += 1

    # Calcola le medie
    avg_correct_with_context = correct_with_context / has_answer_in_context if has_answer_in_context > 0 else 0
    avg_correct_without_context = correct_without_context / no_answer_in_context if no_answer_in_context > 0 else 0
    overall_accuracy = total_correct / total_examples if total_examples > 0 else 0

    return {
        "total_examples": total_examples,
        "examples_with_answer_in_context": has_answer_in_context,
        "examples_without_answer_in_context": no_answer_in_context,
        "correct_with_context": correct_with_context,
        "correct_without_context": correct_without_context,
        "average_correct_with_context": avg_correct_with_context,
        "average_correct_without_context": avg_correct_without_context,
        "overall_accuracy": overall_accuracy
    }


# Esempio di utilizzo
path = r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm\nq\gemma-2-2b-it\test\retrieved\contriever\1_doc\numdoc1_retr1_template_info_all_extended.json'
result = analyze_json_responses(path)
print(result)

{'total_examples': 2889, 'examples_with_answer_in_context': 721, 'examples_without_answer_in_context': 2168, 'correct_with_context': 539, 'correct_without_context': 87, 'average_correct_with_context': 0.7475728155339806, 'average_correct_without_context': 0.04012915129151291, 'overall_accuracy': 0.2166839736933195}


In [6]:
import os
import argparse
import warnings
import pandas as pd
import re
from tqdm import tqdm
from typing import Tuple, Dict, Optional

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from utils import *
from bgm import BGM
from default_prompts import *
from prompt_dataset import PromptDataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings('ignore')
SEED=10

info = {
    "nq_bgm": {
        "train": {
            "data_path": r'C:\Users\franc\Documents\Bridge_the_GAP\data\10k_train_dataset.json',
            "contriever_search_results_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\processed\contriever_search_results_at150.pkl",
        }
    },
}

def save_dataloader_to_json(dataloader, output_file, num_examples=15):
    all_batches = []

    print("Saving DataLoader contents to JSON...")
    for idx, batch in enumerate(dataloader):
        if idx >= num_examples:  # Stop after saving the specified number of examples
            break

        batch_dict = {}
        for key, value in batch.items():
            # Convert tensors to lists for JSON serialization
            if isinstance(value, torch.Tensor):
                batch_dict[key] = value.tolist()
            else:
                batch_dict[key] = value
        all_batches.append(batch_dict)
    
    # Save the entire list of dictionaries to a JSON file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(all_batches, f, ensure_ascii=False, indent=4)

    print(f"DataLoader contents saved to {output_file}")

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):
    """
    Mimics argparse to parse arguments for LLM generation. Accepts custom arguments as a dictionary for notebooks.
    """
    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_id_document_bgm',
        'llm_id': 'google/flan-t5-large',
        'dataset': 'nq_bgm',
        'model_max_length': 4096,
        'quantization_bits': 4,
        'use_model_chat_template': False, 
        'gold_position': None,
        'num_retrieved_documents': 5,
        'use_test': False,
        'max_new_tokens': 50,
        'use_task_with_proof': False,
        'batch_size': None,
        'save_every': 250,
    }

    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    # Perform validation
    if default_args['num_retrieved_documents'] is None:
        raise ValueError("'num_retrieved_documents' must be specified.")
    if default_args['num_retrieved_documents'] <= 0:
        raise ValueError("'num_retrieved_documents' must be a positive integer.")
    if default_args['gold_position'] is not None:
        if (default_args['gold_position'] < 0 or 
            default_args['gold_position'] >= default_args['num_retrieved_documents']):
            raise ValueError("'gold_position' must be within the range of 'num_retrieved_documents'.")

    return DotDict(**default_args)


def load_corpus(
    args: argparse.Namespace
) -> Tuple[List[Dict], Optional[Dict[int, int]]]:
    
    # Corpus with documents from Contriever
    corpus, full_to_subset_idx_map = read_corpus_with_contriever()

    return corpus, full_to_subset_idx_map

def load_search_results(args: argparse.Namespace) -> List[Tuple[List[int], List[float]]]:

    search_results_path = info[args.dataset][args.split]['contriever_search_results_path']
    retriever_search_results = read_pickle(search_results_path)

    return retriever_search_results


def get_prompt_template(args: argparse.Namespace):
    prompt_configuration = args.dataset
    if args.use_model_chat_template:
        chat_task_template_str = chat_task_templates[args.llm_id]['template']
        
        task_instruction = task_instructions[prompt_configuration]

        prompt_template = apply_chat_task_template(chat_task_template_str, task_instruction)
    else:
        task_template = task_templates[prompt_configuration]

        prompt_template = task_template.create_prompt_template()

    return prompt_template


def initialize_dataset_and_loader(
    args: argparse.Namespace, 
    corpus: List[Dict], 
    full_to_subset_idx_map: Optional[Dict[int, int]], 
    retriever_search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer
) -> DataLoader:
    
    prompt_template = get_prompt_template(args)
    
    prompt_ds = PromptDataset(
        corpus=corpus, data_path=info[args.dataset][args.split]['data_path'], 
        tokenizer=tokenizer, 
        max_tokenized_length=args.model_max_length - 2, 
        search_results=retriever_search_results,
        prompt_template=prompt_template,
        full_to_subset_idx_map=full_to_subset_idx_map,
        do_normalize_query=True, 
        num_documents_in_context=args.num_retrieved_documents,
        gold_position=args.gold_position, # None in these experiments
    )
        
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_dataloader


def print_info(args: argparse.Namespace):
    print("INFO:")    
    print(f"DATA: {info[args.dataset][args.split]['data_path']}")
    print(f"USE TEST: {args.use_test}")
    print(f"MODEL: {args.llm_id}")
    print(f"MODEL MAX LENGTH: {args.model_max_length}")
    print(f'MAX NEW TOKENS: {args.max_new_tokens}')
    print(f"USE MODEL CHAT TEMPLATE: {args.use_model_chat_template}")
    print(f"TASK WITH PROOF:", args.use_task_with_proof)
    print(f"GOLD POSITION: {args.gold_position}")
    print(f"NUM DOCUMENTS IN CONTEXT: {args.num_retrieved_documents}")
    print(f"BATCH SIZE: {args.batch_size}")
    print(f"SAVE EVERY: {args.save_every}")


def extract_generate_answers(
    args: argparse.Namespace, 
    generated_output: List[str]
) -> List[str]:
    answer_prefix = "Answer:"
    if args.use_model_chat_template:
        answer_prefix = re.escape(chat_task_templates[args.llm_id]['answer_prefix'])

    generated_answers = []
    for output in generated_output:
        matches = list(re.finditer(answer_prefix, output))
        match_idx = 0

        # When using the proof there is a one-shot example that already 
        # contains the string "Answer:". Thus, we should get the second (match_idx=1) match.
        if args.use_model_chat_template and answer_prefix != "Answer:":
            match_idx = 0
 
        answer_end = matches[match_idx].end()
        response = output[answer_end:].strip()
        generated_answers.append(response)
    
    return generated_answers


def BGMTraining(
    args: argparse.Namespace, 
    prompt_ds: PromptDataset,
    llm: BGM, 
    prompt_dataloader: DataLoader
):
    # Info from arguments
    llm_id = args.llm_id
    num_doc = args.num_retrieved_documents
    save_every = args.save_every
    retriever_str = "contriever" 
    chat_template_str = "_template" if args.use_model_chat_template else ""
    prompt_type = "retrieved_proof" if args.use_task_with_proof else "retrieved"

    # Create the saving directory
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{args.output_dir}/{args.dataset}/{llm_folder}/{args.split}/{prompt_type}/{retriever_str}/{num_doc}_doc"
    os.makedirs(saving_dir, exist_ok=True)

    all_info = []  
    for idx, prompt_batch in enumerate(tqdm(prompt_dataloader)):
        prompts = prompt_batch['prompt']
        example_id = prompt_batch['example_id']
        prompts = prompt_batch['prompt']
        query = prompt_batch['query']
        document_indices=prompt_batch['document_indices']
        
        for doc_idx in document_indices:
            
            candidate_docs += doc_idx

            formatted_docs, _ = prompt_ds._get_documents_from_indices(candidate_docs)

            if '\nAnswer:' not in candidate_prompt:
                candidate_prompt += '\nAnswer:'

            

        generated_output = llm.generate(
            prompts, 
            max_new_tokens=args.max_new_tokens
        )

        generated_answers = extract_generate_answers(args, generated_output)
        prompt_batch['generated_answer'] = generated_answers
        
        all_info.append(prompt_batch)
        '''
        if (idx + 1) % save_every == 0 or (idx + 1) == len(prompt_dataloader):
            print(f"Saving at {idx + 1}...")
            file_name = f"{saving_dir}/numdoc{num_doc}_retr{args.num_retrieved_documents}{chat_template_str}_info_{idx+1}.pkl"
            write_pickle(all_info, file_name)
            all_info = []
        '''


def main():
    args = parse_arguments()

    args.split = "test" if args.use_test else "train"

    print("Loading LLM...")
    llm_id = args.llm_id
    bgm = BGM(
        llm_id, device,  
        model_max_length=args.model_max_length
    )
    tokenizer = bgm.tokenizer
    print("LLM loaded")


    print("Loading corpus and search results...")
    corpus, full_to_subset_idx_map = load_corpus(args)
    retriever_search_results = load_search_results(args)
    print("Corpus and search results loaded")


    print("Loading prompt dataset...")
    prompt_ds, prompt_dataloader = initialize_dataset_and_loader(
        args, corpus, full_to_subset_idx_map, 
        retriever_search_results, tokenizer
    )
    print("Prompt dataset loaded")

    print_info(args)

    #output_json_path = r'C:\Users\franc\Documents\Bridge_the_GAP\data\dataloader_contents.json'
    #save_dataloader_to_json(prompt_dataloader, output_json_path, num_examples=15)
        
    BGMTraining(args, prompt_ds, bgm, prompt_dataloader)



if __name__ == "__main__":
    seed_everything(SEED)
    main()

Loading LLM...
LLM loaded
Loading corpus and search results...
Corpus and search results loaded
Loading prompt dataset...
Prompt dataset loaded
INFO:
DATA: C:\Users\franc\Documents\Bridge_the_GAP\data\10k_train_dataset.json
USE TEST: False
MODEL: google/flan-t5-large
MODEL MAX LENGTH: 4096
MAX NEW TOKENS: 50
USE MODEL CHAT TEMPLATE: False
TASK WITH PROOF: False
GOLD POSITION: None
NUM DOCUMENTS IN CONTEXT: 3
BATCH SIZE: None
SAVE EVERY: 250
Saving DataLoader contents to JSON...
DataLoader contents saved to C:\Users\franc\Documents\Bridge_the_GAP\data\dataloader_contents.json


In [None]:
import json

def match_example_ids(file1_path, file2_path, output_path):
    """
    Modifica il file1 aggiungendo l'example_id da file2 quando query e question corrispondono.

    Args:
        file1_path (str): Percorso al file JSON di input 1.
        file2_path (str): Percorso al file JSON di input 2.
        output_path (str): Percorso al file JSON di output aggiornato.
    """
    try:
        # Caricamento dei file JSON
        with open(file1_path, 'r') as f1:
            file1 = json.load(f1)

        with open(file2_path, 'r') as f2:
            file2 = json.load(f2)

        # Creazione di un dizionario per mappare le domande agli example_id
        question_to_example_id = {item['question']: item['example_id'] for item in file2}

        # Modifica del primo file
        for entry in file1:
            query = entry.get('query')
            if query in question_to_example_id:
                entry['example_id'] = question_to_example_id[query]

        # Salvataggio del file aggiornato
        with open(output_path, 'w') as f1_updated:
            json.dump(file1, f1_updated, indent=4)

        print(f"File aggiornato salvato in: {output_path}")
    except FileNotFoundError as e:
        print(f"Errore: {e}")
    except json.JSONDecodeError as e:
        print(f"Errore nel parsing del file JSON: {e}")
    except Exception as e:
        print(f"Errore imprevisto: {e}")

def update_queries_with_document_indices(file1_path, file2_path, output_path):
    # Carica i dati dai file JSON
    with open(file1_path, 'r', encoding='utf-8') as f1, open(file2_path, 'r', encoding='utf-8') as f2:
        file1_data = json.load(f1)
        file2_data = json.load(f2)

    # Crea un dizionario per mappare le query ai document_indices di File 2
    query_to_indices = {
        entry['query']: entry.get('document_indices', [])
        for entry in file2_data
    }

    # Aggiorna File 1 aggiungendo i document_indices associati alle query
    for entry in file1_data:
        query = entry['query']
        if query in query_to_indices:
            entry['document_indices'] = query_to_indices[query]

    # Salva il risultato in un nuovo file JSON
    with open(output_path, 'w', encoding='utf-8') as output_file:
        json.dump(file1_data, output_file, indent=4, ensure_ascii=False)


path_output=r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_template_info_all_extended_updated.json'
file_da_modificare = r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_template_info_all_extended.json'
file_di_confronto = r'C:\Users\franc\Documents\Bridge_the_GAP\data\10k_train_dataset.json'

match_example_ids(file_da_modificare, file_di_confronto, path_output)

update_queries_with_document_indices(r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_template_info_all_extended_updated.json', r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_template_info_all_extended_updated.json', r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_template_info_all_extended_updated_last.json')

File aggiornato salvato in: C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_template_info_all_extended_updated.json


In [11]:
import json
import random

# Percentuali per ogni caso
percentages = {
    "case_1_single_doc": 0.1,
    "case_2_multiple_docs": 0.2,
    "case_3_no_docs": 0.1,
    "case_4_less_docs": 0.4,
    "case_5_reranking": 0.2,
}

# Task instruction da aggiungere a ogni query
task_instruction = "Output only the document IDs relevant to the query. Use this format: [ID1, ID2, ...]."

def process_data(input_file, output_file):
    with open(input_file, "r", encoding="utf-8") as f:
        examples = json.load(f)

    dataset = []
    
    # Shuffle examples to ensure random sampling
    random.shuffle(examples)

    # Total examples to be processed for each case
    total_examples = len(examples)
    case_limits = {case: int(total_examples * perc) for case, perc in percentages.items()}
    case_counters = {case: 0 for case in percentages}

    for example in examples:
        if all(count >= case_limits[case] for case, count in case_counters.items()):
            break  # Stop if all case limits are met
        
        query = f"Task Instruction: {task_instruction}\nQuestion:{example['query']}"  # Aggiunge la task instruction
        retrieved_docs = example["document_indices"]
        selected_docs = example["selected_documents"]
        are_answer = example["are_answer"]

        # Case 1: Single document correct answer
        if are_answer and len(selected_docs) == 1 and case_counters["case_1_single_doc"] < case_limits["case_1_single_doc"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": retrieved_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_1_single_doc"] += 1

        # Case 2: Multiple documents correct answer
        elif are_answer and len(selected_docs) > 1 and case_counters["case_2_multiple_docs"] < case_limits["case_2_multiple_docs"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": retrieved_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_2_multiple_docs"] += 1

        # Case 3: No documents correct answer
        elif are_answer and len(selected_docs) == 0 and case_counters["case_3_no_docs"] < case_limits["case_3_no_docs"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": [],
                },
                "output": [],
            })
            case_counters["case_3_no_docs"] += 1

        # Case 4: Input and output unchanged
        elif are_answer and len(selected_docs) > 2 and case_counters["case_4_less_docs"] < case_limits["case_4_less_docs"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": selected_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_4_less_docs"] += 1

        # Case 5: Reranking
        elif are_answer and len(selected_docs) > 2 and case_counters["case_5_reranking"] < case_limits["case_5_reranking"]:
            reranked_docs = random.sample(selected_docs, len(selected_docs))  # Randomize order
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": reranked_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_5_reranking"] += 1

    # Save the dataset to a file
    with open(output_file, "w") as f:
        json.dump(dataset, f, indent=4)

# Path to input and output files
input_file = r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_info_all_extended_training_set.json'
output_file = r'C:\Users\franc\Documents\Bridge_the_GAP\data\training_dataset.json'

process_data(input_file, output_file)

In [33]:
import json
import random

# Percentuali per ogni caso
percentages = {
    "case_1_single_doc": 0.07,
    "case_2_multiple_docs": 0.4,
    "case_3_no_docs": 0.1,
    "case_4_multi_doc_unchanged": 0.35,
    "case_5_reranking": 0.5,
    "case_6_single_doc_unchanged": 0.05,
}

# Task instruction da aggiungere a ogni query
task_instruction = "Output only the document IDs relevant to the query. Use this format: [ID1, ID2, ...]."

def process_data(input_file, output_file):
    with open(input_file, "r", encoding="utf-8") as f:
        examples = json.load(f)

    # Filtra gli esempi con are_answer = true
    valid_examples = [ex for ex in examples if ex["are_answer"] is True]

    print(f"Totale esempi nel file di input: {len(examples)}")
    print(f"Esempi con 'are_answer=True': {len(valid_examples)}")

    # Raggruppa per numero di selected_documents
    grouped_examples = {
        "len_0": [ex for ex in valid_examples if len(ex["selected_documents"]) == 0],
        "len_1": [ex for ex in valid_examples if len(ex["selected_documents"]) == 1],
        "len_gt_1": [ex for ex in valid_examples if len(ex["selected_documents"]) > 1],
    }

    print(f"Esempi con 'selected_documents == 0': {len(grouped_examples['len_0'])}")
    print(f"Esempi con 'selected_documents == 1': {len(grouped_examples['len_1'])}")
    print(f"Esempi con 'selected_documents > 1': {len(grouped_examples['len_gt_1'])}")

    # Calcola le suddivisioni per ogni gruppo
    group_case_limits = {
        "case_1_single_doc": int(len(grouped_examples["len_1"]) * percentages["case_1_single_doc"]),
        "case_2_multiple_docs": int(len(grouped_examples["len_gt_1"]) * percentages["case_2_multiple_docs"]),
        "case_3_no_docs": int(len(grouped_examples["len_0"]) * percentages["case_3_no_docs"]),
        "case_4_multi_doc_unchanged": int(len(grouped_examples["len_gt_1"]) * percentages["case_4_multi_doc_unchanged"]),
        "case_5_reranking": int(len(grouped_examples["len_gt_1"]) * percentages["case_5_reranking"]),
        "case_6_single_doc_unchanged": int(len(grouped_examples["len_1"]) * percentages["case_6_single_doc_unchanged"]),
    }

    print("Distribuzione pianificata degli esempi nel dataset creato:")
    for case, limit in group_case_limits.items():
        print(f"{case}: {limit}")

    dataset = []
    case_counters = {case: 0 for case in group_case_limits}

    # Processa gli esempi
    for example in valid_examples:
        query = f"Task Instruction: {task_instruction}\nQuestion:{example['query']}"
        retrieved_docs = example["document_indices"]
        selected_docs = example["selected_documents"]

        # Case 1: Single document correct answer
        if len(selected_docs) == 1 and case_counters["case_1_single_doc"] < group_case_limits["case_1_single_doc"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": retrieved_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_1_single_doc"] += 1

        # Case 2: Multiple documents correct answer
        elif len(selected_docs) > 1 and case_counters["case_2_multiple_docs"] < group_case_limits["case_2_multiple_docs"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": retrieved_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_2_multiple_docs"] += 1

        # Case 3: No documents correct answer
        elif len(selected_docs) == 0 and case_counters["case_3_no_docs"] < group_case_limits["case_3_no_docs"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": [],
                },
                "output": [],
            })
            case_counters["case_3_no_docs"] += 1

        # Case 4: Input and output unchanged for multiple docs
        elif len(selected_docs) > 1 and case_counters["case_4_multi_doc_unchanged"] < group_case_limits["case_4_multi_doc_unchanged"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": selected_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_4_multi_doc_unchanged"] += 1

        # Case 6: Input and output unchanged for single doc
        elif len(selected_docs) == 1 and case_counters["case_6_single_doc_unchanged"] < group_case_limits["case_6_single_doc_unchanged"]:
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": selected_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_6_single_doc_unchanged"] += 1

        # Case 5: Reranking
        elif len(selected_docs) > 1 and case_counters["case_5_reranking"] < group_case_limits["case_5_reranking"]:
            reranked_docs = selected_docs[:]
            while reranked_docs == selected_docs:  # Garantisce che l'ordine sia diverso
                reranked_docs = random.sample(selected_docs, len(selected_docs))
            dataset.append({
                "input": {
                    "query": query,
                    "retrieved_docs": reranked_docs,
                },
                "output": selected_docs,
            })
            case_counters["case_5_reranking"] += 1

    print("Esempi effettivamente inclusi nel dataset creato:")
    tot=0
    for case, count in case_counters.items():
        tot += count
        print(f"{case}: {count}")

    print(f"Totale degli Esempi inclusi nel training dataset creato: {tot}")
    

    # Save the dataset to a file
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(dataset, f, indent=4)

# Path to input and output files
input_file = r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_ids_document_training_set_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc\numdoc5_retr5_info_all_extended_training_set.json'
output_file = r'C:\Users\franc\Documents\Bridge_the_GAP\data\training_dataset.json'

process_data(input_file, output_file)

Totale esempi nel file di input: 3000
Esempi con 'are_answer=True': 1233
Esempi con 'selected_documents == 0': 366
Esempi con 'selected_documents == 1': 736
Esempi con 'selected_documents > 1': 131
Distribuzione pianificata degli esempi nel dataset creato:
case_1_single_doc: 51
case_2_multiple_docs: 52
case_3_no_docs: 36
case_4_multi_doc_unchanged: 45
case_5_reranking: 65
case_6_single_doc_unchanged: 36
Esempi effettivamente inclusi nel dataset creato:
case_1_single_doc: 51
case_2_multiple_docs: 52
case_3_no_docs: 36
case_4_multi_doc_unchanged: 45
case_5_reranking: 34
case_6_single_doc_unchanged: 36
Totale degli Esempi inclusi nel training dataset creato: 254


Test del BGM Model

In [None]:
import os
import re 
import argparse
import warnings
from tqdm import tqdm
from typing import Tuple, Dict, Optional, Union, List

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer, AutoModelForCausalLM
from trl import setup_chat_format

from utils import *
from bgm import BGM
from default_prompts import *
from prompt_dataset import PromptDataset
from datasets import Dataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings('ignore')
SEED=10

info = {
    "nq": {
        "test": {
            "data_path": r'C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json',
            "contriever_search_results_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\processed\contriever_test_search_results_at150.pkl",
        }
    },
}

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):
    """
    Mimics argparse to parse arguments for LLM generation. Accepts custom arguments as a dictionary for notebooks.
    """
    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_id_res_example_bgm',
        'llm_id': 'meta-llama/Llama-3.2-1B',
        'dataset': 'nq',
        'task_instruction' : "Output only the document IDs relevant to the query. Use this format: [Id_1, Id_2, ...].",
        'model_max_length': 4096,
        'quantization_bits': 4,
        'gold_position': None,
        'use_model_chat_template': False, 
        'num_retrieved_documents': 5,
        'use_test': True,
        'padding_strategy': 'longest',
        'max_new_tokens': 15,
        'use_task_with_proof': False,
        'batch_size': None,
        'save_every': 250,
    }

    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    # Perform validation
    if default_args['num_retrieved_documents'] is None:
        raise ValueError("'num_retrieved_documents' must be specified.")
    if default_args['num_retrieved_documents'] <= 0:
        raise ValueError("'num_retrieved_documents' must be a positive integer.")
    if default_args['gold_position'] is not None:
        if (default_args['gold_position'] < 0 or 
            default_args['gold_position'] >= default_args['num_retrieved_documents']):
            raise ValueError("'gold_position' must be within the range of 'num_retrieved_documents'.")

    return DotDict(**default_args)


def load_corpus(
    args: argparse.Namespace
) -> Tuple[List[Dict], Optional[Dict[int, int]]]:
    
    # Corpus with documents from Contriever
    corpus, full_to_subset_idx_map = read_test_corpus_with_random_and_contriever()

    return corpus, full_to_subset_idx_map

def load_search_results(args: argparse.Namespace) -> List[Tuple[List[int], List[float]]]:

    search_results_path = info[args.dataset][args.split]['contriever_search_results_path']
    retriever_search_results = read_pickle(search_results_path)

    return retriever_search_results


def get_prompt_template(args: argparse.Namespace):
    prompt_configuration = args.dataset
    if args.use_model_chat_template:
        chat_task_template_str = chat_task_templates[args.llm_id]['template']
        
        task_instruction = task_instructions[prompt_configuration]

        prompt_template = apply_chat_task_template(chat_task_template_str, task_instruction)
    else:
        task_template = task_templates[prompt_configuration]

        prompt_template = task_template.create_prompt_template()

    return prompt_template

def process_dataset(dataset, task_instruction):
    """
    Processes the dataset by applying the chat template transformation.

    Args:
        dataset (List[Dict]): The dataset to be processed.

    Returns:
        List[Dict]: The processed dataset with formatted text.
    """

    # Define the chat messages format
    messages = [
        {"role": "system", "content": task_instruction},
        {"role": "user", "content": dataset},
    ]

    return messages

def initialize_dataset_and_loader(
    args: argparse.Namespace, 
    corpus: List[Dict], 
    full_to_subset_idx_map: Optional[Dict[int, int]], 
    retriever_search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer,
) -> DataLoader:
    
    prompt_template = get_prompt_template(args)
    
    prompt_ds = PromptDataset(
        corpus=corpus, data_path=info[args.dataset][args.split]['data_path'], 
        tokenizer=tokenizer, 
        max_tokenized_length=args.model_max_length - 2, 
        search_results=retriever_search_results,
        prompt_template=prompt_template,
        full_to_subset_idx_map=full_to_subset_idx_map,
        do_normalize_query=True, 
        num_documents_in_context=args.num_retrieved_documents,
        gold_position=args.gold_position, # None in these experiments
    )
        
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_dataloader


def check_documents_before_model(
    document_indices: List[str],
    answers: List[str],
    prompt: str
) -> Tuple[bool, List[str]]:
    """
    Controlla se almeno uno dei documenti passati al modello contiene una delle risposte.

    Args:
        document_indices (List[str]): Lista degli ID dei documenti passati al modello.
        answers (List[str]): Lista delle risposte da cercare.
        prompt (str): Testo completo del prompt.

    Returns:
        Tuple[bool, List[str]]:
            - True se almeno un documento contiene una risposta.
            - Lista degli ID dei documenti con risposta trovata.
    """
    # Usa regex per estrarre i documenti dal prompt
    documents = re.findall(r"Document \[(\d+)\]\(.*?\)\s(.*?)(?=Document \[\d+\]|$)", prompt, re.DOTALL)
    documents_dict = {doc_id: text.strip() for doc_id, text in documents}

    document_indices = [str(doc).strip() for doc in document_indices]

    documents_with_answer = []

    for id_doc in document_indices: 
        if id_doc in documents_dict:
            document_text = documents_dict[id_doc]
            for answer in answers:
                if answer.lower() in document_text.lower():
                    documents_with_answer.append(id_doc)
                    break  # Non serve controllare altre risposte per questo documento

    if documents_with_answer:
        return True, documents_with_answer

    return False, []


def check_document_with_answer(
    original_ids: str,
    answers: List[str],
    prompt: str
) -> Tuple[bool, Optional[str], List[str], List[str]]:
    """
    Verifica se almeno un documento tra quelli con ID specificati contiene una delle risposte.

    Args:
        original_ids (str): Stringa di ID originali separati da virgola.
        answers (List[str]): Lista delle risposte da cercare.
        prompt (str): Testo completo del prompt.

    Returns:
        Tuple[bool, Optional[str], List[str], List[str]]:
            - True se almeno una risposta √® stata trovata.
            - La prima risposta trovata o None.
            - Lista degli ID dei documenti che contengono la risposta.
            - Lista di tutti gli ID generati.
    """
    # Usa regex per estrarre i documenti dal prompt
    documents = re.findall(r"Document \[(\d+)\]\(.*?\)\s(.*?)(?=Document \[\d+\]|$)", prompt, re.DOTALL)
    documents_dict = {doc_id: text.strip() for doc_id, text in documents}

    # Estrae e pulisce gli ID dalla stringa
    ids_to_check = [id_.strip() for id_ in original_ids.split(",") if id_.strip()]

    # Lista di documenti contenenti la risposta
    documents_with_answer = []

    first_answer = None

    # Controlla ogni documento indicato
    for id_doc in ids_to_check:
        if id_doc in documents_dict:
            document_text = documents_dict[id_doc]
            for answer in answers:
                if answer.lower() in document_text.lower():
                    documents_with_answer.append(id_doc)
                    if not first_answer:
                        first_answer = answer

    if documents_with_answer:
        return True, first_answer, documents_with_answer, ids_to_check

    # Se non trova alcuna risposta
    return False, None, [], ids_to_check

    
def map_document_indices(prompt: str, document_indices: list) -> tuple:
    """
    Modifica il prompt mappando i document ID originali con ID sequenziali.
    
    Args:
        prompt (str): Il prompt originale contenente i document ID.
        document_indices (list): Lista degli indici dei documenti nel prompt.

    Returns:
        tuple: Il prompt modificato e la mappatura (original_id -> new_id).
    """
    id_mapping = {}
    
    # Mappa i document ID ai nuovi ID sequenziali (Id_1, Id_2, ...)
    for idx, doc_id in enumerate(document_indices, start=1):
        new_id = f"Id_{idx}"
        id_mapping[str(doc_id)] = new_id
    
    # Sostituisce i document ID nel prompt
    modified_prompt = prompt
    for original_id, new_id in id_mapping.items():
        modified_prompt = re.sub(rf'\[{original_id}\]', f'[{new_id}]', modified_prompt)
    
    return modified_prompt, id_mapping

def extract_and_convert_answer_indices(generated_output: str, id_mapping: dict) -> str:
    """
    Estrae e converte gli ID generati dal modello dopo '<|im_start|>assistant'.
    Converte gli ID nel formato originale. Se non trova un ID nella mappatura, restituisce 'Unknown(ID)'.

    Args:
        generated_output (str): Testo con la risposta generata dal modello.
        id_mapping (dict): Mappatura {original_id: Id_n}.

    Returns:
        str: Stringa con gli ID originali separati da virgola.
    """
    # Invertire la mappatura per ottenere {Id_n: original_id}
    inverse_mapping = {v: k for k, v in id_mapping.items()}

    # Estrae la risposta dopo '<|im_start|>assistant'
    match = re.search(r'<\|im_start\|>assistant\s*(.*)', generated_output, re.DOTALL)
    if not match:
        return "Nessuna risposta trovata"
    #if match:
        #print("Contenuto dopo assistant:", repr(match.group(1)))

    # Ottieni la stringa con gli ID dopo assistant
    answer_string = match.group(1).strip().split("<|im_end|>")[0].strip()

    # Dividi e converte gli ID
    generated_ids = [id_.strip() for id_ in answer_string.split(",") if id_.strip()]
    original_ids = [inverse_mapping.get(id_, f"Unknown({id_})") for id_ in generated_ids]

    # Restituisci gli ID originali come stringa separata da virgole
    return ",".join(original_ids)
    

def print_info(args: argparse.Namespace):
    print("INFO:")    
    print(f"DATA: {info[args.dataset]['test']['data_path']}")
    print(f"TASK INSTRUCTION: {args.task_instruction}")
    print(f"USE TEST: {args.use_test}")
    print(f"MODEL: {args.llm_id}")
    print(f"MODEL MAX LENGTH: {args.model_max_length}")
    print(f'MAX NEW TOKENS: {args.max_new_tokens}')
    print(f"USE MODEL CHAT TEMPLATE: {args.use_model_chat_template}")
    print(f"TASK WITH PROOF:", args.use_task_with_proof)
    print(f"GOLD POSITION: {args.gold_position}")
    print(f"NUM DOCUMENTS IN CONTEXT: {args.num_retrieved_documents}")
    print(f"BATCH SIZE: {args.batch_size}")
    print(f"SAVE EVERY: {args.save_every}")


def generate_and_save(
    args: argparse.Namespace, 
    model_weights_path, 
    tokenizer,
    dataset,
    num_examples=10,
    max_length=50
):
    # Info from arguments
    llm_id = args.llm_id
    num_doc = args.num_retrieved_documents
    save_every = args.save_every 

    # Create the saving directory
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{args.output_dir}/{args.dataset}/{llm_folder}/{args.split}/{num_doc}_doc"
    os.makedirs(saving_dir, exist_ok=True)

    # Path del file .json
    json_file_path = os.path.join(saving_dir, "generated_results_with_best_weights.json")

    # Load the trained model
    model = AutoModelForCausalLM.from_pretrained(model_weights_path)
    model.to(device)
    model.eval()

    all_info = []  
    for idx, prompt_batch in enumerate(tqdm(dataset)):
        if idx >= num_examples:
            break

        prompt = prompt_batch['prompt'].replace('You are given a question and you must respond based on the provided documents. You must always provide an answer.', "")
        document_indices= prompt_batch['document_indices']
        answers = prompt_batch['answers']

        # Mappa i document ID nel prompt
        modified_prompt, id_mapping = map_document_indices(prompt, document_indices)
        print(f"Processing Example {idx+1}:\n")
        print("Mappatura ID:", id_mapping)

        # Controllo se i 5 documenti passati al modello contengono almeno la risposta
        has_answer_before, documents_with_answer_before = check_documents_before_model(
            document_indices, answers, prompt
        )

        if has_answer_before:
            print(f"‚úÖ I documenti passati al modello contengono risposte. ID documenti: {documents_with_answer_before}")
        else:
            print("‚ùå I documenti passati al modello NON contengono alcuna risposta.")

        prompt_formatted = process_dataset(modified_prompt, args.task_instruction)
        prompt_formatted = tokenizer.apply_chat_template(
            prompt_formatted, tokenize=False, add_generation_prompt=False, add_special_tokens=False
        )

        #print(f"Formatted Chat:\n{prompt}\n")

        # Tokenize the input
        inputs = tokenizer(prompt_formatted, return_tensors="pt", truncation=False).to(device)

        # Generate the model's response
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=max_length, num_beams=5)

        # Decode the generated response
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

        # Convertire gli ID generati nei document ID originali
        original_ids = extract_and_convert_answer_indices(generated_text, id_mapping)
        print(f"I migliori indici secondo il modello: {original_ids}\n")

        # Verifica la presenza della risposta nei documenti corrispondenti
        has_answer_after, found_answer, documents_with_answer_after, all_generated_ids = check_document_with_answer(
            original_ids, answers, prompt
        )

        if has_answer_after:
            print(f"‚úÖ Risposta trovata nel documento {documents_with_answer_after}")
            print(f"üìå Prima risposta trovata: '{found_answer}'")
        else:
            print("‚ùå Nessuna risposta trovata in nessuno dei documenti generati.")
            print(f"La risposta era {answers}")

        result = {
            "prompt": prompt,
            "all_document_indices": document_indices,
            "generated_indices": original_ids,
            "has_answer_in_passed_documents": has_answer_before,
            "documents_with_answer_in_passed_documents": documents_with_answer_before,
            "generated_id_document_has_answer": has_answer_after,
            "answer_in_the_document": found_answer,
            "documents_with_answer_in_generated": documents_with_answer_after,
            "answers_target": answers
        }
        all_info.append(result)
        
    with open(json_file_path, "w", encoding="utf-8") as json_file:
        json.dump(all_info, json_file, indent=4, ensure_ascii=False)

    print(f"Risultati salvati in: {json_file_path}")


def main():
    args = parse_arguments()

    args.split = "test" if args.use_test else "train"

    print("Loading LLM...")
    llm_id = args.llm_id
    
    bgm = BGM(
        llm_id, device, 
        quantization_bits=args.quantization_bits, 
        model_max_length=args.model_max_length,
    )
    model= bgm.model

    tokenizer = bgm.tokenizer
    model, tokenizer = setup_chat_format(model, tokenizer)
    print("LLM loaded")

    task_instruction = args.task_instruction

    print("Loading corpus and search results...")
    corpus, full_to_subset_idx_map = load_corpus(args)
    retriever_search_results = load_search_results(args)
    print("Corpus and search results loaded")


    print("Loading prompt dataset...")
    prompt_dataloader = initialize_dataset_and_loader(
        args, corpus, full_to_subset_idx_map, 
        retriever_search_results, tokenizer
    )
    print("Prompt dataset loaded")

    print_info(args)

    training_path=r'C:\Users\franc\Documents\Bridge_the_GAP\data\SFT_training_bgm\meta-llama-Llama-3.2-1B\checkpoint-800'  #best checkpoint 700-800
    generate_and_save(args, training_path, tokenizer, prompt_dataloader, num_examples=100, max_length=15)



if __name__ == "__main__":
    seed_everything(SEED)
    main()

Loading LLM...
LLM loaded
Loading corpus and search results...
Corpus and search results loaded
Loading prompt dataset...
Prompt dataset loaded
INFO:
DATA: C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json
TASK INSTRUCTION: Output only the document IDs relevant to the query. Use this format: [Id_1, Id_2, ...].
USE TEST: True
MODEL: meta-llama/Llama-3.2-1B
MODEL MAX LENGTH: 4096
MAX NEW TOKENS: 15
USE MODEL CHAT TEMPLATE: False
TASK WITH PROOF: False
GOLD POSITION: None
NUM DOCUMENTS IN CONTEXT: 5
BATCH SIZE: None
SAVE EVERY: 250


  0%|          | 0/2889 [00:53<?, ?it/s]


KeyboardInterrupt: 