# BERT - finetuning for classification

## Model Loading

In [1]:
from transformers import BertModel, AutoTokenizer,BertConfig
from torch.optim import Adam 

base_model=BertModel.from_pretrained("google-bert/bert-base-uncased")
tokenizer=AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

  from .autonotebook import tqdm as notebook_tqdm


## Constants

In [2]:
import torch 


NUM_OF_CLASSES=3
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_OF_EPOCHS=20


### defining the collate_func

In [3]:
import torch


def collate_func(list_of_data_points):
    
    tokenized_texts=tokenizer([x["text"] for x in list_of_data_points],
                              max_length=128,
                              truncation=True,
                              padding="max_length",
                              return_tensors="pt"
                              )
    return {"input_ids":tokenized_texts["input_ids"],
            "attention_mask":tokenized_texts["attention_mask"],
            'token_type_ids':tokenized_texts["token_type_ids"],
            "label": torch.tensor([x["label"] for x in list_of_data_points],dtype=torch.long)
            }

### dataloading

In [4]:
from datasets import load_dataset
# sample data
ds = load_dataset("Sp1786/multiclass-sentiment-analysis-dataset")


In [5]:
from torch.utils.data import DataLoader

train_dataloader=DataLoader(dataset=ds["train"],batch_size=16,collate_fn=collate_func)
val_dataloader=DataLoader(dataset=ds["validation"],batch_size=16,collate_fn=collate_func)

## Wrapper Model definition

In [6]:
import torch.nn as nn
import torch.nn.functional as F

class WrapperModel(nn.Module):
    def __init__(self,) -> None:
        super().__init__()
        self.base_model=BertModel.from_pretrained("google-bert/bert-base-uncased")
        self.classification_head=nn.Linear(BertConfig().hidden_size,NUM_OF_CLASSES)
        for param in self.base_model.parameters(): # Freezing the base model parameters
            param.requires_grad=False
        
    
    def forward(self,input_ids,attention_masks,token_type_ids):
        x=self.base_model(input_ids,attention_masks,token_type_ids)["pooler_output"]
        x=F.relu(x)
        logits=self.classification_head(x)
        return logits

In [7]:
model=WrapperModel()
loss_func=nn.CrossEntropyLoss()
optimizer=Adam(model.parameters(),lr=0.001)

# training loop

In [8]:
from sklearn.metrics import confusion_matrix,classification_report
from tqdm import tqdm

In [9]:
# training loop
model=model.to(DEVICE)

for epoch_id in range(NUM_OF_EPOCHS):
    tot_loss=0
    tot_count=0
    model.train()
    for data in tqdm(train_dataloader):
        input_ids,attention_masks,token_type_ids,label=data["input_ids"].to(DEVICE),data["attention_mask"].to(DEVICE),data["token_type_ids"].to(DEVICE),data["label"].to(DEVICE)
        logits=model(input_ids,attention_masks,token_type_ids)
        loss=loss_func(logits,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tot_loss+=loss.item()
        tot_count+=label.shape[0]
        
    print(f"avg training loss of {tot_loss/tot_count} at epoch_number {epoch_id+1}")

    if epoch_id%5==0:
        model.eval()
        all_pred,all_truths=[],[]
        with torch.no_grad():
            for data in val_dataloader:
                input_ids,attention_masks,token_type_ids,label=data["input_ids"].to(DEVICE),data["attention_mask"].to(DEVICE),data["token_type_ids"].to(DEVICE),data["label"].to(DEVICE)
                logits=model(input_ids,attention_masks,token_type_ids)
                predicted=torch.argmax(logits,dim=-1).view(-1).tolist()
                truth=label.view(-1).tolist()
                all_pred.extend(predicted)
                all_truths.extend(truth)
            conf_array=confusion_matrix(all_truths,all_pred)
            live_metrics=classification_report(all_truths,all_pred)
            print("\n\n On the Validation Dataset ")
            print("confusion_matrix")
            print(conf_array)
        


100%|██████████| 1952/1952 [01:35<00:00, 20.54it/s]


avg training loss of 0.064557820758935 at epoch_number 1


 On the Validation Dataset 
confusion_matrix
[[ 106 1352   59]
 [  31 1779  118]
 [  10 1198  552]]


100%|██████████| 1952/1952 [01:42<00:00, 19.00it/s]


avg training loss of 0.06021692003428814 at epoch_number 2


100%|██████████| 1952/1952 [01:47<00:00, 18.18it/s]


avg training loss of 0.0584824707160597 at epoch_number 3


100%|██████████| 1952/1952 [01:47<00:00, 18.18it/s]


avg training loss of 0.057577101593898214 at epoch_number 4


100%|██████████| 1952/1952 [01:47<00:00, 18.23it/s]


avg training loss of 0.0568378423874984 at epoch_number 5


100%|██████████| 1952/1952 [01:46<00:00, 18.26it/s]


avg training loss of 0.056513312522734165 at epoch_number 6


 On the Validation Dataset 
confusion_matrix
[[ 412  988  117]
 [ 136 1513  279]
 [  48  754  958]]


100%|██████████| 1952/1952 [01:46<00:00, 18.26it/s]


avg training loss of 0.056132759261479384 at epoch_number 7


100%|██████████| 1952/1952 [01:47<00:00, 18.19it/s]


avg training loss of 0.055921754215798175 at epoch_number 8


100%|██████████| 1952/1952 [01:46<00:00, 18.28it/s]


avg training loss of 0.0556175043867504 at epoch_number 9


100%|██████████| 1952/1952 [01:47<00:00, 18.22it/s]


avg training loss of 0.05539915186723267 at epoch_number 10


100%|██████████| 1952/1952 [01:47<00:00, 18.18it/s]


avg training loss of 0.055372260415499086 at epoch_number 11


 On the Validation Dataset 
confusion_matrix
[[ 445  954  118]
 [ 137 1477  314]
 [  51  684 1025]]


 32%|███▏      | 633/1952 [00:34<01:12, 18.22it/s]