In [1]:
from datasets import load_dataset
dataset = load_dataset("trivia_qa", "unfiltered", split="train[:1000]")  # subset

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-nq",
    index_name="exact",  # FAISS index trained on DPR-wiki passages
    use_dummy_dataset=True  # loads built-in Wikipedia index
)

model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

from transformers import BartTokenizerFast
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
generator_tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large")

from transformers import BartTokenizerFast
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
generator_tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large")

# Prepare dataset
def process_example(example):
    # Tokenize input (question)
    input_encodings = rag_tokenizer(example['question'], truncation=True, padding="max_length", max_length=512)

    # Use just the "value" field of the answer if it exists
    if isinstance(example['answer'], dict) and 'value' in example['answer']:
        answer_text = example['answer']['value']
    else:
        answer_text = "No answer provided"

    # Tokenize answer (target)
    target_encodings = generator_tokenizer(answer_text, truncation=True, padding="max_length", max_length=128)

    # Return tokenized input and target
    return {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids']
    }



Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/47 [00:00<?, ?it/s]

2025-04-09 04:07:02.732390: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-09 04:07:02.990818: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744151823.123274 1049844 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744151823.156547 1049844 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-09 04:07:03.452545: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import torch
from torch.utils.data.dataloader import default_collate


def custom_collate_fn(batch):
    # Collate the batch by padding the sequences to the max length in the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Pad sequences to the max length in each batch (or use a fixed size)
    input_ids_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in input_ids], batch_first=True, padding_value=0)
    attention_mask_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in attention_mask], batch_first=True, padding_value=0)
    labels_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in labels], batch_first=True, padding_value=-100)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_mask_padded,
        'labels': labels_padded
    }

processed_dataset = dataset.map(process_example, remove_columns=["question", "answer"])



In [3]:
# When initializing your DataLoader
BATCH_SIZE = 4
train_dataloader = DataLoader(processed_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, num_workers=4, pin_memory=True)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
device = torch.device("cpu")
model.to(device)

accumulation_steps = 8  # Accumulate gradients over 8 steps

for epoch in range(1):  # Training for 1 epochs
    model.train()
    total_loss = 0
    optimizer.zero_grad()   # new change

    # Loop over batches
    for i, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Compute embeddings for the questions
        with torch.no_grad():
            output = model.question_encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden_state = output[0]  # Since it's a tuple with only one element
            question_hidden_states = last_hidden_state.cpu().numpy()  # ✅ Works for 2D tensors

        # 🔁 Multihop change: Retrieve top-k documents and concatenate them into a single context
        n_docs = 5  # Number of documents per hop
        n_hops = 2  # Number of hops (i.e., chained sets of documents)
        all_contexts = []

        # Retrieve contexts for multiple hops
        question_hidden_states_hop = question_hidden_states
        for hop in range(n_hops):
            _, _, doc_dicts = retriever.retrieve(question_hidden_states_hop, n_docs=n_docs)
            hop_contexts = [doc["text"] for doc in doc_dicts]
            hop_contexts = [item for sublist in hop_contexts for item in sublist]
            all_contexts.append(hop_contexts)

            # 🔁 Multihop change: For the next hop, re-encode the concatenated docs from this hop as the new query
            flat_hop_contexts = [" ".join(docs) for docs in hop_contexts]
            hop_encodings = rag_tokenizer(flat_hop_contexts, return_tensors="pt", padding=True, truncation=True, max_length=512)
            with torch.no_grad():
                hop_output = model.question_encoder(input_ids=hop_encodings['input_ids'].to(device),
                                                    attention_mask=hop_encodings['attention_mask'].to(device))
                question_hidden_states_hop = hop_output[0].cpu().numpy()

        # 🔁 Multihop change: Concatenate all retrieved documents from both hops
        combined_contexts = [" ".join(docs) for docs in all_contexts]

        context_encodings = generator_tokenizer.batch_encode_plus(
            combined_contexts,
            truncation=True,
            padding="max_length",
            max_length=512,
            return_tensors="pt"
        )

        context_input_ids = context_encodings['input_ids'].to(device).unsqueeze(1)  # shape: [B, 1, seq_len]
        context_attention_mask = context_encodings['attention_mask'].to(device).unsqueeze(1)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            context_input_ids=context_input_ids,
            context_attention_mask=context_attention_mask
        )

        loss = outputs.loss.mean()

        loss = loss / accumulation_steps  # New change
        
        # Backward pass
        loss.backward()

        # New change
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {total_loss / len(train_dataloader)}")

# === Save Model ===
model.save_pretrained("rag_model_finetuned_multihop")


100%|██████████| 250/250 [5:17:08<00:00, 76.11s/it]  


Epoch 1 - Loss: 3.308797043323517


In [4]:
# =====================================
# Evaluation on 100 TriviaQA datapoints
# =====================================
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score as bert_score_fn
from tqdm import tqdm

# (If you haven't already downloaded NLTK punkt tokenizer)
nltk.download('punkt')

# Select 100 datapoints for evaluation
eval_dataset = dataset.shuffle(seed=42).select(range(100))

# Initialize metric accumulators
bleu_scores = []
rouge1_scores = []
rouge2_scores = []
rougeL_scores = []
em_scores = []

# To compute BERTScore in batch later:
all_references = []
all_hypotheses = []

# Initialize ROUGE scorer and BLEU smoothing function
rouge_evaluator = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
smooth_fn = SmoothingFunction().method1

# Set the model to evaluation mode
model.eval()

print("\n--- Evaluation on 100 TriviaQA datapoints ---\n")
for example in tqdm(eval_dataset, desc="Evaluating"):
    question = example['question']
    ground_truth = example['answer']
    # Process ground truth to get a clean string answer
    if isinstance(ground_truth, dict) and 'value' in ground_truth:
        ref = ground_truth['value']
    else:
        ref = str(ground_truth)
    
    # Tokenize question using the RAG tokenizer
    input_dict = rag_tokenizer.prepare_seq2seq_batch([question], return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=input_dict['input_ids'],
            attention_mask=input_dict['attention_mask'],
            max_length=64,
            num_beams=4,
            num_return_sequences=1
        )
    hyp = generator_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # Save references and hypotheses for BERTScore later
    all_references.append(ref)
    all_hypotheses.append(hyp)
    
    # Compute BLEU score (using whitespace tokenization here)
    ref_tokens = ref.split()
    hyp_tokens = hyp.split()
    bleu = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smooth_fn)
    bleu_scores.append(bleu)
    
    # Compute ROUGE scores
    rouge_scores = rouge_evaluator.score(ref, hyp)
    rouge1_scores.append(rouge_scores['rouge1'].fmeasure)
    rouge2_scores.append(rouge_scores['rouge2'].fmeasure)
    rougeL_scores.append(rouge_scores['rougeL'].fmeasure)
    
    # Compute Exact Match (EM) metric (case-insensitive exact match)
    em = 1 if ref.lower().strip() == hyp.lower().strip() else 0
    em_scores.append(em)

# Compute average metrics
avg_bleu = sum(bleu_scores) / len(bleu_scores)
avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores)
avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores)
avg_rougeL = sum(rougeL_scores) / len(rougeL_scores)
avg_em = sum(em_scores) / len(em_scores)

# Compute BERTScore (F1) over all examples
P, R, F1 = bert_score_fn(all_hypotheses, all_references, lang="en", verbose=True)
avg_bert_f1 = F1.mean().item()
avg_bert_p = P.mean().item()
avg_bert_r = R.mean().item()

print("\n--- Evaluation Metrics ---")
print(f"BLEU: {avg_bleu:.4f}")
print(f"ROUGE-1: {avg_rouge1:.4f}")
print(f"ROUGE-2: {avg_rouge2:.4f}")
print(f"ROUGE-L: {avg_rougeL:.4f}")
print(f"Exact Match (EM): {avg_em*100:.2f}%")
print(f"BERT F1: {avg_bert_f1:.4f}")
print(f"BERT Precision: {avg_bert_p:.4f}")
print(f"BERT Recall: {avg_bert_r:.4f}")

[nltk_data] Downloading package punkt to /home/anishkav/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



--- Evaluation on 100 TriviaQA datapoints ---



Evaluating: 100%|██████████| 100/100 [20:24<00:00, 12.24s/it]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 1.26 seconds, 79.53 sentences/sec

--- Evaluation Metrics ---
BLEU: 0.0223
ROUGE-1: 0.0930
ROUGE-2: 0.0306
ROUGE-L: 0.0930
Exact Match (EM): 7.00%
BERT F1: 0.8559
BERT Precision: 0.8632
BERT Recall: 0.8526


In [None]:
# ==============================
# Evaluation: Generate Answers
# ==============================

# Load 5 random samples from the TriviaQA dataset
sample_questions = dataset.shuffle(seed=42).select(range(5))

print("\n--- Evaluation on 5 Sample TriviaQA Questions ---\n")

model.eval()
model.to(device)

for example in sample_questions:
    question = example['question']
    print(f"❓ Question: {question}")

    # Tokenize the question
    input_dict = rag_tokenizer.prepare_seq2seq_batch([question], return_tensors="pt").to(device)

    # Generate the answer
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=input_dict['input_ids'],
            attention_mask=input_dict['attention_mask'],
            max_length=64,
            num_return_sequences=1,
            num_beams=4
        )

    # Decode and print the output
    answer = generator_tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Extract clean ground truth answer
    gt_answer = example.get('answer', {})
    if isinstance(gt_answer, dict):
        gt_text = gt_answer.get('value', str(gt_answer))
    elif isinstance(gt_answer, list):
        gt_text = gt_answer[0] if gt_answer else "No answer"
    else:
        gt_text = str(gt_answer)

    print(f"💡 Generated Answer: {answer}")
    print(f"✅ Ground Truth Answer: {gt_text}\n")



--- Evaluation on 5 Sample TriviaQA Questions ---

❓ Question: Who is Julian Lennon's step-mother?
💡 Generated Answer: Lennon's mother-in-law, and his mother-in-law-in-law-in-law-in-law-in-law
✅ Ground Truth Answer: Yoko Ono

❓ Question: What was the name of Bob Fosse's character in All That Jazz?
💡 Generated Answer: Ludwig Göransson
✅ Ground Truth Answer: Joe Gideon

❓ Question: In which 1998 film did Bruce Willis lead a team to confront a deadly threat from outer space?
💡 Generated Answer: Mauvais vietnam
✅ Ground Truth Answer: Armageddon

❓ Question: Which comedy contained the song A Wink and a Smile?
💡 Generated Answer: The Hangover
✅ Ground Truth Answer: Sleepless in Seattle

❓ Question: "What sports activity was originally known in England as ""plank-gliding""?"
💡 Generated Answer: Guitar
✅ Ground Truth Answer: Waterskiing. The first recorded mention of the sport in England was in 1914



: 