In [None]:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, BartTokenizerFast
from datasets import load_dataset
from itertools import islice
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score as bert_score_fn
from tqdm import tqdm
import torch

In [None]:
dataset = load_dataset("trivia_qa", "unfiltered", split="validation", streaming=True)  # subset
eval_dataset = list(islice(dataset, 0, 100))

In [None]:
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
generator_tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large")
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-nq",
    index_name="exact",  # FAISS index trained on DPR-wiki passages
    use_dummy_dataset=True  # loads built-in Wikipedia index
)
model = RagTokenForGeneration.from_pretrained("rag_model_finetuned", retriever=retriever)


In [None]:
# =====================================
# Evaluation on 100 TriviaQA datapoints
# =====================================

# (If you haven't already downloaded NLTK punkt tokenizer)
nltk.download('punkt')

# Initialize metric accumulators
bleu_scores = []
rouge1_scores = []
rouge2_scores = []
rougeL_scores = []
em_scores = []

# To compute BERTScore in batch later:
all_references = []
all_hypotheses = []

# Initialize ROUGE scorer and BLEU smoothing function
rouge_evaluator = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
smooth_fn = SmoothingFunction().method1

# Set the model to evaluation mode
model.eval()

print("\n--- Evaluation on 100 TriviaQA datapoints ---\n")
for example in tqdm(eval_dataset, desc="Evaluating"):
    question = example['question']
    ground_truth = example['answer']
    # Process ground truth to get a clean string answer
    if isinstance(ground_truth, dict) and 'value' in ground_truth:
        ref = ground_truth['value']
    else:
        ref = str(ground_truth)
    
    # Tokenize question using the RAG tokenizer
    input_dict = rag_tokenizer.prepare_seq2seq_batch([question], return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=input_dict['input_ids'],
            attention_mask=input_dict['attention_mask'],
            max_length=64,
            num_beams=4,
            num_return_sequences=1
        )
    hyp = generator_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # Save references and hypotheses for BERTScore later
    all_references.append(ref)
    all_hypotheses.append(hyp)
    
    # Compute BLEU score (using whitespace tokenization here)
    ref_tokens = ref.split()
    hyp_tokens = hyp.split()
    bleu = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smooth_fn)
    bleu_scores.append(bleu)
    
    # Compute ROUGE scores
    rouge_scores = rouge_evaluator.score(ref, hyp)
    rouge1_scores.append(rouge_scores['rouge1'].fmeasure)
    rouge2_scores.append(rouge_scores['rouge2'].fmeasure)
    rougeL_scores.append(rouge_scores['rougeL'].fmeasure)
    
    # Compute Exact Match (EM) metric (case-insensitive exact match)
    em = 1 if ref.lower().strip() == hyp.lower().strip() else 0
    em_scores.append(em)

# Compute average metrics
avg_bleu = sum(bleu_scores) / len(bleu_scores)
avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores)
avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores)
avg_rougeL = sum(rougeL_scores) / len(rougeL_scores)
avg_em = sum(em_scores) / len(em_scores)

# Compute BERTScore (F1) over all examples
P, R, F1 = bert_score_fn(all_hypotheses, all_references, lang="en", verbose=True)
avg_bert_f1 = F1.mean().item()
avg_bert_p = P.mean().item()
avg_bert_r = R.mean().item()

print("\n--- Evaluation Metrics ---")
print(f"BLEU: {avg_bleu:.4f}")
print(f"ROUGE-1: {avg_rouge1:.4f}")
print(f"ROUGE-2: {avg_rouge2:.4f}")
print(f"ROUGE-L: {avg_rougeL:.4f}")
print(f"Exact Match (EM): {avg_em*100:.2f}%")
print(f"BERT F1: {avg_bert_f1:.4f}")
print(f"BERT Precision: {avg_bert_p:.4f}")
print(f"BERT Recall: {avg_bert_r:.4f}")