In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import torch
from datasets import load_dataset
from itertools import islice
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, BartTokenizerFast


In [None]:
dataset = load_dataset("trivia_qa", "unfiltered", split="train", streaming=True)  # subset
dataset = list(islice(dataset, 0, 1000))

In [None]:
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-nq",
    index_name="exact",  # FAISS index trained on DPR-wiki passages
    use_dummy_dataset=True  # loads built-in Wikipedia index
)

model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)


In [None]:
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")   
generator_tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large")


In [None]:
# Prepare dataset
def process_example(example):
    # Tokenize input (question)
    input_encodings = rag_tokenizer(example['question'], truncation=True, padding="max_length", max_length=512)

    # Use just the "value" field of the answer if it exists
    if isinstance(example['answer'], dict) and 'value' in example['answer']:
        answer_text = example['answer']['value']
    else:
        answer_text = "No answer provided"

    # Tokenize answer (target)
    target_encodings = generator_tokenizer(answer_text, truncation=True, padding="max_length", max_length=128)

    # Return tokenized input and target
    return {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids']
    }
    



In [None]:
def custom_collate_fn(batch):
    # Collate the batch by padding the sequences to the max length in the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]
    

    # Pad sequences to the max length in each batch (or use a fixed size)
    input_ids_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in input_ids], batch_first=True, padding_value=0)
    attention_mask_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in attention_mask], batch_first=True, padding_value=0)
    labels_padded = torch.nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in labels], batch_first=True, padding_value=-100)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_mask_padded,
        'labels': labels_padded,
    }



In [None]:
#HYPERPARAMETERS
BATCH_SIZE = 4
EPOCHS = 1
accumulation_steps = 8  # Accumulate gradients over 8 steps

# Load dataset and process it
processed_dataset = [process_example(example) for example in dataset]
train_dataloader = DataLoader(processed_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, num_workers=4, pin_memory=True)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
# Training loop
for epoch in range(EPOCHS):  
    model.train()
    total_loss = 0
    optimizer.zero_grad()  

    # Loop over batches
    for i, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Compute embeddings for the questions
        with torch.no_grad():
            output = model.question_encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden_state = output[0]  
            question_hidden_states = last_hidden_state.cpu().numpy()  

            n_docs = 5  # Number of documents to retrieve
            
            # Retrieve top 5 docs using embeddings
            _, _, doc_dicts = retriever.retrieve(question_hidden_states, n_docs=n_docs)

            # Convert doc dicts into context encodings using the generator tokenizer
            contexts = [doc["text"] for doc in doc_dicts]

            flat_contexts = [item for sublist in contexts for item in sublist]

            context_encodings = generator_tokenizer.batch_encode_plus(
                flat_contexts,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt"
            )

            # Reshape to [batch_size, n_docs, seq_len]
            context_input_ids = context_encodings['input_ids'].view(BATCH_SIZE, n_docs, -1)
            context_attention_mask = context_encodings['attention_mask'].view(input_ids.size(0), 5, -1).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            context_input_ids=context_input_ids,
            context_attention_mask=context_attention_mask
        )

        # Compute loss
        loss = outputs.loss.mean()
        loss = loss / accumulation_steps  
        
        # Backward pass
        loss.backward()

        # Gradient accumulation
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {total_loss / len(train_dataloader)}")

# Save the model after training
model.save_pretrained("rag_model_finetuned")
