#TODO: Maybe try this - https://youtu.be/0VLAoVGf_74?si=zuJ8AL_wLbsbRdd5
Compare MHA, MQA, MLA (DeepSeek)

## Part 2: Apply TransformerClassifier (Encoder Only) 

In [10]:
from common_utils import *
from datasets import load_dataset
from sklearn.preprocessing import LabelEncoder
from transformer import TransformerClassifierMHA
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
from sklearn.metrics import accuracy_score, f1_score
from torch.optim.lr_scheduler import OneCycleLR

### Step 1: Prepare Dataset Loader

In [16]:
# CrowdFlower
# dataset = load_dataset("csv", data_files="./dataset/text_emotion.csv")
# dataset = crowd_dataset.rename_column('content', 'text')
# dataset_dict = create_train_validation_test(crowd_dataset['train'])

# Wassa
dataset = load_dataset("csv", data_files="./dataset/wassa_combined_data.csv")
dataset = dataset.rename_column('tweet', 'text')
dataset_dict = create_train_validation_test(dataset['train'])

embedding_matrix = np.load(EMBEDDING_PATH)

with open(WORD2IDX_PATH, "r", encoding="utf-8") as f:
    word2idx = json.load(f)

Generating train split: 0 examples [00:00, ? examples/s]

Train size: 4970
Validation size: 711
Test size: 1421


In [17]:
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(dataset_dict['train']['sentiment'])
val_labels = label_encoder.transform(dataset_dict['validation']['sentiment'])
test_labels = label_encoder.transform(dataset_dict['test']['sentiment'])
num_classes = len(label_encoder.classes_)
print(f"Number of sentiment classes: {num_classes}")
print(f"Emotion classes: {label_encoder.classes_}")

crowd_labels_dict = {
    'train': train_labels,
    'validation': val_labels,
    'test': test_labels
}

Number of sentiment classes: 4
Emotion classes: ['anger' 'fear' 'joy' 'sadness']


In [None]:
dataloaders_dict = create_dataloaders(dataset_dict, crowd_labels_dict, word2idx)

Created DataLoaders with 156 training batches, 23 validation batches, and 45 test batches.


### Step 2: Train the TransformerClassifier

In [None]:
from common_utils import *
from datasets import load_dataset
from sklearn.preprocessing import LabelEncoder
from transformer import TransformerClassifierMHA
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
from sklearn.metrics import accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import OneCycleLR

TransformerClassifier_MHA = TransformerClassifierMHA(len(word2idx), 13, 100, 10, 6, 200, 100, 0.15)
# vocab_size, num_classes, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

TransformerClassifier_MHA.to(device)

embedding_matrix_tensor = torch.FloatTensor(embedding_matrix)
TransformerClassifier_MHA.embedding.weight.data.copy_(embedding_matrix_tensor)
TransformerClassifier_MHA.embedding.weight.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(TransformerClassifier_MHA.parameters(), lr=0.001)

total_steps = 100 * len(dataloaders_dict['train'])
scheduler = OneCycleLR(
    optimizer,
    max_lr=0.001,
    total_steps=total_steps,
    pct_start=0.1,
    anneal_strategy='cos'
)

Using device: cuda


In [None]:
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=100):
    since = time.time()
    best_model_wts = model.state_dict()
    best_val_f1 = 0.0

    patience = 5
    no_improve_epochs = 0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            all_preds = []
            all_labels = []

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        scheduler.step()

                running_loss += loss.item() * inputs.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = accuracy_score(all_labels, all_preds)
            epoch_f1 = f1_score(all_labels, all_preds, average='weighted')

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} F1: {epoch_f1:.4f}')

            if phase == 'validation':
                if epoch_f1 > best_val_f1:
                    best_val_f1 = epoch_f1
                    best_model_wts = model.state_dict().copy()
                    no_improve_epochs = 0

                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'val_f1': epoch_f1,
                    }, 'best_transformer_MHA_model.pt')
                else:
                    no_improve_epochs += 1
                    print(f'No improvement for {no_improve_epochs} epochs')

                if no_improve_epochs >= patience:
                    print(f'Early stopping triggered after {epoch+1} epochs')
                    time_elapsed = time.time() - since
                    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
                    print(f'Best val F1: {best_val_f1:.4f}')
                    model.load_state_dict(best_model_wts)
                    return model

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val F1: {best_val_f1:.4f}')

    model.load_state_dict(best_model_wts)
    return model

trained_model = train_model(TransformerClassifier_MHA, dataloaders_dict, criterion, optimizer, scheduler, num_epochs=100)

def evaluate_model(model, test_dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_acc = accuracy_score(all_labels, all_preds)
    test_f1 = f1_score(all_labels, all_preds, average='weighted')

    print(f'Test Accuracy: {test_acc:.4f}')
    print(f'Test F1 Score: {test_f1:.4f}')

    return test_acc, test_f1

test_acc, test_f1 = evaluate_model(trained_model, dataloaders_dict['test'])

Epoch 1/100
----------
train Loss: 2.0954 Acc: 0.2908 F1: 0.1704
validation Loss: 1.7011 Acc: 0.3116 F1: 0.1481

Epoch 2/100
----------
train Loss: 1.5498 Acc: 0.2926 F1: 0.2330
validation Loss: 1.4207 Acc: 0.3116 F1: 0.1481
No improvement for 1 epochs

Epoch 3/100
----------
train Loss: 1.4345 Acc: 0.2825 F1: 0.2515
validation Loss: 1.3855 Acc: 0.3116 F1: 0.1481
No improvement for 2 epochs

Epoch 4/100
----------
train Loss: 1.2922 Acc: 0.3943 F1: 0.3678
validation Loss: 0.8700 Acc: 0.6471 F1: 0.6053

Epoch 5/100
----------
train Loss: 0.8483 Acc: 0.6631 F1: 0.6599
validation Loss: 0.6553 Acc: 0.7434 F1: 0.7186

Epoch 6/100
----------
train Loss: 0.6446 Acc: 0.7627 F1: 0.7625
validation Loss: 0.4780 Acc: 0.8260 F1: 0.8286

Epoch 7/100
----------
train Loss: 0.5210 Acc: 0.8231 F1: 0.8231
validation Loss: 0.4742 Acc: 0.8461 F1: 0.8455

Epoch 8/100
----------
train Loss: 0.4224 Acc: 0.8539 F1: 0.8538
validation Loss: 0.3383 Acc: 0.8698 F1: 0.8703

Epoch 9/100
----------
train Loss: 0.373