#### Finetuning a BERT Model for Sentiment Classification

We will take the original `BERT` model trained on the masked language modeling task and `finetune` it for `sentiment classification` on the Stanford Sentiment Tree and CFIMDB datasets. The `pretrained` BERT model takes in an input sequence of integer tokens and outputs a corresponding sequence of contextualized encoded vectors (768 dimensional). A special `[CLS]` token is appended at the start of the input sequence and the coressponding encoded output vector of this token represents a `pooled representation` of the entire sequence. This pooled representation vector can then be used by a feedforward network to perform a sentence classification task. All parameters in this combined model (consisting of the BERT and the feedforward classifier) can be trained together to optimize the model for the sentence classification task, this process is called `finetuning`, because it involves adapting the pretrained parameters of BERT for this specialized task.

In [2]:
import torch
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import csv
from tqdm import tqdm
import psutil
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

#### We will use the WordPiece tokenizer and the pre-trained BERT provided by the Hugginface transformers library. First, lets try out the tokenizer.

In [3]:
# load the prettrained WordPiece tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# use it on a test sentence
sentence = "Yay, I'm excited to try out this BERT model from Huggingface!"
tokens_subwords = tokenizer.tokenize(sentence)
tokens_idx = tokenizer.encode(sentence)
idx_to_tokens = tokenizer.convert_ids_to_tokens(tokens_idx)
decoded_sentence = tokenizer.decode(tokens_idx)
print("Original sentence: ", sentence)
print("Subword tokens: ", tokens_subwords)
print("Encoded sentence: ", tokens_idx)
print("Idx back to tokens: ", idx_to_tokens)
print("Decoded sentence: ", decoded_sentence)

# let's also take a look at all the special tokens
print("\nSpecial tokens with their integer id:")
special_tokens = tokenizer.all_special_tokens
for t in special_tokens:
    print(t," <--> " ,tokenizer.convert_tokens_to_ids(t))

Original sentence:  Yay, I'm excited to try out this BERT model from Huggingface!
Subword tokens:  ['ya', '##y', ',', 'i', "'", 'm', 'excited', 'to', 'try', 'out', 'this', 'bert', 'model', 'from', 'hugging', '##face', '!']
Encoded sentence:  [101, 8038, 2100, 1010, 1045, 1005, 1049, 7568, 2000, 3046, 2041, 2023, 14324, 2944, 2013, 17662, 12172, 999, 102]
Idx back to tokens:  ['[CLS]', 'ya', '##y', ',', 'i', "'", 'm', 'excited', 'to', 'try', 'out', 'this', 'bert', 'model', 'from', 'hugging', '##face', '!', '[SEP]']
Decoded sentence:  [CLS] yay, i'm excited to try out this bert model from huggingface! [SEP]

Special tokens with their integer id:
[UNK]  <-->  100
[SEP]  <-->  102
[PAD]  <-->  0
[CLS]  <-->  101
[MASK]  <-->  103


#### Note that since we're using the `uncased` version of the tokenizer, everything becomes lowercase.

Now let's load the SST dataset from file and package it inside a pytroch dataset class. Each data instance is a sentence-sentiment value pair, there are 5 different sentiment labels: NEGATIVE (0), SOMEWHAT NEGATIVE (1), NEUTRAL (2), SOMEWHAT POSITIVE (3), POSITIVE (4) 

In [4]:
def load_data_sst(split="train"):
    if split == "test":
        filename = "data/ids-sst-test-student.csv"    
        data = []
        with open(filename, 'r') as f:
            for record in csv.DictReader(f, delimiter='\t'):
                sent = record['sentence'].lower().strip()
                sent_id = record['id'].lower().strip()
                data.append((sent,sent_id))
        return data          
    else:
        if split == "train":
            filename = "data/ids-sst-train.csv"
        elif split== "dev":
            filename = "data/ids-sst-dev.csv"   
        data = []
        labels = []
        with open(filename, 'r') as f:
            for record in csv.DictReader(f, delimiter='\t'):
                sent = record['sentence'].lower().strip()
                sent_id = record['id'].lower().strip()
                label = int(record['sentiment'].strip())
                data.append((sent,label,sent_id))
                labels.append(label)
        label_distribution = Counter(labels)        
        return data, label_distribution

In [5]:
sst_train, train_label_distribution = load_data_sst(split="train")
sst_dev, dev_label_distribution = load_data_sst(split="dev")

print(f"Number of training examples: {len(sst_train)}")
print(f"Train Label distribution: {train_label_distribution}")
print(f"Number of dev examples: {len(sst_dev)}")
print(f"Dev Label distribution: {dev_label_distribution}")


Number of training examples: 8544
Train Label distribution: Counter({3: 2322, 1: 2218, 2: 1624, 4: 1288, 0: 1092})
Number of dev examples: 1101
Dev Label distribution: Counter({1: 289, 3: 279, 2: 229, 4: 165, 0: 139})


In [25]:
class SSTDataset(Dataset):
    def __init__(self, data, max_length=128):
        self.data = data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    # collate function for padding the sentences to the same length and creating attention masks
    def collate_fn(self, batch):
        sents = [x[0] for x in batch]
        labels = [x[1] for x in batch]
        encoded = self.tokenizer.batch_encode_plus(sents, add_special_tokens=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        input_idx = encoded['input_ids']
        attn_mask = encoded['attention_mask']   
        #token_type_idx = encoded['token_type_ids'] # don't need this since we only have one sentence
        labels = torch.tensor(labels)
        return input_idx, labels, attn_mask

#### Now we define our sentiment classifier model.

In [47]:
class BERTSentimentClassifier(torch.nn.Module):
    def __init__(self, hidden_size=768, num_classes=5, dropout_rate=0.1, finetune=False):
        super().__init__()
        # load pretrained BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define classifier head
        self.classifier_head = torch.nn.Linear(hidden_size, num_classes)

        for param in self.bert.parameters():
            if finetune:
                # make all parameters of BERT model trainable if we're finetuning
                param.requires_grad = True
            else:
                # freeze all parameters of BERT model if we're not finetuning
                param.requires_grad = False

    def forward(self, input_idx, labels, attn_mask):
        # compute BERT encodings
        bert_output = self.bert(input_idx, attention_mask=attn_mask)
        # extract the `[CLS]` encoding (first element of the sequence)
        bert_output = bert_output.last_hidden_state # shape: (batch_size, sequence_length, hidden_size)
        cls_encoding = bert_output[:, 0] # shape: (batch_size, hidden_size)
        # apply dropout 
        cls_encoding = self.dropout(cls_encoding) 
        # compute classifier logits
        logits = self.classifier_head(cls_encoding)  # shape: (batch_size, num_classes)
        # compute loss
        loss = F.cross_entropy(logits, labels)

        return logits, loss

In [44]:
# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, val_every=1, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    model.train()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets, attn_mask = batch
            # move batch to device
            inputs, targets, attn_mask = inputs.to(device), targets.to(device), attn_mask.to(device)
            # forward pass
            logits, loss = model(inputs, targets, attn_mask)
            # reset gradients
            optimizer.zero_grad()
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            B, _ = inputs.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += y_pred.eq(targets.view(-1)).sum().item()            
            num_total += B
            train_acc = num_correct / num_total        
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        if scheduler is not None:
            scheduler.step()
        
        if val_every is not None:
            if epoch%val_every == 0:
                # compute validation loss
                val_loss, val_acc = validation(model, val_dataloader, device=device)
                pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}") 

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets, attn_mask = batch
            inputs, targets, attn_mask = inputs.to(device), targets.to(device), attn_mask.to(device)
            logits, loss = model(inputs, targets, attn_mask)
            B, _ = inputs.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += y_pred.eq(targets.view(-1)).sum().item()            
            num_total += B
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy


def save_model_checkpoint(model, optimizer, epoch=None, loss=None, filename=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    if filename:
        torch.save(checkpoint, filename)
    else:
        torch.save(checkpoint, 'sentiment_classifier_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer, filename=None):
    if filename:
        checkpoint = torch.load(filename)
    else:
        checkpoint = torch.load('sentiment_classifier_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer     

In [48]:
B = 32
max_length = 128
num_heads = 8
num_layers = 4
learning_rate = 5e-4
DEVICE = "cuda"

train_dataset = SSTDataset(sst_train, max_length=max_length)
val_dataset = SSTDataset(sst_dev, max_length=max_length)
train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2, collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2, collate_fn=val_dataset.collate_fn)

# model with finetuning disabled
model = BERTSentimentClassifier(finetune=False).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 109.486085 M
RAM used: 1584.21 MB


In [49]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=10, save_every=50, val_every=1) 

Epoch 1, EMA Train Loss: 1.378, Train Accuracy:  0.380, Val Loss:  0.000, Val Accuracy:  0.000: 100%|██████████| 267/267 [00:29<00:00,  9.15it/s]
Epoch 2, EMA Train Loss: 1.303, Train Accuracy:  0.433, Val Loss:  1.349, Val Accuracy:  0.412: 100%|██████████| 267/267 [00:29<00:00,  9.08it/s]
Epoch 3, EMA Train Loss: 1.240, Train Accuracy:  0.452, Val Loss:  1.296, Val Accuracy:  0.431: 100%|██████████| 267/267 [00:29<00:00,  8.97it/s]
