# BERT Finetuning Tutorial for Sentiment Analysis

In [1]:
import torch
from torch import nn
from transformers import BertTokenizer, BertModel
from torch.optim import AdamW  
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(42)
    
# Set device (GPU if available, else CPU)
device = (
    "mps" 
    if torch.backends.mps.is_available() 
    else "cuda" 
    if torch.cuda.is_available() 
    else "cpu"
)
device = torch.device(device)
print(f"Using device: {device}")

Using device: cpu


In [4]:
# demo data from https://www.kaggle.com/datasets/ankurzing/sentiment-analysis-for-financial-news?resource=download
data = pd.read_csv('all-data.csv', names=['label', 'text'], encoding='utf-8', encoding_errors='ignore')

In [8]:
# data

In [None]:
# Create a mapping dictionary
label_map = {
    'neutral': 0,
    'positive': 1,
    'negative': 2
}

# Apply the mapping to the 'label' column
data['label'] = data['label'].map(label_map)


In [9]:
# Split data ensuring indices are reset
train_texts, val_texts, train_labels, val_labels = train_test_split(
    data['text'].reset_index(drop=True),
    data['label'].reset_index(drop=True),
    test_size=0.2,
    random_state=42
)

In [19]:
# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained("ProsusAI/finbert")

In [20]:
class NewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        # Convert to list to ensure sequential indexing
        self.texts = texts.tolist() if hasattr(texts, 'tolist') else list(texts)
        self.labels = labels.tolist() if hasattr(labels, 'tolist') else list(labels)
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [12]:
# Create datasets
train_dataset = NewsDataset(train_texts, train_labels, tokenizer)
val_dataset = NewsDataset(val_texts, val_labels, tokenizer)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [13]:
# Create the BERT-based model class
class SentimentClassifier(nn.Module):
    """
    Our sentiment classifier model.
    It uses BERT as the base model and adds a classification head on top.
    """
    def __init__(self, n_classes=3):
        super(SentimentClassifier, self).__init__()
        # Load the pre-trained BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Add a dropout layer for regularization
        self.dropout = nn.Dropout(p=0.3)
        # Add a linear layer for classification
        # BERT base output dimension is 768
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
        
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        # We only need the [CLS] token output for classification
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        # Get the [CLS] token output
        pooled_output = outputs.pooler_output
        # Apply dropout
        output = self.dropout(pooled_output)
        # Get logits through the classifier
        return self.classifier(output)


In [14]:
# Initialize model
model = SentimentClassifier()
model = model.to(device)

# Initialize optimizer
optimizer = AdamW([
    {'params': model.bert.parameters(), 'lr': 2e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
])

In [15]:
# Training function
def train_epoch(model, data_loader, optimizer, device):
    """
    Trains the model for one epoch and returns the average loss.
    """
    model.train()
    total_loss = 0
    
    # Use tqdm for a nice progress bar
    for batch in tqdm(data_loader, desc="Training"):
        # Move batch to device (CPU/GPU)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        
        # Calculate loss
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(outputs, labels)
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        
        total_loss += loss.item()
    
    # Return average loss
    return total_loss / len(data_loader)

In [16]:
# Evaluation function
def evaluate(model, data_loader, device):
    """
    Evaluates the model on the provided data loader.
    Returns accuracy and average loss.
    """
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(outputs, labels)
            
            _, predictions = torch.max(outputs, dim=1)
            
            total_loss += loss.item()
            correct_predictions += torch.sum(predictions == labels)
            total_predictions += labels.shape[0]
    
    # Use float32 instead of double/float64
    accuracy = (correct_predictions.float() / total_predictions) * 100  
    average_loss = total_loss / len(data_loader)
    
    return accuracy, average_loss

In [17]:
# Main training loop
def train_model(model, train_loader, val_loader, optimizer, device, epochs=3):
    """
    Main training loop that handles the entire training process.
    """
    best_accuracy = 0
    
    for epoch in range(epochs):
        print(f'\nEpoch {epoch + 1}/{epochs}')
        
        # Train one epoch
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        # Evaluate
        val_accuracy, val_loss = evaluate(model, val_loader, device)
        
        # Print metrics
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Validation Accuracy: {val_accuracy:.4f}')
        
        # Save best model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pt')
            print('Best model saved!')

In [21]:
%%time

# Train the model
train_model(model, train_loader, val_loader, optimizer, device)


Epoch 1/3


Training: 100%|██████████| 122/122 [20:07<00:00,  9.90s/it]
Evaluating: 100%|██████████| 31/31 [01:27<00:00,  2.83s/it]


Training Loss: 0.0914
Validation Loss: 0.4894
Validation Accuracy: 85.5670
Best model saved!

Epoch 2/3


Training: 100%|██████████| 122/122 [49:15<00:00, 24.22s/it]  
Evaluating: 100%|██████████| 31/31 [02:06<00:00,  4.08s/it]


Training Loss: 0.0489
Validation Loss: 0.7387
Validation Accuracy: 84.1237

Epoch 3/3


Training: 100%|██████████| 122/122 [58:32<00:00, 28.79s/it]   
Evaluating: 100%|██████████| 31/31 [01:59<00:00,  3.85s/it]


Training Loss: 0.0480
Validation Loss: 0.5661
Validation Accuracy: 86.7010
Best model saved!
CPU times: total: 7h 5min 10s
Wall time: 2h 13min 29s
