In [1]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_name(1))

True
2
NVIDIA GeForce RTX 3080
NVIDIA GeForce RTX 3060


In [3]:
from utils import *

def analyze_json_responses(my_path):
    data = read_json(my_path)  # Legge il file JSON

    total_examples = len(data)  # Numero totale di esempi
    correct_with_context = 0  # Risposte corrette con risposta nel contesto
    correct_without_context = 0  # Risposte corrette senza risposta nel contesto
    has_answer_in_context = 0  # Esempi con risposta nel contesto
    no_answer_in_context = 0  # Esempi senza risposta nel contesto
    total_correct = 0  # Totale delle risposte corrette

    for entry in data:
        ans_in_documents = entry.get("ans_in_documents", False)
        ans_match_after_norm = entry.get("ans_match_after_norm", False)

        if ans_in_documents:
            has_answer_in_context += 1
            if ans_match_after_norm:
                correct_with_context += 1
        else:
            no_answer_in_context += 1
            if ans_match_after_norm:
                correct_without_context += 1

        # Conta ogni risposta corretta
        if ans_match_after_norm:
            total_correct += 1

    # Calcola le medie
    avg_correct_with_context = correct_with_context / has_answer_in_context if has_answer_in_context > 0 else 0
    avg_correct_without_context = correct_without_context / no_answer_in_context if no_answer_in_context > 0 else 0
    overall_accuracy = total_correct / total_examples if total_examples > 0 else 0

    return {
        "total_examples": total_examples,
        "examples_with_answer_in_context": has_answer_in_context,
        "examples_without_answer_in_context": no_answer_in_context,
        "correct_with_context": correct_with_context,
        "correct_without_context": correct_without_context,
        "average_correct_with_context": avg_correct_with_context,
        "average_correct_without_context": avg_correct_without_context,
        "overall_accuracy": overall_accuracy
    }


# Esempio di utilizzo
path = r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm\nq\gemma-2-2b-it\test\retrieved\contriever\1_doc\numdoc1_retr1_template_info_all_extended.json'
result = analyze_json_responses(path)
print(result)

{'total_examples': 2889, 'examples_with_answer_in_context': 721, 'examples_without_answer_in_context': 2168, 'correct_with_context': 539, 'correct_without_context': 87, 'average_correct_with_context': 0.7475728155339806, 'average_correct_without_context': 0.04012915129151291, 'overall_accuracy': 0.2166839736933195}


In [4]:
from utils import *

path=r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_res_example_llm\nq\gemma-2-2b-it\test\retrieved\contriever\5_doc\numdoc5_retr5_template_info_250.pkl'

data_results=read_pickle(path)

data_results

[{'example_id': '-3290814144789249484',
  'query': 'who got the first nobel prize in physics',
  'prompt': '<start_of_turn>user\nYou are given a question and you must respond based on the provided documents. You must always provide an answer.\nDocuments:\nDocument [628506](Title: Nobel Prize in Physics) receive a diploma, a medal and a document confirming the prize amount. Nobel Prize in Physics The Nobel Prize in Physics () is a yearly award given by the Royal Swedish Academy of Sciences for those who have made the most outstanding contributions for mankind in the field of physics. It is one of the five Nobel Prizes established by the will of Alfred Nobel in 1895 and awarded since 1901; the others being the Nobel Prize in Chemistry, Nobel Prize in Literature, Nobel Peace Prize, and Nobel Prize in Physiology or Medicine. The first Nobel Prize in Physics was\nDocument [3546609](Title: E. C. George Sudarshan) had developed the breakthrough. In 2007, Sudarshan told the "Hindustan Times", 

In [None]:
import os
import argparse
import warnings
import re
from tqdm import tqdm
from typing import Tuple, Dict, Optional

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from utils import *
from bgm import BGM
from default_prompts import *
from prompt_dataset import PromptDataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings('ignore')
SEED=10

info = {
    "nq_bgm": {
        "test": {
            "data_path": r'C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json',
            "contriever_search_results_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\processed\contriever_test_search_results_at150.pkl",
        }
    },
}

def save_dataloader_to_json(dataloader, output_file, num_examples=15):
    all_batches = []

    print("Saving DataLoader contents to JSON...")
    for idx, batch in enumerate(dataloader):
        if idx >= num_examples:  # Stop after saving the specified number of examples
            break

        batch_dict = {}
        for key, value in batch.items():
            # Convert tensors to lists for JSON serialization
            if isinstance(value, torch.Tensor):
                batch_dict[key] = value.tolist()
            else:
                batch_dict[key] = value
        all_batches.append(batch_dict)
    
    # Save the entire list of dictionaries to a JSON file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(all_batches, f, ensure_ascii=False, indent=4)

    print(f"DataLoader contents saved to {output_file}")

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):
    """
    Mimics argparse to parse arguments for LLM generation. Accepts custom arguments as a dictionary for notebooks.
    """
    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_id_document_bgm',
        'llm_id': 'google/flan-t5-large',
        'dataset': 'nq_bgm',
        'model_max_length': 4096,
        'quantization_bits': 4,
        'use_model_chat_template': False, 
        'gold_position': None,
        'num_retrieved_documents': 3,
        'use_test': True,
        'max_new_tokens': 50,
        'use_task_with_proof': False,
        'batch_size': None,
        'save_every': 250,
    }

    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    # Perform validation
    if default_args['num_retrieved_documents'] is None:
        raise ValueError("'num_retrieved_documents' must be specified.")
    if default_args['num_retrieved_documents'] <= 0:
        raise ValueError("'num_retrieved_documents' must be a positive integer.")
    if default_args['gold_position'] is not None:
        if (default_args['gold_position'] < 0 or 
            default_args['gold_position'] >= default_args['num_retrieved_documents']):
            raise ValueError("'gold_position' must be within the range of 'num_retrieved_documents'.")

    return DotDict(**default_args)


def load_corpus(
    args: argparse.Namespace
) -> Tuple[List[Dict], Optional[Dict[int, int]]]:
    
    # Corpus with documents from Contriever
    corpus, full_to_subset_idx_map = read_test_corpus_with_random_and_contriever()

    return corpus, full_to_subset_idx_map

def load_search_results(args: argparse.Namespace) -> List[Tuple[List[int], List[float]]]:

    search_results_path = info[args.dataset][args.split]['contriever_search_results_path']
    retriever_search_results = read_pickle(search_results_path)

    return retriever_search_results


def get_prompt_template(args: argparse.Namespace):
    prompt_configuration = args.dataset
    if args.use_model_chat_template:
        chat_task_template_str = chat_task_templates[args.llm_id]['template']
        
        task_instruction = task_instructions[prompt_configuration]
        if args.use_task_with_proof:
            task_instruction = task_instructions['qa_proof'][prompt_configuration]

        prompt_template = apply_chat_task_template(chat_task_template_str, task_instruction)
    else:
        task_template = task_templates[prompt_configuration]

        if args.use_task_with_proof:
            task_template = task_templates['qa_proof'][prompt_configuration]

        prompt_template = task_template.create_prompt_template()

    return prompt_template


def initialize_dataset_and_loader(
    args: argparse.Namespace, 
    corpus: List[Dict], 
    full_to_subset_idx_map: Optional[Dict[int, int]], 
    retriever_search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer
) -> DataLoader:
    
    prompt_template = get_prompt_template(args)
    
    prompt_ds = PromptDataset(
        corpus=corpus, data_path=info[args.dataset][args.split]['data_path'], 
        tokenizer=tokenizer, 
        max_tokenized_length=args.model_max_length - 2, 
        search_results=retriever_search_results,
        prompt_template=prompt_template,
        full_to_subset_idx_map=full_to_subset_idx_map,
        do_normalize_query=True, 
        num_documents_in_context=args.num_retrieved_documents,
        gold_position=args.gold_position, # None in these experiments
    )
        
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_dataloader


def print_info(args: argparse.Namespace):
    print("INFO:")    
    print(f"DATA: {info[args.dataset]['test']['data_path']}")
    print(f"USE TEST: {args.use_test}")
    print(f"MODEL: {args.llm_id}")
    print(f"MODEL MAX LENGTH: {args.model_max_length}")
    print(f'MAX NEW TOKENS: {args.max_new_tokens}')
    print(f"USE MODEL CHAT TEMPLATE: {args.use_model_chat_template}")
    print(f"TASK WITH PROOF:", args.use_task_with_proof)
    print(f"GOLD POSITION: {args.gold_position}")
    print(f"NUM DOCUMENTS IN CONTEXT: {args.num_retrieved_documents}")
    print(f"BATCH SIZE: {args.batch_size}")
    print(f"SAVE EVERY: {args.save_every}")


def extract_generate_answers(
    args: argparse.Namespace, 
    generated_output: List[str]
) -> List[str]:
    answer_prefix = "Answer:"
    if args.use_model_chat_template:
        answer_prefix = re.escape(chat_task_templates[args.llm_id]['answer_prefix'])

    generated_answers = []
    for output in generated_output:
        matches = list(re.finditer(answer_prefix, output))
        match_idx = 0

        # When using the proof there is a one-shot example that already 
        # contains the string "Answer:". Thus, we should get the second (match_idx=1) match.
        if args.use_task_with_proof:
            match_idx = 1
            if args.use_model_chat_template and answer_prefix != "Answer:":
                match_idx = 0
 
        answer_end = matches[match_idx].end()
        response = output[answer_end:].strip()
        generated_answers.append(response)
    
    return generated_answers


def generate_and_save(
    args: argparse.Namespace, 
    llm: BGM, 
    prompt_dataloader: DataLoader
):
    # Info from arguments
    llm_id = args.llm_id
    num_doc = args.num_retrieved_documents
    save_every = args.save_every
    retriever_str = "contriever" 
    chat_template_str = "_template" if args.use_model_chat_template else ""
    prompt_type = "retrieved_proof" if args.use_task_with_proof else "retrieved"

    # Create the saving directory
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{args.output_dir}/{args.dataset}/{llm_folder}/{args.split}/{prompt_type}/{retriever_str}/{num_doc}_doc"
    os.makedirs(saving_dir, exist_ok=True)

    all_info = []  
    for idx, prompt_batch in enumerate(tqdm(prompt_dataloader)):
        prompts = prompt_batch['prompt']
        generated_output = llm.generate(
            prompts, 
            max_new_tokens=args.max_new_tokens
        )

        generated_answers = extract_generate_answers(args, generated_output)
        prompt_batch['generated_answer'] = generated_answers
        all_info.append(prompt_batch)
        
        if (idx + 1) % save_every == 0 or (idx + 1) == len(prompt_dataloader):
            print(f"Saving at {idx + 1}...")
            file_name = f"{saving_dir}/numdoc{num_doc}_retr{args.num_retrieved_documents}{chat_template_str}_info_{idx+1}.pkl"
            write_pickle(all_info, file_name)
            all_info = []


def main():
    args = parse_arguments()

    args.split = "test" if args.use_test else "train"

    print("Loading LLM...")
    llm_id = args.llm_id
    bgm = BGM(
        llm_id, device,  
        model_max_length=args.model_max_length
    )
    tokenizer = bgm.tokenizer
    print("LLM loaded")


    print("Loading corpus and search results...")
    corpus, full_to_subset_idx_map = load_corpus(args)
    retriever_search_results = load_search_results(args)
    print("Corpus and search results loaded")


    print("Loading prompt dataset...")
    prompt_dataloader = initialize_dataset_and_loader(
        args, corpus, full_to_subset_idx_map, 
        retriever_search_results, tokenizer
    )
    print("Prompt dataset loaded")

    print_info(args)

    output_json_path = r'C:\Users\franc\Documents\Bridge_the_GAP\data\dataloader_contents.json'
    save_dataloader_to_json(prompt_dataloader, output_json_path, num_examples=15)
        
    #generate_and_save(args, llm, prompt_dataloader)



if __name__ == "__main__":
    seed_everything(SEED)
    main()

Loading LLM...
LLM loaded
Loading corpus and search results...
Corpus and search results loaded
Loading prompt dataset...
nq_bgm
<default_prompts.TaskTemplate object at 0x00000176F1E2B850>
Prompt dataset loaded
INFO:
DATA: C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json
USE TEST: True
MODEL: google/flan-t5-large
MODEL MAX LENGTH: 4096
MAX NEW TOKENS: 50
USE MODEL CHAT TEMPLATE: False
TASK WITH PROOF: False
GOLD POSITION: None
NUM DOCUMENTS IN CONTEXT: 3
BATCH SIZE: None
SAVE EVERY: 250
Saving DataLoader contents to JSON...


KeyboardInterrupt: 

In [10]:
import os
import argparse
import warnings
import re
import pandas as pd
from tqdm import tqdm
from typing import Tuple, Dict, Optional

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from utils import *
from llm import LLM
from default_prompts import *
from normalize_answers import *
from prompt_dataset import PromptDataset
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings('ignore')
SEED=10

info = {

    "nq": {
        "test": {
            "data_path": r'C:\Users\franc\Documents\Bridge_the_GAP\data\test_dataset.json',
            "contriever_search_results_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\processed\contriever_test_search_results_at150.pkl",
        },
    },
    "nq_training":{
        "train": {
            "data_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\10k_train_dataset.json",
            "contriever_search_results_path": r"C:\Users\franc\Documents\Bridge_the_GAP\data\processed\contriever_search_results_at150.pkl",
        }
    }
}



def save_dataloader_to_json(dataloader, output_file, num_examples=15):
    all_batches = []

    print("Saving DataLoader contents to JSON...")
    for idx, batch in enumerate(dataloader):
        if idx >= num_examples:  # Stop after saving the specified number of examples
            break

        batch_dict = {}
        for key, value in batch.items():
            # Convert tensors to lists for JSON serialization
            if isinstance(value, torch.Tensor):
                batch_dict[key] = value.tolist()
            else:
                batch_dict[key] = value
        all_batches.append(batch_dict)
    
    # Save the entire list of dictionaries to a JSON file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(all_batches, f, ensure_ascii=False, indent=4)

    print(f"DataLoader contents saved to {output_file}")

class DotDict:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments(custom_args=None):
    """
    Mimics argparse to parse arguments for LLM generation. Accepts custom arguments as a dictionary for notebooks.
    """
    # Define default values
    default_args = {
        'output_dir': r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_id_document_bgm',
        'llm_id': 'google/gemma-2-2b-it',
        'dataset': 'nq_training',
        'model_max_length': 4096,
        'quantization_bits': 4,
        'use_model_chat_template': False, 
        'gold_position': None,
        'num_retrieved_documents': 5,
        'use_test': False,
        'max_new_tokens': 50,
        'batch_size': None,
        'save_every': 250,
    }

    # If custom_args is provided, update defaults
    if custom_args:
        default_args.update(custom_args)

    # Perform validation
    if default_args['num_retrieved_documents'] is None:
        raise ValueError("'num_retrieved_documents' must be specified.")
    if default_args['num_retrieved_documents'] <= 0:
        raise ValueError("'num_retrieved_documents' must be a positive integer.")
    if default_args['gold_position'] is not None:
        if (default_args['gold_position'] < 0 or 
            default_args['gold_position'] >= default_args['num_retrieved_documents']):
            raise ValueError("'gold_position' must be within the range of 'num_retrieved_documents'.")

    return DotDict(**default_args)


def load_corpus(
    args: argparse.Namespace
) -> Tuple[List[Dict], Optional[Dict[int, int]]]:
    
    # Corpus with documents from Contriever
    corpus, full_to_subset_idx_map = read_corpus_with_contriever()

    return corpus, full_to_subset_idx_map

def load_search_results(args: argparse.Namespace) -> List[Tuple[List[int], List[float]]]:

    search_results_path = info[args.dataset][args.split]['contriever_search_results_path']
    retriever_search_results = read_pickle(search_results_path)

    return retriever_search_results


def get_prompt_template(args: argparse.Namespace):
    prompt_configuration = args.dataset
    if args.use_model_chat_template:
        chat_task_template_str = chat_task_templates[args.llm_id]['template']
        
        task_instruction = task_instructions[prompt_configuration]

        prompt_template = apply_chat_task_template(chat_task_template_str, task_instruction)
    else:
        task_template = task_templates[prompt_configuration]


        prompt_template = task_template.create_prompt_template()

    return prompt_template


def initialize_dataset_and_loader(
    args: argparse.Namespace, 
    corpus: List[Dict], 
    full_to_subset_idx_map: Optional[Dict[int, int]], 
    retriever_search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer
) -> Tuple[PromptDataset, DataLoader]:
    
    prompt_template = get_prompt_template(args)
    
    prompt_ds = PromptDataset(
        corpus=corpus, data_path=info[args.dataset][args.split]['data_path'], 
        tokenizer=tokenizer, 
        max_tokenized_length=args.model_max_length - 2, 
        search_results=retriever_search_results,
        prompt_template=prompt_template,
        full_to_subset_idx_map=full_to_subset_idx_map,
        do_normalize_query=True, 
        num_documents_in_context=args.num_retrieved_documents,
        gold_position=args.gold_position, # None in these experiments
    )
        
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_ds, prompt_dataloader


def print_info(args: argparse.Namespace):
    print("INFO:")    
    print(f"DATA: {info[args.dataset][args.split]['data_path']}")
    print(f"USE TEST: {args.use_test}")
    print(f"MODEL: {args.llm_id}")
    print(f"MODEL MAX LENGTH: {args.model_max_length}")
    print(f'MAX NEW TOKENS: {args.max_new_tokens}')
    print(f"USE MODEL CHAT TEMPLATE: {args.use_model_chat_template}")
    print(f"GOLD POSITION: {args.gold_position}")
    print(f"NUM DOCUMENTS IN CONTEXT: {args.num_retrieved_documents}")
    print(f"BATCH SIZE: {args.batch_size}")
    print(f"SAVE EVERY: {args.save_every}")


def extract_generate_answers(
    args: argparse.Namespace, 
    generated_output: List[str]
) -> List[str]:
    answer_prefix = "Answer:"
    if args.use_model_chat_template:
        answer_prefix = re.escape(chat_task_templates[args.llm_id]['answer_prefix'])

    generated_answers = []
    for output in generated_output:
        matches = list(re.finditer(answer_prefix, output))
        match_idx = 0

        # When using the proof there is a one-shot example that already 
        # contains the string "Answer:". Thus, we should get the second (match_idx=1) match.
        if args.use_model_chat_template and answer_prefix != "Answer:":
            match_idx = 0
 
        answer_end = matches[match_idx].end()
        response = output[answer_end:].strip()
        generated_answers.append(response)
    
    return generated_answers


def are_answers_matching(prediction: str, ground_truths: List[str]) -> float:
    normalized_prediction = normalize_answer(prediction)

    for ground_truth in ground_truths:
        normalized_ground_truth = normalize_answer(ground_truth)
        if normalized_ground_truth in normalized_prediction:
            return True
    return False


def calculate_bleu(generated_answer, reference_answers):
    """
    Calcola il BLEU score di una risposta generata rispetto a una lista di risposte corrette.

    Args:
        generated_answer (str): La risposta generata dal modello.
        reference_answers (list): Lista di risposte corrette (ogni elemento è una stringa).

    Returns:
        float: Il valore del BLEU score.
    """
    # Tokenizza la risposta generata
    generated_tokens = generated_answer.split()
    
    # Tokenizza le risposte di riferimento
    reference_tokens = [ref.split() for ref in reference_answers]
    
    # Calcolo del BLEU con un smoothing per gestire casi senza corrispondenze
    smoothing_function = SmoothingFunction().method1
    bleu_score = sentence_bleu(reference_tokens, generated_tokens, smoothing_function=smoothing_function)
    
    return bleu_score

def evaluate_accuracy(args: argparse.Namespace, llm, candidate_prompt, answers: List[str], max_new_tokens=50):
    """
    Valuta la generazioner del modello.
    Per semplicità, controlla se la risposta generata dal modello matcha con quella dell'esempio.
    """
    generated_answer = llm.generate(candidate_prompt, max_new_tokens=max_new_tokens)
    response_text = extract_generate_answers(args, generated_answer)
    response_text = response_text[0].split('\n', 1)[0]
    accuracy_answer = are_answers_matching(response_text, answers) if answers else False

    return accuracy_answer, response_text

### DA QUI INIZIALIZZARE L'ALGORITMO DI GENERAZIONE ###
def generate_and_save(
    args: argparse.Namespace, 
    llm: LLM,
    prompt_ds: PromptDataset,
    df: pd.DataFrame,
    prompt_dataloader: DataLoader
):
    # Info from arguments
    max_new_tokens=args.max_new_tokens
    llm_id = args.llm_id
    num_doc = args.num_retrieved_documents
    save_every = args.save_every
    retriever_str = "contriever" 
    chat_template_str = "_template" if args.use_model_chat_template else ""
    prompt_type = "retrieved"

    

    # Create the saving directory
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{args.output_dir}/{args.dataset}/{llm_folder}/{args.split}/{prompt_type}/{retriever_str}/{num_doc}_doc"
    os.makedirs(saving_dir, exist_ok=True)

    ### ALGORITMO DI GENERAZIONE ###
    max_examples = 3000 
    all_info = [] 
    for idx, prompt_batch in enumerate(tqdm(prompt_dataloader)):

        if idx >= max_examples:
            break

        example_id = prompt_batch['example_id']
        prompts = prompt_batch['prompt']
        query = prompt_batch['query']
        document_indices=prompt_batch['document_indices']

        #estraggo la lista di answers per quell'esempio
        answers = df[df['example_id'].astype(str) == str(example_id)].answers.iloc[0]

        # Informazioni da salvare
        query_info = {
            "query": query,
            "document_indices": document_indices,
            "selected_documents": [],
            "generated_responses": []
        }
        
        # Inizializza la sequenza ottimizzata
        d_silv = []
        d_silv_are_answer = False  # Punteggio iniziale neutrale
        d_silv_score = 0
        
        best_candidate = None
        best_candidate_are_answer = False  # Mantieni il punteggio corrente come riferimento
        best_candidate_score = 0
        
        while True: 
            # Valuta i documenti non ancora selezionati
            for doc_idx in document_indices:
                if doc_idx in d_silv:
                    continue  # Salta i documenti già selezionati

                # Crea un candidato aggiungendo il documento alla sequenza corrente
                candidate_docs = d_silv + [doc_idx]
                
                #print(f'Documenti in d_silv: {d_silv}')

                #print(f'Documenti candidati: {candidate_docs}')

                # Salta la funzione `_get_documents_from_indices` se non ci sono documenti
                if not candidate_docs:
                    formatted_docs = []

                else: 
                    formatted_docs, _ = prompt_ds._get_documents_from_indices(candidate_docs)

                # Crea il prompt candidato
                candidate_prompt = prompts
                # Aggiungi i documenti solo se non sono vuoti
                if formatted_docs:
                    candidate_prompt += "\n".join(formatted_docs)  #concatena il prompt di ogni esempio con ogni documento in input in quel momento

                if '\nAnswer:' not in candidate_prompt:
                    candidate_prompt += '\nAnswer:'

                #print(f'Prompt corrente: {candidate_prompt}')

                # Verifica se la risposta matcha con quella corretta dell'esempio in questione
                cur_are_answer, generated_answer = evaluate_accuracy(args, llm, candidate_prompt, answers, max_new_tokens)

                #print(f'Le risposte corrette sono{answers}')
                #print(f"Ce' la risposta in quella generata dal modello? : {cur_are_answer}")
                
                cur_score = calculate_bleu(generated_answer, answers)

                # Salva i dettagli della generazione
                query_info["generated_responses"].append({
                    "candidate_documents": candidate_docs,
                    "generated_answer": generated_answer,
                    "is_correct": cur_are_answer,
                    "bleu_score": cur_score
                })

                # Aggiorna il miglior candidato se il punteggio migliora
                if cur_are_answer == True: 
                    if best_candidate_are_answer == True:
                    
                        if cur_score > best_candidate_score:
                            best_candidate = doc_idx
                            best_candidate_are_answer = cur_are_answer
                            best_candidate_score = cur_score

                            #print(f"La risposta con questo documento {best_candidate} c'e'? {best_candidate_are_answer} e ha questa accuracy {best_candidate_score}")
                
                    else: 
                        best_candidate = doc_idx
                        best_candidate_are_answer = cur_are_answer
                        best_candidate_score = cur_score
                        #print(f"La risposta con questo documento {best_candidate} c'e'? {best_candidate_are_answer} e ha questa accuracy {best_candidate_score}\n")
                
                #else:
                    #print(f"La risposta con il documento {candidate_docs} non e'stata trovata\n")

            
            # Aggiungi il miglior candidato alla sequenza se migliora il punteggio
            if best_candidate is not None and best_candidate not in d_silv and best_candidate_are_answer == True:

                if d_silv_are_answer == False:
                    d_silv.append(best_candidate)
                    d_silv_are_answer = best_candidate_are_answer
                    d_silv_score = best_candidate_score

                    #print(f"\nSono nel secondo If e la risposta con questi documenti {d_silv} c'e'? {d_silv_are_answer} e ha questa accuracy {d_silv_score}\n")

                else:
                    if best_candidate_score > d_silv_score:
                        d_silv.append(best_candidate)
                        d_silv_score = best_candidate_score
                        d_silv_are_answer = best_candidate_are_answer
                    
                    #print(f"\nSono nel secondo If e la risposta con questi documenti {d_silv} c'e'? {d_silv_are_answer} e ha questa accuracy {d_silv_score}\n")
                
            else:
                #print(f"\nNon c'e' piu' nessun documento che migliora la risposta, la risposta migliore e' con questi documenti {d_silv} e ha questa accuracy {d_silv_score}\n")
                break  # Interrompi se nessun documento migliora il punteggio
        
        # Aggiorna i documenti selezionati finali
        query_info["selected_documents"] = d_silv

        # Aggiungi le informazioni del batch corrente a `all_info`
        all_info.append(query_info)
            

        if (idx + 1) % save_every == 0 or (idx + 1) == len(prompt_dataloader):
            print(f"Saving at {idx + 1}...")
            file_name = f"{saving_dir}/numdoc{num_doc}_retr{args.num_retrieved_documents}{chat_template_str}_info_{idx+1}.pkl"
            write_pickle(all_info, file_name)
            all_info = [] 
            


def main():
    args = parse_arguments()

    args.split = "test" if args.use_test else "train"

    print("Loading LLM...")
    llm_id = args.llm_id
    llm = LLM(
        llm_id, device,  
        model_max_length=args.model_max_length
    )
    tokenizer = llm.tokenizer
    print("LLM loaded")


    print("Loading corpus and search results...")
    corpus, full_to_subset_idx_map = load_corpus(args)
    retriever_search_results = load_search_results(args)
    print("Corpus and search results loaded")


    print("Loading prompt dataset...")
    prompt_ds, prompt_dataloader = initialize_dataset_and_loader(
        args, corpus, full_to_subset_idx_map, 
        retriever_search_results, tokenizer
    )
    print("Prompt dataset loaded")

    print_info(args)

    df = pd.read_json(info[args.dataset][args.split]['data_path'], dtype={'example_id': str})
    
    #output_json_path = r'C:\Users\franc\Documents\Bridge_the_GAP\data\dataloader_contents.json'
    #save_dataloader_to_json(prompt_dataloader, output_json_path, num_examples=15)
        
    generate_and_save(args, llm, prompt_ds, df, prompt_dataloader)



if __name__ == "__main__":
    seed_everything(SEED)
    main()

Loading LLM...


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.25s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


LLM loaded
Loading corpus and search results...
Corpus and search results loaded
Loading prompt dataset...
Prompt dataset loaded
INFO:
DATA: C:\Users\franc\Documents\Bridge_the_GAP\data\10k_train_dataset.json
USE TEST: False
MODEL: google/gemma-2-2b-it
MODEL MAX LENGTH: 4096
MAX NEW TOKENS: 50
USE MODEL CHAT TEMPLATE: False
GOLD POSITION: None
NUM DOCUMENTS IN CONTEXT: 5
BATCH SIZE: None
SAVE EVERY: 250


  2%|▎         | 250/10000 [2:24:05<119:49:07, 44.24s/it]

Saving at 250...


  5%|▌         | 500/10000 [4:53:04<112:53:22, 42.78s/it]

Saving at 500...


  8%|▊         | 750/10000 [7:28:14<88:42:27, 34.52s/it] 

Saving at 750...


 10%|█         | 1000/10000 [10:19:30<74:04:02, 29.63s/it]

Saving at 1000...


 12%|█▎        | 1250/10000 [12:35:54<55:45:32, 22.94s/it] 

Saving at 1250...


 15%|█▌        | 1500/10000 [14:48:06<64:28:00, 27.30s/it] 

Saving at 1500...


 18%|█▊        | 1750/10000 [17:00:13<62:41:52, 27.36s/it] 

Saving at 1750...


 20%|██        | 2000/10000 [19:10:28<81:52:50, 36.85s/it] 

Saving at 2000...


 22%|██▎       | 2250/10000 [21:16:22<87:18:32, 40.56s/it] 

Saving at 2250...


 25%|██▌       | 2500/10000 [23:28:06<70:47:55, 33.98s/it] 

Saving at 2500...


 28%|██▊       | 2750/10000 [25:40:02<67:37:13, 33.58s/it] 

Saving at 2750...


 30%|███       | 3000/10000 [27:55:07<58:47:26, 30.24s/it] 

Saving at 3000...


 30%|███       | 3000/10000 [27:55:11<65:08:47, 33.50s/it]


In [17]:
import os
import pickle
import re
import json

def extract_number(filename):
    # Estrai il numero usando una regex
    match = re.search(r'_(\d+)\.pkl$', filename)
    return int(match.group(1)) if match else float('inf')

def process_pkl_files(input_folder):
    combined_data = []
    filtered_data = []
    count_example = 0
    correct_count = 0
    multiple_selected_count = 0

    output_file = os.path.join(input_folder, 'numdoc5_retr5_template_info_all.json')
    filtered_output_file = os.path.join(input_folder, 'numdoc5_retr5_template_info_all_extended.json')

    # Elenco dei file .pkl ordinati per nome
    pkl_files = sorted([filename for filename in os.listdir(input_folder) if filename.endswith('.pkl')], key=extract_number)

    print("Trovati i seguenti file .pkl:")
    for filename in pkl_files:
        print(filename)

    # Leggi tutti i file .pkl nella cartella
    for filename in pkl_files:
        file_path = os.path.join(input_folder, filename)
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
            combined_data.extend(data)

    for example in combined_data:
        # Controlla se almeno una risposta ha 'is_correct' = True
        is_correct = any(response.get('is_correct') for response in example.get('generated_responses', []))
        if is_correct:
            correct_count += 1

        # Controlla se ci sono più di un indice in 'selected_documents'
        selected_documents = example.get('selected_documents', [])
        if len(selected_documents) > 1:
            multiple_selected_count += 1

        # Trova la risposta corretta con il punteggio BLEU più alto
        best_answer = None
        best_bleu_score = 0
        for response in example.get('generated_responses', []):
            if response.get('is_correct') and response.get('bleu_score', 0) > best_bleu_score:
                best_bleu_score = response['bleu_score']
                best_answer = response.get('generated_answer', None)

        filtered_data.append({
            'query': example.get('query'),
            'are_answer': is_correct,
            'generated_answer': best_answer,
            'selected_documents': selected_documents,
            'number_documents_selected': len(selected_documents)
        })

        count_example += 1

    # Salva i dati combinati in un file JSON
    with open(output_file, 'w') as f:
        json.dump(combined_data, f, indent=4)

    # Salva i dati filtrati in un file JSON
    with open(filtered_output_file, 'w') as f:
        json.dump(filtered_data, f, indent=4)

    return correct_count, multiple_selected_count, count_example

# Utilizzo
input_folder = r'C:\Users\franc\Documents\Bridge_the_GAP\data\gen_id_document_bgm\nq_training\gemma-2-2b-it\train\retrieved\contriever\5_doc'
correct, multiple_selected, n_example = process_pkl_files(input_folder)

print(f"Numeri di esempi totali: {n_example}")
print(f"Esempi con almeno una risposta corretta: {correct}")
print(f"Esempi con più di un indice in 'selected_documents': {multiple_selected}")


Trovati i seguenti file .pkl:
numdoc5_retr5_info_250.pkl
numdoc5_retr5_info_500.pkl
numdoc5_retr5_info_750.pkl
numdoc5_retr5_info_1000.pkl
numdoc5_retr5_info_1250.pkl
numdoc5_retr5_info_1500.pkl
numdoc5_retr5_info_1750.pkl
numdoc5_retr5_info_2000.pkl
numdoc5_retr5_info_2250.pkl
numdoc5_retr5_info_2500.pkl
numdoc5_retr5_info_2750.pkl
numdoc5_retr5_info_3000.pkl
Numeri di esempi totali: 3000
Esempi con almeno una risposta corretta: 1062
Esempi con più di un indice in 'selected_documents': 153
