#### Finetuning BERT for Paraphrase Detection

Similar to how we finetuned a BERT model to perform sentiment analysis on a single sentence, we can also perform paraphrase detection on a pair of sentences. We feed in an input sequence containing a contatenation of both sentences seperated by the `[SEP]` token. Then we can perform binary classification using the encoded `[CLS]` token exactly like how we did sentiment classification.

We will use the Quora dataset for this task which contains questin pairs with binary labels.

In [1]:
import torch
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import csv
from tqdm import tqdm
import psutil
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

#### Load the dataset

In [2]:
def load_data_quora(split="train"):
    if split == "test":
        filename = "data/quora-test-student.csv"    
        data = []
        with open(filename, 'r') as f:
            for record in csv.DictReader(f, delimiter='\t'):
                sent1 = record['sentence1'].lower().strip()
                sent2 = record['sentence2'].lower().strip()
                sent_id = record['id'].lower().strip()
                data.append((sent1,sent2,sent_id))
        return data          
    else:
        if split == "train":
            filename = "data/quora-train.csv"
        elif split== "dev":
            filename = "data/quora-dev.csv"   
        data = []
        labels = []
        with open(filename, 'r') as f:
            i = 1
            for record in csv.DictReader(f, delimiter='\t'):
                sent1 = record['sentence1'].lower().strip()
                sent2 = record['sentence2'].lower().strip()
                sent_id = record['id'].lower().strip()
                label = record['is_duplicate'].strip()
                if (sent1=="") or (sent2=="") or (label==""):
                    continue
                label = int(float(label))                
                data.append(((sent1,sent2),label,sent_id))
                labels.append(label)
        label_distribution = Counter(labels)        
        return data, label_distribution

In [3]:
quora_train, train_label_distribution = load_data_quora(split="train")
quora_dev, dev_label_distribution = load_data_quora(split="dev")

print(f"Number of training examples: {len(quora_train)}")
print(f"Train Label distribution: {train_label_distribution}")
print(f"Number of dev examples: {len(quora_dev)}")
print(f"Dev Label distribution: {dev_label_distribution}")

Number of training examples: 141497
Train Label distribution: Counter({0: 89224, 1: 52273})
Number of dev examples: 20212
Dev Label distribution: Counter({0: 12627, 1: 7585})


In [4]:
class QuoraDataset(Dataset):
    def __init__(self, data, max_length=128):
        self.data = data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    # collate function for padding the sentences to the same length and creating attention masks
    def collate_fn(self, batch):
        sent_pairs = [x[0] for x in batch]
        labels = [x[1] for x in batch]
        encoded = self.tokenizer.batch_encode_plus(sent_pairs, add_special_tokens=True, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_idx = encoded['input_ids']
        attn_mask = encoded['attention_mask']   
        #token_type_idx = encoded['token_type_ids'] # don't need this since we only have one sentence
        labels = torch.tensor(labels)
        return input_idx, labels, attn_mask

#### Define the model

In [5]:
class BERTParaphraseDetector(torch.nn.Module):
    def __init__(self, hidden_size=768, num_classes=2, dropout_rate=0.1, finetune=False):
        super().__init__()
        # load pretrained BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define classifier head
        self.classifier_head = torch.nn.Linear(hidden_size, num_classes)

        for param in self.bert.parameters():
            if finetune:
                # make all parameters of BERT model trainable if we're finetuning
                param.requires_grad = True
            else:
                # freeze all parameters of BERT model if we're not finetuning
                param.requires_grad = False

    def forward(self, input_idx, labels, attn_mask):
        # compute BERT encodings
        bert_output = self.bert(input_idx, attention_mask=attn_mask)
        # extract the `[CLS]` encoding (first element of the sequence)
        bert_output = bert_output.last_hidden_state # shape: (batch_size, sequence_length, hidden_size)
        cls_encoding = bert_output[:, 0] # shape: (batch_size, hidden_size)
        # apply dropout 
        cls_encoding = self.dropout(cls_encoding) 
        # compute classifier logits
        logits = self.classifier_head(cls_encoding)  # shape: (batch_size, num_classes)
        # compute loss
        loss = F.cross_entropy(logits, labels)

        return logits, loss
    

# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, val_every=1, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    model.train()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets, attn_mask = batch
            # move batch to device
            inputs, targets, attn_mask = inputs.to(device), targets.to(device), attn_mask.to(device)
            # forward pass
            logits, loss = model(inputs, targets, attn_mask)
            # reset gradients
            optimizer.zero_grad()
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            B, _ = inputs.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += y_pred.eq(targets.view(-1)).sum().item()            
            num_total += B
            train_acc = num_correct / num_total        
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        if scheduler is not None:
            scheduler.step()
        
        if val_every is not None:
            if epoch%val_every == 0:
                # compute validation loss
                val_loss, val_acc = validation(model, val_dataloader, device=device)
                pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}") 

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets, attn_mask = batch
            inputs, targets, attn_mask = inputs.to(device), targets.to(device), attn_mask.to(device)
            logits, loss = model(inputs, targets, attn_mask)
            B, _ = inputs.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += y_pred.eq(targets.view(-1)).sum().item()            
            num_total += B
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy


def save_model_checkpoint(model, optimizer, epoch=None, loss=None, filename=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    if filename:
        torch.save(checkpoint, filename)
    else:
        torch.save(checkpoint, 'paraphrase_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer, filename=None):
    if filename:
        checkpoint = torch.load(filename)
    else:
        checkpoint = torch.load('paraphrase_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer         

#### First, let's train the model wihout finetuning the BERT base

In [6]:
max_length = 128
B = 64
DEVICE = "cuda"
learning_rate = 5e-3

train_dataset = QuoraDataset(quora_train, max_length=max_length)
val_dataset = QuoraDataset(quora_dev, max_length=max_length)
train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2, collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2, collate_fn=val_dataset.collate_fn)

# model with finetuning disabled
model = BERTParaphraseDetector(finetune=False).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")


Total number of parameters in transformer network: 109.483778 M
RAM used: 1133.46 MB


In [7]:
input_idx, labels, attn_mask = next(iter(train_dataloader))
print(input_idx.shape, labels.shape, attn_mask.shape)

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


torch.Size([64, 128]) torch.Size([64]) torch.Size([64, 128])


In [8]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=3, save_every=50, val_every=1) 

Epoch 1, EMA Train Loss: 0.565, Train Accuracy:  0.641, Val Loss:  0.000, Val Accuracy:  0.000:   1%|▏         | 28/2211 [00:06<08:19,  4.37it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Epoch 1, EMA Train Loss: 0.590, Train Accuracy:  0.654, Val Loss:  0.000, Val Accuracy:  0.000:   3%|▎         | 62/2211 [00:14<08:12,  4.36it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Epoch 1, EMA Train Loss: 0.533, Train Accuracy:  0.682, Val Loss:  0.000, Val Accuracy:  0.000:   8%|▊         | 182/2211 [00:42<07:48,  4.33it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'lon