## BERT finetuning for sentiment classification

In [1]:
# Let's finetune BERT for the task of sentiment classification.
# Here are some popular datasets for sentiment classification and their sizes: https://paperswithcode.com/sota/sentiment-analysis-on-imdb
# Let's load SST-2 dataset from Huggingface and finetune BERT on it.

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Load the dataset
dataset = load_dataset("stanfordnlp/sst2")

In [5]:
dataset['train'][0]

{'idx': 0,
 'sentence': 'hide new secretions from the parental units ',
 'label': 0}

In [6]:
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")



In [7]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, return_tensors="pt")

tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

Map: 100%|██████████| 67349/67349 [00:31<00:00, 2106.63 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 1682.43 examples/s]
Map: 100%|██████████| 1821/1821 [00:01<00:00, 1636.59 examples/s]


In [8]:
type(tokenized_dataset['train'][0]['input_ids'])

torch.Tensor

In [9]:
# Load the model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# Prepare the DataLoader
train_dataloader = DataLoader(tokenized_dataset["train"].select(range(1000)), shuffle=True, batch_size=32)
eval_dataloader = DataLoader(tokenized_dataset["validation"], batch_size=32)

In [12]:
# Evaluate the model
model.eval()

# Initialize lists to store labels and predictions
all_labels = []
all_predictions = []

# Set the number of batches to evaluate
num_batches_to_evaluate = 1
batch_counter = 0

# Evaluate on a limited number of batches
for batch in tqdm(eval_dataloader):
    if batch_counter >= num_batches_to_evaluate:
        break
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["label"]
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    predictions = torch.argmax(outputs.logits, dim=1)
    all_labels.extend(labels.cpu().numpy())
    all_predictions.extend(predictions.cpu().numpy())
    batch_counter += 1


  4%|▎         | 1/28 [00:31<14:07, 31.38s/it]


In [13]:
accuracy = accuracy_score(all_labels, all_predictions)
print(f"Accuracy: {accuracy}")

print(classification_report(all_labels, all_predictions))
print(confusion_matrix(all_labels, all_predictions))

Accuracy: 0.5625
              precision    recall  f1-score   support

           0       0.57      0.94      0.71        18
           1       0.50      0.07      0.12        14

    accuracy                           0.56        32
   macro avg       0.53      0.51      0.42        32
weighted avg       0.54      0.56      0.45        32

[[17  1]
 [13  1]]


In [13]:
# Prepare the optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)



In [14]:
# Train the model
model.train()

for epoch in range(3):
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["label"]
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

  0%|                                                                                                                                                      | 0/32 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:27<00:00,  1.19it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:27<00:00,  1.18it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:27<00:00,  1.18it/s]


In [None]:
# Evaluate the model
model.eval()

# Initialize lists to store labels and predictions
all_labels = []
all_predictions = []

# Set the number of batches to evaluate
num_batches_to_evaluate = 1
batch_counter = 0

# Evaluate on a limited number of batches
for batch in tqdm(eval_dataloader):
    if batch_counter >= num_batches_to_evaluate:
        break
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["label"]
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    predictions = torch.argmax(outputs.logits, dim=1)
    all_labels.extend(labels.cpu().numpy())
    all_predictions.extend(predictions.cpu().numpy())
    batch_counter += 1

# Now you can use all_labels and all_predictions for further evaluation

  4%|█████                                                                                                                                         | 1/28 [00:00<00:07,  3.43it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00,  3.50it/s]


In [16]:
accuracy = accuracy_score(all_labels, all_predictions)
print(f"Accuracy: {accuracy}")

print(classification_report(all_labels, all_predictions))
print(confusion_matrix(all_labels, all_predictions))

Accuracy: 0.8784403669724771
              precision    recall  f1-score   support

           0       0.90      0.84      0.87       428
           1       0.86      0.91      0.88       444

    accuracy                           0.88       872
   macro avg       0.88      0.88      0.88       872
weighted avg       0.88      0.88      0.88       872

[[361  67]
 [ 39 405]]
