# RAG

Build a simple Retrieval-Augmented Generation pipeline to demonstrate its working.

Steps:
1. Document Store: Use in-memory key-value store.
2. Retrieval: Use embeddings from GPT-2
3. Generation: Use GPT-2 for generating a response

In [20]:
# Store the embeddings for the dev set
import torch
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

import json
from transformers import BertTokenizer,BertModel
from transformers import DataCollatorWithPadding
from bert import BERT
from bert_config import BERTConfig
from rag.snliDataset import snliDataset, snliEmbeddings
from tqdm import tqdm

model = BERT.from_pretrained(config=BERTConfig())
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
device="cuda"
model.to(device)
model.eval()

def dynamic_padding(data,device="cuda"):
    s1 = [item["sentence1"] for item in data]
    s2 = [item["sentence2"] for item in data]
    labels = [item["label"] for item in data]
    encoded = tokenizer(s1,s2,padding=True,truncation=True,return_tensors="pt",max_length=512)
    encoded["attention_mask"] = encoded["attention_mask"].bool()
    encoded = {key: tensor.to(device) for key, tensor in encoded.items()}
    return encoded,labels

def prepare_data(split: str, output_filename:str):
    """
    Store the training data in the following as a json file. Format:
    {
        'input_ids': [tokenized input ids for sentence 1, sentence2]
        'embedding': [output of bert for the sentence1, sentence2 input],
        'label': int (0,1,2)
    }
    """
    sd = snliDataset(split)
    batch_size = 64
    dl = DataLoader(sd,batch_size=batch_size,collate_fn=dynamic_padding)    

    data = []
    with torch.no_grad():
        for encoded,labels in tqdm(dl):
            
            output = model(**encoded)
            output = output.cpu()
            embedding = output
            seq_lens = torch.sum(encoded["attention_mask"],dim=1)
            for i in range(len(encoded)):
                
                data_item = {
                    "input_ids": encoded["input_ids"][i][:seq_lens[i]].cpu(),
                    "embedding": embedding[i],
                    "label": labels[i],
                }
                data.append(data_item)

    torch.save({"data": data}, output_filename)
    
    print(f"Wrote {split} data to {output_filename}")

# prepare_data("dev", "dev_data.pt")
# prepare_data("test", "test_data.pt")
prepare_data("train","train_data.pt")

Loading pre-trained weights for BERT


In [3]:
# Build a MLP classifier
from torch import nn
import torch.nn.functional as F
class MLP(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(MLP, self).__init__()
        self.hidden_layer = nn.Linear(input_size,hidden_size)
        self.output_layer = nn.Linear(hidden_size,output_size)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self,x):
        x = self.dropout(F.relu(self.hidden_layer(x)))
        x = self.output_layer(x)
        return x


In [4]:
# Train loop
import torch
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

from snliDataset import snliDataset, snliEmbeddings
from torch.utils.data.dataloader import DataLoader

split = "dev"
se_train = snliEmbeddings(split=split)


In [18]:
device = "cuda"
mlp = MLP(768,256,3)
mlp.to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 1e-4
optimizer = torch.optim.Adam(mlp.parameters(),lr=learning_rate)
n_epochs = 500
train_loader = DataLoader(se_train,batch_size=64,shuffle=True)

for epoch in range(n_epochs):
    mlp.train()
    running_loss = 0.0

    for i,batch in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = mlp(batch["embedding"].to(device))
        loss = criterion(outputs,torch.tensor(batch["label"]).to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {running_loss/len(train_loader):.4f}")



  loss = criterion(outputs,torch.tensor(batch["label"]).to(device))


Epoch [1/500], Loss: 1.1131
Epoch [2/500], Loss: 1.1030
Epoch [3/500], Loss: 1.0953
Epoch [4/500], Loss: 1.0983
Epoch [5/500], Loss: 1.0935
Epoch [6/500], Loss: 1.0955
Epoch [7/500], Loss: 1.0884
Epoch [8/500], Loss: 1.0844
Epoch [9/500], Loss: 1.0917
Epoch [10/500], Loss: 1.0839
Epoch [11/500], Loss: 1.0950
Epoch [12/500], Loss: 1.0894
Epoch [13/500], Loss: 1.0937
Epoch [14/500], Loss: 1.0766
Epoch [15/500], Loss: 1.0751
Epoch [16/500], Loss: 1.0789
Epoch [17/500], Loss: 1.0756
Epoch [18/500], Loss: 1.0881
Epoch [19/500], Loss: 1.0775
Epoch [20/500], Loss: 1.0838
Epoch [21/500], Loss: 1.0759
Epoch [22/500], Loss: 1.0717
Epoch [23/500], Loss: 1.0759
Epoch [24/500], Loss: 1.0701
Epoch [25/500], Loss: 1.0690
Epoch [26/500], Loss: 1.0701
Epoch [27/500], Loss: 1.0654
Epoch [28/500], Loss: 1.0708
Epoch [29/500], Loss: 1.0670
Epoch [30/500], Loss: 1.0652
Epoch [31/500], Loss: 1.0644
Epoch [32/500], Loss: 1.0679
Epoch [33/500], Loss: 1.0710
Epoch [34/500], Loss: 1.0672
Epoch [35/500], Loss: 1

In [19]:
# Evaluation
mlp.eval()
correct, total = 0, 0
se_test = snliEmbeddings(split="test")
test_loader = DataLoader(se_test,batch_size=32)
with torch.no_grad():
    for batch in test_loader:
        outputs = mlp(batch["embedding"].to(device))
        _, predicted = torch.max(outputs, 1)
        # print(predicted, batch["label"])
        total += batch["label"].size(0)
        correct += (predicted == batch["label"].to(device)).sum().item()
    
print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 51.30%


In [7]:
a = torch.load("dev_data.pt")["data"]
print(a[0])

{'input_ids': tensor([  101,  2048,  2308,  2024, 23581,  2096,  3173,  2000,  2175, 14555,
         1012,   102,  1996,  5208,  2024, 17662,  9119,  2096,  3173,  2000,
         2175, 14555,  2044,  2074,  5983,  6265,  1012,   102]), 'embedding': tensor([-4.9645e-01, -1.7457e-01, -8.3766e-01,  7.9489e-01,  7.5655e-01,
        -2.4049e-01,  5.3153e-01,  1.7647e-01, -1.8890e-01, -1.1814e+00,
        -3.0919e-01,  1.0288e+00,  1.0833e+00,  7.8304e-02,  1.0192e+00,
        -4.1620e-01, -4.0028e-01, -1.1613e-01,  1.2654e-01, -6.3370e-02,
         9.0735e-01,  3.1044e+00, -1.5455e-01,  3.5114e-02,  9.9734e-02,
         1.2293e+00, -5.2205e-01,  1.0538e+00,  8.8199e-01,  5.2332e-01,
        -4.4482e-01, -1.2841e-02, -1.4690e+00, -1.2647e-01, -1.2444e+00,
        -1.0411e+00,  1.4478e-01, -2.5737e-01, -1.3852e-01, -4.3224e-02,
        -1.0746e+00,  8.0413e-02,  1.5950e+00,  1.3573e-01,  5.4307e-01,
        -7.0265e-03, -3.0560e+00,  2.2700e-01, -6.7728e-01,  2.1787e-01,
         5.6539e-01, 

  a = torch.load("dev_data.pt")["data"]
