<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/misc/SentenceSimilarity(DistilRoBERTa_DistilRoBERTa).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! nvidia-smi

In [None]:
! pip install transformers
! pip install datasets
! pip install wget

In [1]:
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git
%cd PyTorch-Architectures/

fatal: destination path 'PyTorch-Architectures' already exists and is not an empty directory.
/content/PyTorch-Architectures


In [2]:
from tqdm.auto import tqdm
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from toolkit.utils import get_optimal_batchsize, get_linear_schedule_with_warmup
from toolkit.utils import dict_to_device, EarlyStopping

In [3]:
dataset = load_dataset("quora")

Using custom data configuration default
Reusing dataset quora (/root/.cache/huggingface/datasets/quora/default/0.0.0/36ba4cd42107f051a158016f1bea6ae3f4685c5df843529108a54e42d86c1e04)


In [4]:
train_p = []
train_n = []
test_list = []
count_p, count_n = 0, 0
for idx, sample in enumerate(dataset["train"]):
  text_1, text_2 = sample["questions"]["text"][0], sample["questions"]["text"][1]
  if len(train_p) < 20000 and sample["is_duplicate"]:
    train_p.append((text_1, text_2, 1))
  elif len(train_n) < 20000 and not sample["is_duplicate"]:
    train_n.append((text_1, text_2, 0))
  elif len(test_list) < 10000:
    is_duplicate = 1 if sample["is_duplicate"] else 0
    test_list.append((text_1, text_2, is_duplicate))
train_list = []
train_list.extend(train_p)
train_list.extend(train_n)
random.shuffle(train_list)
print(f"No. of Train Samples: {len(train_list)} || No. of Test Samples: {len(test_list)}")

No. of Train Samples: 40000 || No. of Test Samples: 10000


In [5]:
# path_str = "distilroberta-base"
path_str = "distilbert-base-uncased"

In [6]:
class CustomDataset(Dataset):
  def __init__(self, path_str: str, list_samples: list, max_input_length: int = 16):
    self.tokenizer = AutoTokenizer.from_pretrained(path_str)
    self.list_samples = list_samples
    self.max_input_length = max_input_length
  
  def __len__(self,):
    return len(self.list_samples)
  
  def __getitem__(self, idx):
    sample = self.list_samples[idx]
    return {
        'text_1': sample[0],
        'text_2': sample[1],
        'is_duplicate': sample[2],
    }
  
  def collate_fn(self, batch):
    text_1 = []
    text_2 = []
    labels = []
    for sample in batch:
      text_1.append(sample["text_1"])
      text_2.append(sample["text_2"])
      labels.append(sample["is_duplicate"])
    tokens_1 = self.tokenizer(text_1,
                              max_length=self.max_input_length,
                              padding=True,
                              truncation=True,
                              return_tensors="pt",
                              )
    tokens_2 = self.tokenizer(text_2,
                              max_length=self.max_input_length,
                              padding=True,
                              truncation=True,
                              return_tensors="pt",
                              )
    ids_1, att_1 = tokens_1["input_ids"], tokens_1["attention_mask"]
    ids_2, att_2 = tokens_2["input_ids"], tokens_2["attention_mask"]
    labels = torch.tensor(labels, dtype=torch.long)
    return {
        "input_ids_1": ids_1,
        "attention_mask_1": att_1,
        "input_ids_2": ids_2,
        "attention_mask_2": att_2,
        "labels": labels,
    }

In [7]:
class Attention(nn.Module):
  def __init__(self, in_size: int = 768, hidden_size: int = 512):
    super().__init__()
    self.W = nn.Linear(in_size, hidden_size)
    self.V = nn.Linear(hidden_size, 1)
    self.dropout = nn.Dropout(0.3)
  
  def forward(self, x):
    x = torch.tanh(self.W(x))
    score = self.V(x)
    attention_weights = score.softmax(dim=1)
    context_vector = x * attention_weights
    context_vector = torch.sum(context_vector, dim=1)
    output = self.dropout(context_vector)
    return output

In [8]:
class SentenceSimilarity(nn.Module):
  def __init__(self, path_str: str, in_size: int = 768, hidden_size: int = 768):
    super().__init__()
    self.encoder = AutoModel.from_pretrained(path_str)
    self.attention = Attention(in_size, hidden_size)
    self.ff = nn.Linear(hidden_size * 2, 2)
  
  def forward(self,
              input_ids_1,
              attention_mask_1,
              input_ids_2,
              attention_mask_2,
              labels=None,
              ):
    
    outputs_1 = self.encoder(input_ids=input_ids_1,
                               attention_mask=attention_mask_1)
    outputs_2 = self.encoder(input_ids=input_ids_2,
                               attention_mask=attention_mask_2)
    
    enc_weights_1 = self.attention(outputs_1.last_hidden_state)
    enc_weights_2 = self.attention(outputs_2.last_hidden_state)

    output = torch.cat([enc_weights_1, enc_weights_2], dim=1)

    logits = self.ff(output)
    loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits.view(logits.size(0), -1), labels.view(-1))
    return (loss, logits)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SentenceSimilarity(path_str=path_str)
model.to(device)

In [10]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable Parameters: {params}")

Trainable Parameters: 66957315


In [None]:
# Create Datasets
dataset_train = CustomDataset(path_str=path_str,
                              list_samples=train_list,
                              max_input_length=32,
                              )
dataset_valid = CustomDataset(path_str=path_str,
                              list_samples=test_list,
                              max_input_length=32,
                              )

In [12]:
# get_optimal_batchsize(dataset_train, model, fp16=False)

In [12]:
# Hyperparameter section
BS = 256
EPOCHS = 100
LR = 3e-5

In [13]:
train_loader = DataLoader(dataset=dataset_train,
                          batch_size=BS,
                          shuffle=True,
                          collate_fn=dataset_train.collate_fn,
                          )
valid_loader = DataLoader(dataset=dataset_valid,
                          batch_size=BS,
                          shuffle=True,
                          collate_fn=dataset_valid.collate_fn,
                          )
print(f"Length of Train Loader: {len(train_loader)} || Length of Valid Loader: {len(valid_loader)}")

Length of Train Loader: 157 || Length of Valid Loader: 40


In [14]:
early_stop = EarlyStopping(metric="val_accuracy", patience=5, verbose=True)
num_training_steps = len(train_loader) * EPOCHS
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=num_training_steps,
                                            )

In [15]:
progress_bar = tqdm(range(num_training_steps))

def get_accuracy(model, data_loader, device):
  if model.training:
    print("Model is in train mode...switching to eval mode!")
    model.eval()
  
  correct, total = 0, 0
  with torch.set_grad_enabled(False):
    for sample in data_loader:
      sample = dict_to_device(sample, device=device)
      labels = sample["labels"].view(-1)
      
      
      outputs = model(**sample)

      _, logits = outputs
      
      probs = torch.softmax(logits, dim=-1)
      _, preds = torch.max(probs, dim=-1)
      correct += (preds == labels).sum()
      total += labels.size(0)
  return (correct / total * 100).item()

for epoch in range(EPOCHS):
  model.train()
  for sample in train_loader:
    outputs = model(**dict_to_device(sample, device=device))

    loss, _ = outputs
    
    loss.backward()
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)
  model.eval()
  with torch.set_grad_enabled(False):
    valid_acc = get_accuracy(model, valid_loader, device)
    early_stop(valid_acc, model)
    if early_stop.early_stop:
      print("Early Stopping!")
      break

  0%|          | 0/15700 [00:00<?, ?it/s]

Validation accuracy increased from -inf% to 61.01%
Validation accuracy increased from 61.01% to 62.78%
Validation accuracy increased from 62.78% to 67.57%
EarlyStopping counter: 1 out of 5
EarlyStopping counter: 2 out of 5
Validation accuracy increased from 67.57% to 74.11%
EarlyStopping counter: 1 out of 5
EarlyStopping counter: 2 out of 5
EarlyStopping counter: 3 out of 5
EarlyStopping counter: 4 out of 5
EarlyStopping counter: 5 out of 5
Early Stopping!
