# RAG

Build a simple Retrieval-Augmented Generation pipeline to demonstrate its working.

Steps:
1. Document Store: Use in-memory key-value store.
2. Retrieval: Use embeddings from GPT-2
3. Generation: Use GPT-2 for generating a response

In [None]:
# Store the embeddings for the dev set
import torch
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

import json
from transformers import BertTokenizer,BertModel
from transformers import DataCollatorWithPadding
from bert import BERT
from bert_config import BERTConfig
from rag.snliDataset import snliDataset, snliEmbeddings
from tqdm import tqdm

model = BERT.from_pretrained(config=BERTConfig())
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
device="cuda"
model.to(device)
model.eval()

def dynamic_padding(data,device="cuda"):
    s1 = [item["sentence1"] for item in data]
    s2 = [item["sentence2"] for item in data]
    labels = [item["label"] for item in data]
    encoded = tokenizer(s1,s2,padding=True,truncation=True,return_tensors="pt",max_length=512)
    encoded["attention_mask"] = encoded["attention_mask"].bool()
    encoded = {key: tensor.to(device) for key, tensor in encoded.items()}
    return encoded,labels

def prepare_data(split: str, output_filename:str):
    """
    Store the training data in the following as a json file. Format:
    {
        'input_ids': [tokenized input ids for sentence 1, sentence2]
        'embedding': [output of bert for the sentence1, sentence2 input],
        'label': int (0,1,2)
    }
    """
    sd = snliDataset(split)
    batch_size = 64
    dl = DataLoader(sd,batch_size=batch_size,collate_fn=dynamic_padding)    

    data = []
    with torch.no_grad():
        for encoded,labels in tqdm(dl):
            
            output = model(**encoded)
            output = output.cpu()
            embedding = output
            seq_lens = torch.sum(encoded["attention_mask"],dim=1)
            for i in range(len(encoded)):
                
                data_item = {
                    "input_ids": encoded["input_ids"][i][:seq_lens[i]].cpu(),
                    "embedding": embedding[i],
                    "label": labels[i],
                }
                data.append(data_item)

    torch.save({"data": data}, output_filename)
    
    print(f"Wrote {split} data to {output_filename}")

# prepare_data("dev", "dev_data.pt")
# prepare_data("test", "test_data.pt")
prepare_data("train","train_data.pt")

In [None]:
# Build a MLP classifier
from torch import nn
import torch.nn.functional as F
class MLP(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(MLP, self).__init__()
        self.hidden_layer = nn.Linear(input_size,hidden_size)
        self.output_layer = nn.Linear(hidden_size,output_size)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self,x):
        x = self.dropout(F.relu(self.hidden_layer(x)))
        x = self.output_layer(x)
        return x


In [None]:
# Train loop
import torch
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

from snliDataset import snliDataset, snliEmbeddings
from torch.utils.data.dataloader import DataLoader

split = "train"
se_train = snliEmbeddings(split=split)


In [None]:
device = "cuda"
mlp = MLP(768,256,3)
mlp.to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(mlp.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.5)
n_epochs = 500
train_loader = DataLoader(se_train,batch_size=64,shuffle=True)

for epoch in range(n_epochs):
    mlp.train()
    running_loss = 0.0

    for i,batch in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = mlp(batch["embedding"].to(device))
        loss = criterion(outputs,torch.tensor(batch["label"]).to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    scheduler.step()        
    print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {running_loss/len(train_loader):.4f}")



In [None]:
# Evaluation
mlp.eval()
correct, total = 0, 0
se_test = snliEmbeddings(split="test")
test_loader = DataLoader(se_test,batch_size=32)
with torch.no_grad():
    for batch in test_loader:
        outputs = mlp(batch["embedding"].to(device))
        _, predicted = torch.max(outputs, 1)
        # print(predicted, batch["label"])
        total += batch["label"].size(0)
        correct += (predicted == batch["label"].to(device)).sum().item()
    
print(f"Test Accuracy: {100 * correct / total:.2f}%")

In [None]:
a = torch.rand((1,8))
b = torch.rand((1,8))
c = torch.cat((a,b,torch.abs(a-b)),dim=1)
print(c.size())

Sentence BERT Training

In [1]:
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

from rag.train import Trainer
from rag.snliDataset import sentenceBERTDataset
from bert_config import BERTConfig, BERTTrainConfig

train_set = sentenceBERTDataset("train")
val_set = sentenceBERTDataset("dev")
train_config = BERTTrainConfig()
model_config = BERTConfig()

trainer = Trainer(train_set, val_set,model_config, train_config)

  from .autonotebook import tqdm as notebook_tqdm


Loading extracted data from /home/varun/Downloads/snli_1.0/sentenceBERT_dev.pt


  self.data = torch.load(cache_fpath)["data"]


In [2]:
trainer.train()

Loading pre-trained weights for BERT


  0%|          | 3/40000 [00:03<10:41:34,  1.04it/s]

Step: 0
Train Loss: 1.1745139360427856
Validation Loss:1.1882182359695435


  5%|▍         | 1999/40000 [01:49<32:22, 19.57it/s]

Step: 2000
Train Loss: 1.1046299934387207
Validation Loss:1.1010029315948486
Saving checkpoint to out/bert_ckpt_train.pt


 10%|▉         | 3998/40000 [03:39<31:20, 19.15it/s]  

Step: 4000
Train Loss: 1.0989958047866821
Validation Loss:1.0960195064544678
Saving checkpoint to out/bert_ckpt_train.pt


 12%|█▏        | 4904/40000 [04:35<32:50, 17.81it/s]  


KeyboardInterrupt: 

In [None]:
import torch
a = torch.tensor([1,2])
b = torch.tensor([3,4])
torch.cat([a,b,a-b],dim=1)