In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import torch
from datasets import load_dataset
from itertools import islice
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, BartTokenizerFast
from bert_score import score
import torch
import re
import string
from torch.nn import functional as F
from rank_bm25 import BM25Okapi


In [None]:
dataset = load_dataset("trivia_qa", "unfiltered", split="train") 

# Initialize Dense Retriever
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-nq",
    index_name="exact",  # FAISS index trained on DPR-wiki passages
    use_dummy_dataset=True  # loads built-in Wikipedia index
)

model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")   
generator_tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large")

dense_corpus = retriever.index.dataset  # This holds the corpus of documents for retrieval

# Initialize sparse retriever (e.g., BM25) using the same dataset
sparse_corpus = [doc["text"].split() for doc in dense_corpus]  # Pre-tokenize the documents for BM25
bm25 = BM25Okapi(sparse_corpus)

In [None]:
# Prepare dataset
def process_example(example):
    # Tokenize input (question)
    input_encodings = rag_tokenizer(example['question'], truncation=True, padding="max_length", max_length=512)

    # Use just the "value" field of the answer if it exists
    if isinstance(example['answer'], dict) and 'value' in example['answer']:
        answer_text = example['answer']['value']
    else:
        answer_text = "No answer provided"

    # Tokenize answer (target)
    target_encodings = generator_tokenizer(answer_text, truncation=True, padding="max_length", max_length=128)

    # Return tokenized input and target
    return {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids']
    }


In [None]:
def custom_collate_fn(batch):
    # Collate the batch by padding the sequences to the max length in the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]
    

    # Pad sequences to the max length in each batch (or use a fixed size)
    input_ids_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in input_ids], batch_first=True, padding_value=0)
    attention_mask_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in attention_mask], batch_first=True, padding_value=0)
    labels_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in labels], batch_first=True, padding_value=-100)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_mask_padded,
        'labels': labels_padded,
    }

In [None]:
def normalize_text(s):
    """Lowercase and remove punctuation, articles, and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in string.punctuation)
    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, ground_truth):
    return int(normalize_text(prediction) == normalize_text(ground_truth))

def compute_f1(prediction, ground_truth):
    pred_tokens = normalize_text(prediction).split()
    gt_tokens = normalize_text(ground_truth).split()
    common = set(pred_tokens) & set(gt_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gt_tokens)
    return 2 * (precision * recall) / (precision + recall)

def compute_reward(predictions, references, retrieved_contexts, device, alpha=0.2, beta=0.4, gamma=0.3, delta=0.1):
    """
    Compute reward as weighted sum of:
    alpha * Exact Match + beta * F1 Score + gamma * BERTScore + delta * Consistency Score
    """

    em_scores = []
    f1_scores = []
    for pred, ref in zip(predictions, references):
        em = compute_exact_match(pred, ref)
        f1 = compute_f1(pred, ref)
        em_scores.append(em)
        f1_scores.append(f1)

    em_tensor = torch.tensor(em_scores, dtype=torch.float, device=device)
    f1_tensor = torch.tensor(f1_scores, dtype=torch.float, device=device)

    # Compute BERTScore between prediction and reference
    _, _, bert_f1 = score(predictions, references, lang='en', verbose=False, device=device)

    # Compute Consistency: BERTScore between prediction and retrieved contexts
    # retrieved_contexts: list of list of strings (for each sample)
    consistency_scores = []
    for pred, contexts in zip(predictions, retrieved_contexts):
        P, R, F1 = score([pred], [contexts], lang='en', verbose=False, device=device)
        consistency_scores.append(F1.item())

    consistency_tensor = torch.tensor(consistency_scores, dtype=torch.float, device=device)

    # Final reward
    reward = alpha * em_tensor + beta * f1_tensor + gamma * bert_f1 + delta * consistency_tensor
    return reward


In [None]:
#HYPERPARAMETERS
BATCH_SIZE = 4
EPOCHS = 10
S = 0.7
D = 0.3
accumulation_steps = 8  # Accumulate gradients over 8 steps
rl_every_n_batches = 4  # Compute RL loss every 4 batches
Lambda = 0.9  # Weight for supervised loss vs RL loss

# Load dataset and process it
processed_dataset = [process_example(example) for example in dataset]
train_dataloader = DataLoader(processed_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, num_workers=4, pin_memory=True)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
for epoch in range(EPOCHS):  # Training for 1 epochs
    model.train()
    total_loss = 0
    optimizer.zero_grad()   

    # Loop over batches
    for i, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # ======== Supervised Loss (no gradients into question encoder) ====
        with torch.no_grad():

            # Compute embeddings for the questions
            output = model.question_encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden_state = output[0]  
            question_hidden_states = last_hidden_state.cpu().numpy()  
            
            n_docs = 5  # Number of documents to retrieve
            
            # Retrieve top 5 docs using embeddings
            _, _, doc_dicts = retriever.retrieve(question_hidden_states, n_docs=n_docs)

            query_terms = [rag_tokenizer.decode(ids, skip_special_tokens=True).split() for ids in input_ids]
            sparse_doc_dicts = []
            for query in query_terms:
                scores = bm25.get_scores(query)
                top_doc_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_docs]
                sparse_doc_dicts.extend([sparse_corpus[i] for i in top_doc_indices])

            # Combine dense and sparse retrieved documents
            sparse_doc_dicts = [doc for doc in sparse_doc_dicts if doc]  # Filter out empty documents
            sparse_doc_dicts = sparse_doc_dicts[:int(S * len(sparse_doc_dicts))]
            dense_doc_dicts = [doc["text"] for doc in doc_dicts if doc]  # Filter out empty documents
            dense_doc_dicts = dense_doc_dicts[:int(D * len(dense_doc_dicts))]
            combined_contexts = dense_doc_dicts + sparse_doc_dicts
            flat_combined_contexts = [item for sublist in combined_contexts for item in sublist]

            context_encodings = generator_tokenizer.batch_encode_plus(
                flat_combined_contexts,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt"
            )

            # Reshape to [batch_size, n_docs, seq_len]
            context_input_ids = context_encodings['input_ids'].view(BATCH_SIZE, n_docs, -1)
            context_attention_mask = context_encodings['attention_mask'].view(input_ids.size(0), 5, -1).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            context_input_ids=context_input_ids,
            context_attention_mask=context_attention_mask
        )

        supervised_loss = outputs.loss.mean()
        
        # ======== RL Loss ========
        if i % rl_every_n_batches == 0:
            
            # Compute embeddings for the questions
            output = model.question_encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden_state = output[0]  # Since it's a tuple with only one element
            question_hidden_states = last_hidden_state.detach().cpu().numpy()  # ✅ Works for 2D tensors

            n_docs = 5  # Number of documents to retrieve
            
            # Retrieve top 5 docs using embeddings
            _, _, doc_dicts = retriever.retrieve(question_hidden_states, n_docs=n_docs)

            query_terms = [rag_tokenizer.decode(ids, skip_special_tokens=True).split() for ids in input_ids]
            sparse_doc_dicts = []
            for query in query_terms:
                scores = bm25.get_scores(query)
                top_doc_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_docs]
                sparse_doc_dicts.extend([sparse_corpus[i] for i in top_doc_indices])

            # Combine dense and sparse retrieved documents
            combined_contexts = [doc["text"] for doc in doc_dicts] + sparse_doc_dicts
            flat_combined_contexts = [item for sublist in combined_contexts for item in sublist]

            context_encodings = generator_tokenizer.batch_encode_plus(
                flat_combined_contexts,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt"
            )

            # Reshape to [batch_size, n_docs, seq_len]
            context_input_ids = context_encodings['input_ids'].view(BATCH_SIZE, n_docs, -1)
            context_attention_mask = context_encodings['attention_mask'].view(input_ids.size(0), 5, -1).to(device)

            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                num_return_sequences=1,
                num_beams=4,
                n_docs=4,  # ← So that Batch Size is divisible by n_docs
                max_length=16
            )

            generated_texts = generator_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            reference_texts = generator_tokenizer.batch_decode(labels, skip_special_tokens=True)

            # Prepare retrieved contexts
            retrieved_contexts = []
            for docs in doc_dicts:  # doc_dicts is batch of [n_docs retrieved per question]
                contexts = [doc["text"] for doc in doc_dicts]
                flat_contexts = [item for sublist in contexts for item in sublist]
                retrieved_contexts.append(flat_contexts)

            # Compute reward
            rewards = compute_reward(
                generated_texts, 
                reference_texts, 
                retrieved_contexts=retrieved_contexts, 
                device=device)
                        
            # Compute log-probs manually
            decoder_input_ids = generated_ids[:, :-1].contiguous()
            labels_for_logprob = generated_ids[:, 1:].contiguous()

            model_outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                use_cache=False
            )

            logits = model_outputs.logits  # [batch_size, seq_len, vocab_size]
            log_probs = F.log_softmax(logits, dim=-1)
            token_log_probs = torch.gather(log_probs, dim=2, index=labels_for_logprob.unsqueeze(2)).squeeze(2)

            # Mask padding
            label_mask = (labels_for_logprob != model.config.pad_token_id).float()
            sequence_log_probs = (token_log_probs * label_mask).sum(dim=1) / label_mask.sum(dim=1)

            rl_loss = (-rewards * sequence_log_probs).mean()

        else:
            rl_loss = 0.0
            
        # ======== Combined Loss ========
        combined_loss = Lambda * supervised_loss + (1 - Lambda) * rl_loss
        combined_loss = combined_loss / accumulation_steps
        combined_loss.backward()

        # Gradient accumulation
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += combined_loss.item()
   
    print(f"Epoch {epoch+1} - Combined Loss: {total_loss / len(train_dataloader)}")

# Save the model
model.save_pretrained("rag_model_rl_hr_finetuned")        
