In [1]:
from datasets import load_dataset
dataset = load_dataset("trivia_qa", "unfiltered", split="train", streaming=True)  # subset
from itertools import islice
dataset = list(islice(dataset, 0, 3000))

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
from datasets import load_dataset
dataset = load_dataset("trivia_qa", "unfiltered", split="train", streaming=True)  # subset
from itertools import islice
dataset = list(islice(dataset, 0, 1000))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-nq",
    index_name="exact",  # FAISS index trained on DPR-wiki passages
    use_dummy_dataset=True  # loads built-in Wikipedia index
)

model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

from transformers import BartTokenizerFast
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
generator_tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large")


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

In [3]:
# Prepare dataset
def process_example(example):
    # Tokenize input (question)
    input_encodings = rag_tokenizer(example['question'], truncation=True, padding="max_length", max_length=512)

    # Use just the "value" field of the answer if it exists
    if isinstance(example['answer'], dict) and 'value' in example['answer']:
        answer_text = example['answer']['value']
    else:
        answer_text = "No answer provided"

    # Tokenize answer (target)
    target_encodings = generator_tokenizer(answer_text, truncation=True, padding="max_length", max_length=128)

    # Return tokenized input and target
    return {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids']
    }


In [4]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import torch
from torch.utils.data.dataloader import default_collate


def custom_collate_fn(batch):
    # Collate the batch by padding the sequences to the max length in the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Pad sequences to the max length in each batch (or use a fixed size)
    input_ids_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in input_ids], batch_first=True, padding_value=0)
    attention_mask_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in attention_mask], batch_first=True, padding_value=0)
    labels_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in labels], batch_first=True, padding_value=-100)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_mask_padded,
        'labels': labels_padded
    }

processed_dataset = [process_example(example) for example in dataset]

# When initializing your DataLoader
BATCH_SIZE = 1
train_dataloader = DataLoader(processed_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, num_workers=0, pin_memory=True)


In [5]:
from sklearn.metrics import f1_score
import numpy as np
from bert_score import score
import torch
import re
import string

def normalize_text(s):
    """Lowercase and remove punctuation, articles, and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in string.punctuation)
    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, ground_truth):
    return int(normalize_text(prediction) == normalize_text(ground_truth))

def compute_f1(prediction, ground_truth):
    pred_tokens = normalize_text(prediction).split()
    gt_tokens = normalize_text(ground_truth).split()
    common = set(pred_tokens) & set(gt_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gt_tokens)
    return 2 * (precision * recall) / (precision + recall)

def compute_reward(predictions, references, device="cpu", alpha=0.3, beta=0.3, gamma=0.4):
    """
    Compute reward as weighted sum of:
    alpha * Exact Match + beta * F1 Score + gamma * BERTScore
    """
    # Convert predictions & references to normalized strings for EM/F1
    em_scores = []
    f1_scores = []
    for pred, ref in zip(predictions, references):
        em = compute_exact_match(pred, ref)
        f1 = compute_f1(pred, ref)
        em_scores.append(em)
        f1_scores.append(f1)

    em_tensor = torch.tensor(em_scores, dtype=torch.float, device=device)
    f1_tensor = torch.tensor(f1_scores, dtype=torch.float, device=device)

    # Compute BERTScore F1 (returns a tensor)
    _, _, bert_f1 = score(predictions, references, lang='en', verbose=False, device=device)

    # Combine all three
    reward = alpha * em_tensor + beta * f1_tensor + gamma * bert_f1
    return reward  # shape: [batch_size]


Matplotlib is building the font cache; this may take a moment.


In [None]:
from torch.nn import functional as F

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
device = torch.device("cpu")
model.to(device)

rl_every_n_batches = 4  # Compute RL loss every 4 batches

alpha = 0.9  # Weight for supervised loss vs RL loss
accumulation_steps = 8  # Accumulate gradients over 8 steps

for epoch in range(1):  # Training for 1 epochs
    model.train()
    total_loss = 0
    optimizer.zero_grad()   # new change

    # Loop over batches
    for i, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # ======== Supervised Loss (no gradients into question encoder) ====
        with torch.no_grad():
            # Compute embeddings for the questions
            output = model.question_encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden_state = output[0]  # Since it's a tuple with only one element
            question_hidden_states = last_hidden_state.cpu().numpy()  # ✅ Works for 2D tensors

            # 🔁 Multihop change: Retrieve top-k documents and concatenate them into a single context
            n_docs = 5  # Number of documents per hop
            n_hops = 2  # Number of hops (i.e., chained sets of documents)
            all_contexts = []

            # Retrieve contexts for multiple hops
            question_hidden_states_hop = question_hidden_states
            for hop in range(n_hops):
                _, _, doc_dicts = retriever.retrieve(question_hidden_states_hop, n_docs=n_docs)
                hop_contexts = [doc["text"] for doc in doc_dicts]
                hop_contexts = [item for sublist in hop_contexts for item in sublist]
                all_contexts.append(hop_contexts)

                # 🔁 Multihop change: For the next hop, re-encode the concatenated docs from this hop as the new query
                flat_hop_contexts = [" ".join(docs) for docs in hop_contexts]
                hop_encodings = rag_tokenizer(flat_hop_contexts, return_tensors="pt", padding=True, truncation=True, max_length=512)
                with torch.no_grad():
                    hop_output = model.question_encoder(input_ids=hop_encodings['input_ids'].to(device),
                                                        attention_mask=hop_encodings['attention_mask'].to(device))
                    question_hidden_states_hop = hop_output[0].cpu().numpy()

            # 🔁 Multihop change: Concatenate all retrieved documents from both hops
            combined_contexts = [" ".join(docs) for docs in all_contexts]

            context_encodings = generator_tokenizer.batch_encode_plus(
                combined_contexts,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt"
            )

            context_input_ids = context_encodings['input_ids'].to(device).unsqueeze(1)  # shape: [B, 1, seq_len]
            context_attention_mask = context_encodings['attention_mask'].to(device).unsqueeze(1)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            context_input_ids=context_input_ids,
            context_attention_mask=context_attention_mask
        )

        supervised_loss = outputs.loss.mean()
        
        # ======== RL Loss ========
        if i % rl_every_n_batches == 0:
            output = model.question_encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden_state = output[0]  # Since it's a tuple with only one element
            question_hidden_states = last_hidden_state.detach().cpu().numpy()  # ✅ Works for 2D tensors

            # 🔁 Multihop change: Retrieve top-k documents and concatenate them into a single context
            n_docs = 5  # Number of documents per hop
            n_hops = 2  # Number of hops (i.e., chained sets of documents)
            all_contexts = []

            # Retrieve contexts for multiple hops
            question_hidden_states_hop = question_hidden_states
            for hop in range(n_hops):
                _, _, doc_dicts = retriever.retrieve(question_hidden_states_hop, n_docs=n_docs)
                hop_contexts = [doc["text"] for doc in doc_dicts]
                hop_contexts = [item for sublist in hop_contexts for item in sublist]
                all_contexts.append(hop_contexts)

                # 🔁 Multihop change: For the next hop, re-encode the concatenated docs from this hop as the new query
                flat_hop_contexts = [" ".join(docs) for docs in hop_contexts]
                hop_encodings = rag_tokenizer(flat_hop_contexts, return_tensors="pt", padding=True, truncation=True, max_length=512)
                hop_output = model.question_encoder(input_ids=hop_encodings['input_ids'].to(device),
                                                    attention_mask=hop_encodings['attention_mask'].to(device))
                question_hidden_states_hop = hop_output[0].detach().cpu().numpy()

            # 🔁 Multihop change: Concatenate all retrieved documents from both hops
            combined_contexts = [" ".join(docs) for docs in all_contexts]

            context_encodings = generator_tokenizer.batch_encode_plus(
                combined_contexts,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt"
            )

            context_input_ids = context_encodings['input_ids'].to(device).unsqueeze(1)  # shape: [B, 1, seq_len]
            context_attention_mask = context_encodings['attention_mask'].to(device).unsqueeze(1)

            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                num_return_sequences=1,
                num_beams=1,
                n_docs=4,  # ← So that Batch Size is divisible by n_docs
                max_length=16
            )

            generated_texts = generator_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            reference_texts = generator_tokenizer.batch_decode(labels, skip_special_tokens=True)

            rewards = compute_reward(generated_texts, reference_texts)  # → Tensor of shape [batch_size]
            
            # Compute log-probs manually
            decoder_input_ids = generated_ids[:, :-1].contiguous()
            labels_for_logprob = generated_ids[:, 1:].contiguous()

            model_outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                use_cache=False
            )

            logits = model_outputs.logits  # [batch_size, seq_len, vocab_size]
            log_probs = F.log_softmax(logits, dim=-1)
            token_log_probs = torch.gather(log_probs, dim=2, index=labels_for_logprob.unsqueeze(2)).squeeze(2)

            # Mask padding
            label_mask = (labels_for_logprob != model.config.pad_token_id).float()
            sequence_log_probs = (token_log_probs * label_mask).sum(dim=1) / label_mask.sum(dim=1)

            rl_loss = (-rewards * sequence_log_probs).mean()

        else:
            rl_loss = 0.0
            
        # ======== Combined Loss ========
        combined_loss = alpha * supervised_loss + (1 - alpha) * rl_loss
        combined_loss = combined_loss / accumulation_steps
        combined_loss.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += combined_loss.item()
   
    print(f"Epoch {epoch+1} - Combined Loss: {total_loss / len(train_dataloader)}")

# Save the model
model.save_pretrained("multihop_rl_finetuned")        
 

  0%|          | 0/1000 [00:00<?, ?it/s]Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 4/1000 [01:44<5:38:24, 20.39s/it] Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1%|          | 8/1000 [02:56<4:44:29, 17.21s/it]Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1%|          | 12/1000 [04:04

Epoch 1 - Combined Loss: 2.819846179574728


In [7]:
k = list(dataset)
print(len(k))

1000


In [8]:
import random


import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score as bert_score_fn
from tqdm import tqdm

# (If you haven't already downloaded NLTK punkt tokenizer)
nltk.download('punkt')

# Select 100 datapoints for evaluation
eval_dataset = random.sample(k, 100)

# Initialize metric accumulators
bleu_scores = []
rouge1_scores = []
rouge2_scores = []
rougeL_scores = []
em_scores = []

# To compute BERTScore in batch later:
all_references = []
all_hypotheses = []

# Initialize ROUGE scorer and BLEU smoothing function
rouge_evaluator = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
smooth_fn = SmoothingFunction().method1

# Set the model to evaluation mode
model.eval()

print("\n--- Evaluation on 100 TriviaQA datapoints ---\n")
for example in tqdm(eval_dataset, desc="Evaluating"):
    question = example['question']
    ground_truth = example['answer']
    # Process ground truth to get a clean string answer
    if isinstance(ground_truth, dict) and 'value' in ground_truth:
        ref = ground_truth['value']
    else:
        ref = str(ground_truth)
    
    # Tokenize question using the RAG tokenizer
    input_dict = rag_tokenizer.prepare_seq2seq_batch([question], return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=input_dict['input_ids'],
            attention_mask=input_dict['attention_mask'],
            max_length=64,
            num_beams=4,
            num_return_sequences=1
        )
    hyp = generator_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # Save references and hypotheses for BERTScore later
    all_references.append(ref)
    all_hypotheses.append(hyp)
    
    # Compute BLEU score (using whitespace tokenization here)
    ref_tokens = ref.split()
    hyp_tokens = hyp.split()
    bleu = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smooth_fn)
    bleu_scores.append(bleu)
    
    # Compute ROUGE scores
    rouge_scores = rouge_evaluator.score(ref, hyp)
    rouge1_scores.append(rouge_scores['rouge1'].fmeasure)
    rouge2_scores.append(rouge_scores['rouge2'].fmeasure)
    rougeL_scores.append(rouge_scores['rougeL'].fmeasure)
    
    # Compute Exact Match (EM) metric (case-insensitive exact match)
    em = 1 if ref.lower().strip() == hyp.lower().strip() else 0
    em_scores.append(em)

# Compute average metrics
avg_bleu = sum(bleu_scores) / len(bleu_scores)
avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores)
avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores)
avg_rougeL = sum(rougeL_scores) / len(rougeL_scores)
avg_em = sum(em_scores) / len(em_scores)

# Compute BERTScore (F1) over all examples
P, R, F1 = bert_score_fn(all_hypotheses, all_references, lang="en", verbose=True)
avg_bert_f1 = F1.mean().item()
avg_bert_p = P.mean().item()
avg_bert_r = R.mean().item()

print("\n--- Evaluation Metrics ---")
print(f"BLEU: {avg_bleu:.4f}")
print(f"ROUGE-1: {avg_rouge1:.4f}")
print(f"ROUGE-2: {avg_rouge2:.4f}")
print(f"ROUGE-L: {avg_rougeL:.4f}")
print(f"Exact Match (EM): {avg_em*100:.2f}%")
print(f"BERT F1: {avg_bert_f1:.4f}")
print(f"BERT Precision: {avg_bert_p:.4f}")
print(f"BERT Recall: {avg_bert_r:.4f}")

[nltk_data] Downloading package punkt to /home/anishkav/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



--- Evaluation on 100 TriviaQA datapoints ---



Evaluating:  19%|█▉        | 19/100 [00:38<02:50,  2.11s/it]

Evaluating: 100%|██████████| 100/100 [03:38<00:00,  2.18s/it]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


100%|██████████| 3/3 [00:02<00:00,  1.30it/s]


computing greedy matching.


100%|██████████| 2/2 [00:00<00:00,  4.14it/s]

done in 2.80 seconds, 35.75 sentences/sec

--- Evaluation Metrics ---
BLEU: 0.0466
ROUGE-1: 0.1570
ROUGE-2: 0.0690
ROUGE-L: 0.1570
Exact Match (EM): 9.00%
BERT F1: 0.8642
BERT Precision: 0.8645
BERT Recall: 0.8670



