# RAG

Build a simple Retrieval-Augmented Generation pipeline to demonstrate its working.

Steps:
1. Document Store: Use in-memory key-value store.
2. Retrieval: Use embeddings from GPT-2
3. Generation: Use GPT-2 for generating a response

In [None]:
# Store the embeddings for the dev set
import torch
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

import json
from transformers import BertTokenizer,BertModel
from transformers import DataCollatorWithPadding
from bert import BERT
from bert_config import BERTConfig
from rag.snliDataset import snliDataset, snliEmbeddings
from tqdm import tqdm

model = BERT.from_pretrained(config=BERTConfig())
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
device="cuda"
model.to(device)
model.eval()

def dynamic_padding(data,device="cuda"):
    s1 = [item["sentence1"] for item in data]
    s2 = [item["sentence2"] for item in data]
    labels = [item["label"] for item in data]
    encoded = tokenizer(s1,s2,padding=True,truncation=True,return_tensors="pt",max_length=512)
    encoded["attention_mask"] = encoded["attention_mask"].bool()
    encoded = {key: tensor.to(device) for key, tensor in encoded.items()}
    return encoded,labels

def prepare_data(split: str, output_filename:str):
    """
    Store the training data in the following as a json file. Format:
    {
        'input_ids': [tokenized input ids for sentence 1, sentence2]
        'embedding': [output of bert for the sentence1, sentence2 input],
        'label': int (0,1,2)
    }
    """
    sd = snliDataset(split)
    batch_size = 64
    dl = DataLoader(sd,batch_size=batch_size,collate_fn=dynamic_padding)    

    data = []
    with torch.no_grad():
        for encoded,labels in tqdm(dl):
            
            output = model(**encoded)
            output = output.cpu()
            embedding = output
            seq_lens = torch.sum(encoded["attention_mask"],dim=1)
            for i in range(len(encoded)):
                
                data_item = {
                    "input_ids": encoded["input_ids"][i][:seq_lens[i]].cpu(),
                    "embedding": embedding[i],
                    "label": labels[i],
                }
                data.append(data_item)

    torch.save({"data": data}, output_filename)
    
    print(f"Wrote {split} data to {output_filename}")

# prepare_data("dev", "dev_data.pt")
# prepare_data("test", "test_data.pt")
prepare_data("train","train_data.pt")

In [None]:
# Build a MLP classifier
from torch import nn
import torch.nn.functional as F
class MLP(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(MLP, self).__init__()
        self.hidden_layer = nn.Linear(input_size,hidden_size)
        self.output_layer = nn.Linear(hidden_size,output_size)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self,x):
        x = self.dropout(F.relu(self.hidden_layer(x)))
        x = self.output_layer(x)
        return x


In [None]:
# Train loop
import torch
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

from snliDataset import snliDataset, snliEmbeddings
from torch.utils.data.dataloader import DataLoader

split = "train"
se_train = snliEmbeddings(split=split)


In [None]:
device = "cuda"
mlp = MLP(768,256,3)
mlp.to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(mlp.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.5)
n_epochs = 500
train_loader = DataLoader(se_train,batch_size=64,shuffle=True)

for epoch in range(n_epochs):
    mlp.train()
    running_loss = 0.0

    for i,batch in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = mlp(batch["embedding"].to(device))
        loss = criterion(outputs,torch.tensor(batch["label"]).to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    scheduler.step()        
    print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {running_loss/len(train_loader):.4f}")



In [None]:
# Evaluation
mlp.eval()
correct, total = 0, 0
se_test = snliEmbeddings(split="test")
test_loader = DataLoader(se_test,batch_size=32)
with torch.no_grad():
    for batch in test_loader:
        outputs = mlp(batch["embedding"].to(device))
        _, predicted = torch.max(outputs, 1)
        # print(predicted, batch["label"])
        total += batch["label"].size(0)
        correct += (predicted == batch["label"].to(device)).sum().item()
    
print(f"Test Accuracy: {100 * correct / total:.2f}%")

Sentence BERT Training

In [1]:
import sys
sys.path.append("/home/varun/projects/experiments-with-gpt2/")

from rag.train import Trainer
from rag.snliDataset import sentenceBERTDataset
from bert_config import BERTConfig, BERTTrainConfig

train_set = sentenceBERTDataset("train")
val_set = sentenceBERTDataset("dev")
train_config = BERTTrainConfig()
model_config = BERTConfig()

trainer = Trainer(train_set, val_set,model_config, train_config)

  from .autonotebook import tqdm as notebook_tqdm


Loading extracted data from /home/varun/Downloads/snli_1.0/sentenceBERT_dev.pt


  self.data = torch.load(cache_fpath)["data"]


In [2]:
trainer.train()

Loading pre-trained weights for BERT


  0%|          | 1/300000 [00:05<480:24:55,  5.77s/it]

Step: 0
Train Loss: 1.345630407333374
Validation Loss:1.3962757587432861


  1%|          | 1999/300000 [02:05<4:49:21, 17.16it/s]

Step: 2000
Train Loss: 0.9277328848838806
Validation Loss:0.9481871724128723
Saving checkpoint to out/bert_ckpt_train.pt


  1%|▏         | 3999/300000 [04:15<4:43:36, 17.40it/s]  

Step: 4000
Train Loss: 0.9178657531738281
Validation Loss:0.8213080167770386
Saving checkpoint to out/bert_ckpt_train.pt


  2%|▏         | 5999/300000 [06:24<4:44:17, 17.24it/s]  

Step: 6000
Train Loss: 0.7877076268196106
Validation Loss:0.7275485992431641
Saving checkpoint to out/bert_ckpt_train.pt


  3%|▎         | 8003/300000 [08:40<50:02:53,  1.62it/s] 

Step: 8000
Train Loss: 0.7771178483963013
Validation Loss:0.7609997391700745


  3%|▎         | 9999/300000 [10:41<4:52:06, 16.55it/s] 

Step: 10000
Train Loss: 0.7720151543617249
Validation Loss:0.6933119297027588
Saving checkpoint to out/bert_ckpt_train.pt


  4%|▍         | 12003/300000 [12:57<52:11:19,  1.53it/s] 

Step: 12000
Train Loss: 0.7479739189147949
Validation Loss:0.7340375185012817


  5%|▍         | 13999/300000 [14:58<4:51:11, 16.37it/s] 

Step: 14000
Train Loss: 0.7674320936203003
Validation Loss:0.653600811958313
Saving checkpoint to out/bert_ckpt_train.pt


  5%|▌         | 16003/300000 [17:14<49:06:41,  1.61it/s] 

Step: 16000
Train Loss: 0.6873766779899597
Validation Loss:0.6889321804046631


  6%|▌         | 18003/300000 [19:19<48:25:44,  1.62it/s]

Step: 18000
Train Loss: 0.6119562983512878
Validation Loss:0.699036180973053


  7%|▋         | 20003/300000 [21:26<50:00:38,  1.56it/s]

Step: 20000
Train Loss: 0.7430108189582825
Validation Loss:0.6937074065208435


  7%|▋         | 22003/300000 [23:31<48:22:07,  1.60it/s]

Step: 22000
Train Loss: 0.6567671895027161
Validation Loss:0.6687963604927063


  8%|▊         | 23999/300000 [25:32<4:28:32, 17.13it/s] 

Step: 24000
Train Loss: 0.606904923915863
Validation Loss:0.5758583545684814
Saving checkpoint to out/bert_ckpt_train.pt


  9%|▊         | 26003/300000 [27:48<47:28:33,  1.60it/s] 

Step: 26000
Train Loss: 0.5327774882316589
Validation Loss:0.7549976706504822


  9%|▉         | 28003/300000 [29:53<47:07:38,  1.60it/s]

Step: 28000
Train Loss: 0.7031863331794739
Validation Loss:0.6199772953987122


 10%|█         | 30003/300000 [32:01<48:43:10,  1.54it/s]

Step: 30000
Train Loss: 0.6591920256614685
Validation Loss:0.6686368584632874


 11%|█         | 32003/300000 [34:07<46:17:50,  1.61it/s]

Step: 32000
Train Loss: 0.6026660203933716
Validation Loss:0.7927854061126709


 11%|█▏        | 34003/300000 [36:13<46:00:32,  1.61it/s]

Step: 34000
Train Loss: 0.6167892217636108
Validation Loss:0.7285677194595337


 12%|█▏        | 36003/300000 [38:18<45:35:04,  1.61it/s]

Step: 36000
Train Loss: 0.7224059104919434
Validation Loss:0.6390355229377747


 13%|█▎        | 37999/300000 [40:18<4:14:14, 17.18it/s] 

Step: 38000
Train Loss: 0.5646682977676392
Validation Loss:0.5504401922225952
Saving checkpoint to out/bert_ckpt_train.pt


 13%|█▎        | 40003/300000 [42:33<44:54:45,  1.61it/s] 

Step: 40000
Train Loss: 0.5973913073539734
Validation Loss:0.6216921806335449


 14%|█▍        | 42003/300000 [44:42<44:49:13,  1.60it/s]

Step: 42000
Train Loss: 0.6266039609909058
Validation Loss:0.5773576498031616


 15%|█▍        | 44003/300000 [46:49<44:10:48,  1.61it/s]

Step: 44000
Train Loss: 0.5036266446113586
Validation Loss:0.6436732411384583


 15%|█▌        | 46003/300000 [48:54<43:47:39,  1.61it/s]

Step: 46000
Train Loss: 0.5928137302398682
Validation Loss:0.6320009231567383


 16%|█▌        | 48003/300000 [51:01<45:51:01,  1.53it/s]

Step: 48000
Train Loss: 0.4904387593269348
Validation Loss:0.713444709777832


 17%|█▋        | 50003/300000 [53:07<43:01:10,  1.61it/s]

Step: 50000
Train Loss: 0.5920626521110535
Validation Loss:0.7506657242774963


 17%|█▋        | 52003/300000 [55:13<42:50:08,  1.61it/s]

Step: 52000
Train Loss: 0.5396454334259033
Validation Loss:0.6250673532485962


 18%|█▊        | 54003/300000 [57:18<42:38:43,  1.60it/s]

Step: 54000
Train Loss: 0.5341496467590332
Validation Loss:0.5551284551620483


 19%|█▊        | 56003/300000 [59:24<42:11:27,  1.61it/s]

Step: 56000
Train Loss: 0.7520700097084045
Validation Loss:0.5815581679344177


 19%|█▉        | 57999/300000 [1:01:24<3:51:39, 17.41it/s]

Step: 58000
Train Loss: 0.6114374399185181
Validation Loss:0.5423187017440796
Saving checkpoint to out/bert_ckpt_train.pt


 20%|██        | 60003/300000 [1:03:39<41:15:40,  1.62it/s]

Step: 60000
Train Loss: 0.5535390973091125
Validation Loss:0.6691470742225647


 21%|██        | 62003/300000 [1:05:46<43:32:07,  1.52it/s]

Step: 62000
Train Loss: 0.605380117893219
Validation Loss:0.7054019570350647


 21%|██▏       | 64003/300000 [1:07:53<42:52:36,  1.53it/s]

Step: 64000
Train Loss: 0.5070094466209412
Validation Loss:0.5917572379112244


 22%|██▏       | 66003/300000 [1:09:59<40:26:25,  1.61it/s]

Step: 66000
Train Loss: 0.6413533091545105
Validation Loss:0.5636429190635681


 23%|██▎       | 67999/300000 [1:11:58<3:47:00, 17.03it/s] 

Step: 68000
Train Loss: 0.519533097743988
Validation Loss:0.5364391803741455
Saving checkpoint to out/bert_ckpt_train.pt


 23%|██▎       | 70003/300000 [1:14:15<41:44:10,  1.53it/s]

Step: 70000
Train Loss: 0.4837356507778168
Validation Loss:0.6066639423370361


 24%|██▍       | 72003/300000 [1:16:23<41:05:47,  1.54it/s]

Step: 72000
Train Loss: 0.5611888766288757
Validation Loss:0.554851233959198


 25%|██▍       | 74003/300000 [1:18:30<39:13:42,  1.60it/s]

Step: 74000
Train Loss: 0.6362836956977844
Validation Loss:0.5548858642578125


 25%|██▌       | 76003/300000 [1:20:38<40:47:00,  1.53it/s]

Step: 76000
Train Loss: 0.6098135709762573
Validation Loss:0.5994907021522522


 26%|██▌       | 78003/300000 [1:22:43<38:21:22,  1.61it/s]

Step: 78000
Train Loss: 0.7299414873123169
Validation Loss:0.5772035717964172


 27%|██▋       | 80003/300000 [1:24:49<38:23:19,  1.59it/s]

Step: 80000
Train Loss: 0.537574291229248
Validation Loss:0.6231258511543274


 27%|██▋       | 81999/300000 [1:26:50<3:40:11, 16.50it/s] 

Step: 82000
Train Loss: 0.590459942817688
Validation Loss:0.528691828250885
Saving checkpoint to out/bert_ckpt_train.pt


 28%|██▊       | 83999/300000 [1:29:00<3:27:31, 17.35it/s] 

Step: 84000
Train Loss: 0.4251180589199066
Validation Loss:0.4893195629119873
Saving checkpoint to out/bert_ckpt_train.pt


 29%|██▊       | 85999/300000 [1:31:09<3:27:06, 17.22it/s] 

Step: 86000
Train Loss: 0.5405653715133667
Validation Loss:0.4811805486679077
Saving checkpoint to out/bert_ckpt_train.pt


 29%|██▉       | 88003/300000 [1:33:25<36:36:48,  1.61it/s]

Step: 88000
Train Loss: 0.5389195680618286
Validation Loss:0.6535629034042358


 30%|███       | 90003/300000 [1:35:31<36:09:46,  1.61it/s]

Step: 90000
Train Loss: 0.5466469526290894
Validation Loss:0.6481184959411621


 31%|███       | 92003/300000 [1:37:36<35:59:17,  1.61it/s]

Step: 92000
Train Loss: 0.49778544902801514
Validation Loss:0.6464508175849915


 31%|███▏      | 94003/300000 [1:39:42<35:21:13,  1.62it/s]

Step: 94000
Train Loss: 0.4770573675632477
Validation Loss:0.5983986258506775


 32%|███▏      | 96003/300000 [1:41:48<36:44:47,  1.54it/s]

Step: 96000
Train Loss: 0.5423885583877563
Validation Loss:0.5649168491363525


 33%|███▎      | 98003/300000 [1:43:54<34:45:10,  1.61it/s]

Step: 98000
Train Loss: 0.5822342038154602
Validation Loss:0.49053552746772766


 33%|███▎      | 100003/300000 [1:45:59<34:28:06,  1.61it/s]

Step: 100000
Train Loss: 0.5399509072303772
Validation Loss:0.5416966080665588


 34%|███▍      | 102003/300000 [1:48:04<34:11:03,  1.61it/s]

Step: 102000
Train Loss: 0.5697510242462158
Validation Loss:0.6018663048744202


 35%|███▍      | 104003/300000 [1:50:10<33:48:21,  1.61it/s]

Step: 104000
Train Loss: 0.5953311920166016
Validation Loss:0.49249008297920227


 35%|███▌      | 105999/300000 [1:52:13<3:14:12, 16.65it/s] 

Step: 106000
Train Loss: 0.472395658493042
Validation Loss:0.46317899227142334
Saving checkpoint to out/bert_ckpt_train.pt


 36%|███▌      | 108003/300000 [1:54:30<33:08:09,  1.61it/s]

Step: 108000
Train Loss: 0.5688108801841736
Validation Loss:0.5664777755737305


 37%|███▋      | 110003/300000 [1:56:35<32:43:28,  1.61it/s]

Step: 110000
Train Loss: 0.5071233510971069
Validation Loss:0.6206128001213074


 37%|███▋      | 112003/300000 [1:58:41<32:29:10,  1.61it/s]

Step: 112000
Train Loss: 0.5266244411468506
Validation Loss:0.570122241973877


 38%|███▊      | 113999/300000 [2:00:41<2:57:55, 17.42it/s] 

Step: 114000
Train Loss: 0.5468488335609436
Validation Loss:0.44948625564575195
Saving checkpoint to out/bert_ckpt_train.pt


 39%|███▊      | 116003/300000 [2:02:56<31:34:25,  1.62it/s]

Step: 116000
Train Loss: 0.5297086834907532
Validation Loss:0.5638591051101685


 39%|███▉      | 118003/300000 [2:05:02<31:35:01,  1.60it/s]

Step: 118000
Train Loss: 0.4632229208946228
Validation Loss:0.5348275303840637


 40%|████      | 120003/300000 [2:07:09<31:31:16,  1.59it/s]

Step: 120000
Train Loss: 0.5314221978187561
Validation Loss:0.4886157810688019


 41%|████      | 122003/300000 [2:09:16<32:11:08,  1.54it/s]

Step: 122000
Train Loss: 0.49538856744766235
Validation Loss:0.5237554907798767


 41%|████▏     | 124003/300000 [2:11:22<30:23:33,  1.61it/s]

Step: 124000
Train Loss: 0.46839937567710876
Validation Loss:0.4775594472885132


 42%|████▏     | 126003/300000 [2:13:27<30:06:12,  1.61it/s]

Step: 126000
Train Loss: 0.5157274603843689
Validation Loss:0.5575260519981384


 43%|████▎     | 128003/300000 [2:15:33<29:41:46,  1.61it/s]

Step: 128000
Train Loss: 0.5392510294914246
Validation Loss:0.5697758793830872


 43%|████▎     | 130003/300000 [2:17:38<29:28:15,  1.60it/s]

Step: 130000
Train Loss: 0.5199003219604492
Validation Loss:0.5549852252006531


 44%|████▍     | 132003/300000 [2:19:44<28:52:27,  1.62it/s]

Step: 132000
Train Loss: 0.4762638807296753
Validation Loss:0.565432071685791


 45%|████▍     | 134003/300000 [2:21:49<28:34:43,  1.61it/s]

Step: 134000
Train Loss: 0.42055419087409973
Validation Loss:0.63033527135849


 45%|████▌     | 136003/300000 [2:23:58<29:49:55,  1.53it/s]

Step: 136000
Train Loss: 0.4456484615802765
Validation Loss:0.543994128704071


 46%|████▌     | 137999/300000 [2:25:58<2:36:21, 17.27it/s] 

Step: 138000
Train Loss: 0.5844992995262146
Validation Loss:0.44763267040252686
Saving checkpoint to out/bert_ckpt_train.pt


 47%|████▋     | 140003/300000 [2:28:14<27:30:07,  1.62it/s]

Step: 140000
Train Loss: 0.43185123801231384
Validation Loss:0.5387284159660339


 47%|████▋     | 142003/300000 [2:30:21<27:18:20,  1.61it/s]

Step: 142000
Train Loss: 0.5332781076431274
Validation Loss:0.5737150311470032


 48%|████▊     | 144003/300000 [2:32:26<28:12:19,  1.54it/s]

Step: 144000
Train Loss: 0.4157370626926422
Validation Loss:0.4955263137817383


 49%|████▊     | 146003/300000 [2:34:33<26:40:29,  1.60it/s]

Step: 146000
Train Loss: 0.4471611976623535
Validation Loss:0.45860734581947327


 49%|████▉     | 148003/300000 [2:36:39<27:36:16,  1.53it/s]

Step: 148000
Train Loss: 0.4621451497077942
Validation Loss:0.4782772362232208


 50%|█████     | 150003/300000 [2:38:47<27:14:16,  1.53it/s]

Step: 150000
Train Loss: 0.5099005699157715
Validation Loss:0.5837936401367188


 51%|█████     | 152003/300000 [2:40:55<26:51:33,  1.53it/s]

Step: 152000
Train Loss: 0.5618845820426941
Validation Loss:0.4813927114009857


 51%|█████▏    | 154003/300000 [2:43:01<25:16:39,  1.60it/s]

Step: 154000
Train Loss: 0.7395942807197571
Validation Loss:0.5811088681221008


 52%|█████▏    | 156003/300000 [2:45:08<24:44:58,  1.62it/s]

Step: 156000
Train Loss: 0.5995473861694336
Validation Loss:0.50574791431427


 53%|█████▎    | 158003/300000 [2:47:14<25:45:27,  1.53it/s]

Step: 158000
Train Loss: 0.38471701741218567
Validation Loss:0.552077054977417


 53%|█████▎    | 160003/300000 [2:49:19<24:20:40,  1.60it/s]

Step: 160000
Train Loss: 0.44382965564727783
Validation Loss:0.5559290051460266


 54%|█████▍    | 162003/300000 [2:51:24<23:57:33,  1.60it/s]

Step: 162000
Train Loss: 0.44400206208229065
Validation Loss:0.5312600135803223


 55%|█████▍    | 164003/300000 [2:53:32<23:37:29,  1.60it/s]

Step: 164000
Train Loss: 0.5653117299079895
Validation Loss:0.5938172340393066


 55%|█████▌    | 165999/300000 [2:55:31<2:10:00, 17.18it/s] 

Step: 166000
Train Loss: 0.41678497195243835
Validation Loss:0.43853890895843506
Saving checkpoint to out/bert_ckpt_train.pt


 56%|█████▌    | 168003/300000 [2:57:47<22:50:52,  1.60it/s]

Step: 168000
Train Loss: 0.5323543548583984
Validation Loss:0.48456576466560364


 57%|█████▋    | 170003/300000 [2:59:53<23:34:09,  1.53it/s]

Step: 170000
Train Loss: 0.48141083121299744
Validation Loss:0.5561127066612244


 57%|█████▋    | 172003/300000 [3:01:58<22:12:26,  1.60it/s]

Step: 172000
Train Loss: 0.48937201499938965
Validation Loss:0.5604078769683838


 58%|█████▊    | 173999/300000 [3:03:57<2:00:42, 17.40it/s] 

Step: 174000
Train Loss: 0.4249516725540161
Validation Loss:0.4230937957763672
Saving checkpoint to out/bert_ckpt_train.pt


 59%|█████▊    | 176003/300000 [3:06:14<21:26:23,  1.61it/s]

Step: 176000
Train Loss: 0.5146735310554504
Validation Loss:0.5022866725921631


 59%|█████▉    | 178003/300000 [3:08:19<20:58:18,  1.62it/s]

Step: 178000
Train Loss: 0.3705000877380371
Validation Loss:0.4371287524700165


 60%|██████    | 180003/300000 [3:10:27<21:39:44,  1.54it/s]

Step: 180000
Train Loss: 0.40041303634643555
Validation Loss:0.5203750729560852


 61%|██████    | 181999/300000 [3:12:27<1:53:18, 17.36it/s] 

Step: 182000
Train Loss: 0.45282745361328125
Validation Loss:0.39269232749938965
Saving checkpoint to out/bert_ckpt_train.pt


 61%|██████▏   | 184003/300000 [3:14:42<20:05:04,  1.60it/s]

Step: 184000
Train Loss: 0.31684690713882446
Validation Loss:0.41984400153160095


 62%|██████▏   | 186003/300000 [3:16:49<20:45:44,  1.53it/s]

Step: 186000
Train Loss: 0.48258960247039795
Validation Loss:0.4698801636695862


 63%|██████▎   | 188003/300000 [3:18:57<19:23:03,  1.60it/s]

Step: 188000
Train Loss: 0.4828051030635834
Validation Loss:0.4158710241317749


 63%|██████▎   | 190003/300000 [3:21:02<18:58:22,  1.61it/s]

Step: 190000
Train Loss: 0.3151511251926422
Validation Loss:0.467922568321228


 64%|██████▍   | 192003/300000 [3:23:10<19:38:02,  1.53it/s]

Step: 192000
Train Loss: 0.4414497911930084
Validation Loss:0.4485511779785156


 65%|██████▍   | 194003/300000 [3:25:18<19:17:51,  1.53it/s]

Step: 194000
Train Loss: 0.45568275451660156
Validation Loss:0.5416445136070251


 65%|██████▌   | 196003/300000 [3:27:25<18:19:12,  1.58it/s]

Step: 196000
Train Loss: 0.38939914107322693
Validation Loss:0.5819482207298279


 66%|██████▌   | 198003/300000 [3:29:31<17:44:31,  1.60it/s]

Step: 198000
Train Loss: 0.4617302715778351
Validation Loss:0.5025505423545837


 67%|██████▋   | 200003/300000 [3:31:36<18:06:56,  1.53it/s]

Step: 200000
Train Loss: 0.518556535243988
Validation Loss:0.5495485663414001


 67%|██████▋   | 202003/300000 [3:33:42<16:59:21,  1.60it/s]

Step: 202000
Train Loss: 0.4740409255027771
Validation Loss:0.5131639242172241


 68%|██████▊   | 204003/300000 [3:35:49<16:27:20,  1.62it/s]

Step: 204000
Train Loss: 0.5001590847969055
Validation Loss:0.44082385301589966


 69%|██████▊   | 206003/300000 [3:37:54<16:10:28,  1.61it/s]

Step: 206000
Train Loss: 0.4449859857559204
Validation Loss:0.5019554495811462


 69%|██████▉   | 208003/300000 [3:39:59<15:55:12,  1.61it/s]

Step: 208000
Train Loss: 0.39769065380096436
Validation Loss:0.4489014744758606


 70%|███████   | 210003/300000 [3:42:04<15:45:02,  1.59it/s]

Step: 210000
Train Loss: 0.43222278356552124
Validation Loss:0.6079518795013428


 71%|███████   | 212003/300000 [3:44:09<15:12:55,  1.61it/s]

Step: 212000
Train Loss: 0.42072024941444397
Validation Loss:0.4374060034751892


 71%|███████▏  | 214003/300000 [3:46:16<14:50:13,  1.61it/s]

Step: 214000
Train Loss: 0.3950793445110321
Validation Loss:0.4925355315208435


 72%|███████▏  | 216003/300000 [3:48:22<14:37:56,  1.59it/s]

Step: 216000
Train Loss: 0.41905882954597473
Validation Loss:0.5505979657173157


 73%|███████▎  | 218003/300000 [3:50:29<14:14:47,  1.60it/s]

Step: 218000
Train Loss: 0.4071515202522278
Validation Loss:0.5404424071311951


 73%|███████▎  | 220003/300000 [3:52:35<13:46:49,  1.61it/s]

Step: 220000
Train Loss: 0.45157790184020996
Validation Loss:0.5169483423233032


 74%|███████▍  | 222003/300000 [3:54:41<14:04:48,  1.54it/s]

Step: 222000
Train Loss: 0.3623031675815582
Validation Loss:0.51188725233078


 75%|███████▍  | 224003/300000 [3:56:48<13:49:26,  1.53it/s]

Step: 224000
Train Loss: 0.490761399269104
Validation Loss:0.5873671770095825


 75%|███████▌  | 226003/300000 [3:58:55<13:30:24,  1.52it/s]

Step: 226000
Train Loss: 0.5747248530387878
Validation Loss:0.4937607944011688


 76%|███████▌  | 228003/300000 [4:01:01<12:42:34,  1.57it/s]

Step: 228000
Train Loss: 0.4264153242111206
Validation Loss:0.4693658649921417


 77%|███████▋  | 230003/300000 [4:03:08<12:07:08,  1.60it/s]

Step: 230000
Train Loss: 0.47540712356567383
Validation Loss:0.4336985647678375


 77%|███████▋  | 232003/300000 [4:05:13<11:47:09,  1.60it/s]

Step: 232000
Train Loss: 0.5265156626701355
Validation Loss:0.4431438446044922


 78%|███████▊  | 234003/300000 [4:07:18<11:20:40,  1.62it/s]

Step: 234000
Train Loss: 0.45441383123397827
Validation Loss:0.42819324135780334


 79%|███████▊  | 236003/300000 [4:09:25<11:40:50,  1.52it/s]

Step: 236000
Train Loss: 0.4560371935367584
Validation Loss:0.4268049895763397


 79%|███████▉  | 238003/300000 [4:11:31<10:45:36,  1.60it/s]

Step: 238000
Train Loss: 0.4428126811981201
Validation Loss:0.4026949107646942


 80%|████████  | 240003/300000 [4:13:37<10:25:09,  1.60it/s]

Step: 240000
Train Loss: 0.44823363423347473
Validation Loss:0.525060772895813


 81%|████████  | 242003/300000 [4:15:44<10:34:27,  1.52it/s]

Step: 242000
Train Loss: 0.4719600975513458
Validation Loss:0.6377857327461243


 81%|████████▏ | 244003/300000 [4:17:50<9:53:20,  1.57it/s] 

Step: 244000
Train Loss: 0.4381667375564575
Validation Loss:0.5474753975868225


 82%|████████▏ | 246003/300000 [4:19:56<9:19:06,  1.61it/s] 

Step: 246000
Train Loss: 0.37694475054740906
Validation Loss:0.4526347517967224


 83%|████████▎ | 247999/300000 [4:21:57<50:38, 17.11it/s]  

Step: 248000
Train Loss: 0.4402392506599426
Validation Loss:0.34450018405914307
Saving checkpoint to out/bert_ckpt_train.pt


 83%|████████▎ | 250003/300000 [4:24:12<8:37:24,  1.61it/s] 

Step: 250000
Train Loss: 0.43769508600234985
Validation Loss:0.49297890067100525


 84%|████████▍ | 252003/300000 [4:26:17<8:20:19,  1.60it/s] 

Step: 252000
Train Loss: 0.39479994773864746
Validation Loss:0.45552897453308105


 85%|████████▍ | 254003/300000 [4:28:23<7:54:56,  1.61it/s] 

Step: 254000
Train Loss: 0.3387233018875122
Validation Loss:0.46814265847206116


 85%|████████▌ | 256003/300000 [4:30:28<7:38:24,  1.60it/s] 

Step: 256000
Train Loss: 0.4828459322452545
Validation Loss:0.5028723478317261


 86%|████████▌ | 258003/300000 [4:32:37<7:37:23,  1.53it/s] 

Step: 258000
Train Loss: 0.37815842032432556
Validation Loss:0.5255323052406311


 87%|████████▋ | 260003/300000 [4:34:45<6:56:04,  1.60it/s]

Step: 260000
Train Loss: 0.34347862005233765
Validation Loss:0.45995596051216125


 87%|████████▋ | 262003/300000 [4:36:51<6:31:39,  1.62it/s]

Step: 262000
Train Loss: 0.4417198896408081
Validation Loss:0.540111780166626


 88%|████████▊ | 264003/300000 [4:38:56<6:12:13,  1.61it/s]

Step: 264000
Train Loss: 0.3721959590911865
Validation Loss:0.5623093843460083


 89%|████████▊ | 266003/300000 [4:41:02<6:10:44,  1.53it/s]

Step: 266000
Train Loss: 0.43193116784095764
Validation Loss:0.46273699402809143


 89%|████████▉ | 268003/300000 [4:43:08<5:49:48,  1.52it/s]

Step: 268000
Train Loss: 0.44425147771835327
Validation Loss:0.6195017099380493


 90%|█████████ | 270003/300000 [4:45:14<5:11:54,  1.60it/s]

Step: 270000
Train Loss: 0.3490417003631592
Validation Loss:0.46674299240112305


 91%|█████████ | 272003/300000 [4:47:20<4:50:48,  1.60it/s]

Step: 272000
Train Loss: 0.4057326018810272
Validation Loss:0.4279691278934479


 91%|█████████▏| 274003/300000 [4:49:27<4:31:04,  1.60it/s]

Step: 274000
Train Loss: 0.6203534603118896
Validation Loss:0.5672678351402283


 92%|█████████▏| 276003/300000 [4:51:35<4:09:32,  1.60it/s]

Step: 276000
Train Loss: 0.3652838170528412
Validation Loss:0.5374311804771423


 93%|█████████▎| 278003/300000 [4:53:43<3:49:05,  1.60it/s]

Step: 278000
Train Loss: 0.41721415519714355
Validation Loss:0.5065081119537354


 93%|█████████▎| 280003/300000 [4:55:48<3:28:10,  1.60it/s]

Step: 280000
Train Loss: 0.4083442687988281
Validation Loss:0.47443172335624695


 94%|█████████▍| 281999/300000 [4:57:48<17:08, 17.50it/s]  

Step: 282000
Train Loss: 0.49784544110298157
Validation Loss:0.3317492604255676
Saving checkpoint to out/bert_ckpt_train.pt


 95%|█████████▍| 284003/300000 [5:00:03<2:45:35,  1.61it/s]

Step: 284000
Train Loss: 0.4849894344806671
Validation Loss:0.4857332110404968


 95%|█████████▌| 286003/300000 [5:02:08<2:25:16,  1.61it/s]

Step: 286000
Train Loss: 0.4332057237625122
Validation Loss:0.6089721322059631


 96%|█████████▌| 288003/300000 [5:04:14<2:04:45,  1.60it/s]

Step: 288000
Train Loss: 0.4164702594280243
Validation Loss:0.5200754404067993


 97%|█████████▋| 290003/300000 [5:06:22<1:49:06,  1.53it/s]

Step: 290000
Train Loss: 0.5124045610427856
Validation Loss:0.50863116979599


 97%|█████████▋| 292003/300000 [5:08:29<1:22:54,  1.61it/s]

Step: 292000
Train Loss: 0.3377377390861511
Validation Loss:0.6454980373382568


 98%|█████████▊| 294003/300000 [5:10:35<1:01:53,  1.61it/s]

Step: 294000
Train Loss: 0.33839014172554016
Validation Loss:0.4385831654071808


 99%|█████████▊| 296003/300000 [5:12:43<43:42,  1.52it/s]  

Step: 296000
Train Loss: 0.483060359954834
Validation Loss:0.5036545395851135


 99%|█████████▉| 298003/300000 [5:14:52<21:45,  1.53it/s]

Step: 298000
Train Loss: 0.37724998593330383
Validation Loss:0.4879929721355438


100%|██████████| 300000/300000 [5:16:55<00:00, 15.78it/s]


In [5]:
# Evaluation
import torch
from rag.snliDataset import sentenceBERTDataset
from bert_config import BERTConfig, BERTTrainConfig
from sentenceBERT import sentenceBERT
from torch.utils.data.dataloader import DataLoader
from bert_utils import dynamic_padding

device = "cuda"
test_set = sentenceBERTDataset("test")
ckpt_path = "out/bert_ckpt_train.pt"
ckpt = torch.load(ckpt_path)
model = sentenceBERT(BERTConfig())
model.load_state_dict(ckpt["model"])
model.to(device)

correct, total = 0, 0
test_loader = DataLoader(test_set,batch_size=8,collate_fn=dynamic_padding)
with torch.no_grad():
    for batch in test_loader:
        outputs = model(batch["sentence_1"],batch["sentence_2"])
        _, predicted = torch.max(outputs, 1)
        # print(predicted, batch["label"])
        total += batch["label"].size(0)
        correct += (predicted == batch["label"].to(device)).sum().item()
    
print(f"Test Accuracy: {100 * correct / total:.2f}%")

Loading extracted data from /home/varun/Downloads/snli_1.0/sentenceBERT_test.pt


  self.data = torch.load(cache_fpath)["data"]
  ckpt = torch.load(ckpt_path)


Loading pre-trained weights for BERT


AttributeError: 'sentence' object has no attribute 'to'