In [2]:
from datasets import load_dataset
import pandas as pd
import numpy as np
import requests
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from tqdm.notebook import tqdm
from sklearn.metrics import classification_report, confusion_matrix

In [5]:
# Initialise BERT tokenizer
tokeniser = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [29]:
# Method to encode text 
def encode(corpus):
    encoded = tokeniser.batch_encode_plus(corpus, max_length=128,
                                         add_special_tokens=True,
                                         return_attention_mask=True,
                                         truncation=True,
                                         return_tensors='pt',
                                         padding='max_length')

    return encoded['input_ids'], encoded['attention_mask']

# Method to get the data loader for tokens
def get_dataloader(ids, masks, data, isRandom):
    tensored = TensorDataset(ids, masks, data)
    sampler = RandomSampler(tensored) if isRandom else SequentialSampler(tensored)
    dataloader = DataLoader(tensored, sampler=sampler, batch_size=16)
    return dataloader

In [7]:
# Load tweet dataset labelled with emotions
dataset = load_dataset('dair-ai/emotion', name='unsplit', split='train')

Found cached dataset emotion (/home/zum/.cache/huggingface/datasets/dair-ai___emotion/unsplit/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd)


In [8]:
# Get unique emotion labels from data 
data_labels = dataset.features['label'].names
data_labels

['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

In [9]:
# Format the dataset
dataset.set_format(type='pandas')
N = len(set(dataset['label']))
all_data = dataset[:]

In [15]:
# Sample dataset to make it smaller
all_data = all_data.groupby('label').apply(lambda x: x.sample(500)).reset_index(drop=True)

In [16]:
# Split the data into training data for development and testing data for evaluation
x_train, x_test, y_train, y_test = train_test_split(all_data['text'], all_data['label'], test_size=0.15, random_state=1)

In [19]:
# Encode training and testing data
train_ids, train_masks = encode(list(x_train))
test_ids, test_masks = encode(list(x_test))

In [20]:
# Create 
y_train = torch.LongTensor(list(y_train))
y_test = torch.LongTensor(list(y_test))

In [21]:
# Get dataloaders for training and testing data
train_loader = get_dataloader(train_ids, train_masks, y_train, True)
test_loader = get_dataloader(test_ids, test_masks, y_test, False)

In [23]:
# Get BERT sequence classifier
classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                          num_labels=N,
                                                          output_attentions=False,
                                                          output_hidden_states=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [23]:
device='cpu'
classifier = classifier.cpu()

In [31]:
# Add AdamW optimizer
optimiser = AdamW(classifier.parameters(), lr=1e-4)
scheduler = get_linear_schedule_with_warmup(optimiser,
                                           num_training_steps=len(train_loader) * 5,
                                           num_warmup_steps=0)

In [26]:
# Method to save a model
def save_model(model, path):
    torch.save(model, path)

In [None]:
# 
train_losses = []

for epoch in range(5):
    classifier.train()
    train_loss = 0
    
    for step, data in enumerate(tqdm(train_loader, desc='Training')):
        ids, masks, labels = [x.to(device) for x in data]
        output = classifier(input_ids=ids, attention_mask=masks, labels=labels)
        loss = output.loss
        train_loss += loss.item()
        
        classifier.zero_grad()
        loss.backward()
        del loss
        
        clip_grad_norm_(parameters=classifier.parameters(), max_norm=1.0)
        optimiser.step()
        scheduler.step()
    
    if epoch % 5 == 0:
        save_model(classifier, './secondtry' + str(epoch))
    train_losses.append(train_loss / (step + 1))
    print("Loss: {}".format(train_loss / (step + 1)))
      

Training:   0%|          | 0/144 [00:00<?, ?it/s]

In [19]:
# Save the emotion classification model
save_model(classifier, './finalsecondtry')

In [20]:
# Evaluate the classifier
classifier.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [3]:
classifier = torch.load('./finalsmall')

In [4]:
torch.save(classifier.state_dict(), './finalsmalldict')

In [1]:
# Load classifier
classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                          num_labels=6,
                                                          output_attentions=False,

                                                          output_hidden_states=False)
classifier.load_state_dict(torch.load('./finalsmalldict'))

NameError: name 'BertForSequenceClassification' is not defined

<All keys matched successfully>

In [30]:
# Testing
test_pred = []
test_loss = 0

with torch.no_grad():
    for data in tqdm(test_loader):
        ids, mask, labels = [x.to(device) for x in data]
        output = classifier(input_ids=ids, attention_mask=mask, labels=labels)

        loss = output.loss
        test_loss += loss.item()
        
        print(output.logits)

        test_pred.append(np.argmax(output.logits.cpu().detach().numpy(),axis=-1))
test_pred = np.concatenate(test_pred)

  0%|          | 0/29 [00:00<?, ?it/s]

tensor([[ 3.7086, -3.0280,  6.1392, -1.2748, -3.1719, -4.2372],
        [-1.9102, -1.8621, -2.0427, -1.6673,  9.7336, -1.8912],
        [-2.4590, -1.9066,  9.0883, -2.4602, -2.7496, -2.4426],
        [-2.0003, -2.1581, -2.2166, 10.0003, -1.3689, -2.2349],
        [-2.4013, -1.9714,  9.0996, -2.3069, -2.8207, -2.5105],
        [10.2609, -1.4980, -1.8343, -1.6523, -2.0606, -2.0741],
        [10.2670, -1.7032, -1.7703, -1.6318, -2.0036, -2.0484],
        [ 9.8163, -1.8577, -2.0557,  0.0927, -2.3360, -2.5791],
        [10.2675, -1.6993, -1.7094, -1.7718, -1.9906, -2.0284],
        [-1.9358, -1.3699, -2.0861, -1.9279,  9.6949, -1.9792],
        [-2.3173, -0.6737,  8.9541, -2.6834, -3.5079, -3.0603],
        [-2.4459, -1.4340,  9.0900, -2.4420, -3.1687, -2.7328],
        [-1.8446, -1.6700, -1.7400, -1.7368, -1.4478, 10.2996],
        [-2.3795, -2.0111,  9.0389, -2.0184, -2.8081, -2.7712],
        [10.2553, -1.6192, -1.7171, -1.5916, -2.1077, -2.1383],
        [10.1410, -1.7035, -2.0490, -1.3

tensor([[-2.4709, -2.1126,  9.0120, -2.1447, -2.3463, -2.8506],
        [-1.6749, -1.5437, -1.8486, -1.9074, -1.4186, 10.3265],
        [-1.0877, -2.0377, -2.5241, 10.0361, -1.6798, -2.4672],
        [-1.7897, -1.4318, -1.7786, -1.9267, -1.5143, 10.3377],
        [-2.5295,  9.5641, -1.5361, -2.1474, -2.2188, -1.9188],
        [-1.9796, -2.1773, -2.2308, 10.0019, -1.3850, -2.1958],
        [-2.4939,  9.5738, -1.5907, -2.3176, -2.0171, -1.9931],
        [-1.7649, -2.2522, -2.0014, -1.5389,  9.7055, -1.8579],
        [-2.0858, -2.1735, -2.1771,  9.9882, -1.3744, -2.1790],
        [-2.3491,  9.5513, -1.4743, -2.0605, -2.1917, -2.3549],
        [-2.0242,  9.5393, -1.7564, -1.9835, -2.4546, -2.1135],
        [-2.3904,  9.5936, -1.5731, -2.3040, -2.1515, -2.0099],
        [-1.8582, -1.6753, -1.6704, -1.8666, -1.3747, 10.2950],
        [ 9.3887, -1.1777, -2.6679,  0.7192, -2.0082, -2.7870],
        [-1.3676, -1.5579, -1.9609, -2.1361, -1.3053, 10.2845],
        [-1.9484, -2.0799, -2.2510, 10.0

tensor([[-1.9772, -1.8244, -1.9992, -1.7009,  9.7260, -1.8928],
        [-2.4025, -0.8398,  8.9772, -2.8868, -3.0256, -3.0958],
        [10.2671, -1.8676, -1.7518, -1.7925, -1.7596, -2.0159],
        [-1.7156, -1.4879, -1.8026, -1.9616, -1.4600, 10.3356],
        [-1.7238, -1.8530, -2.1376, -1.7891,  9.7235, -1.8615],
        [-1.7510, -1.6187, -1.8251, -1.8149, -1.4107, 10.3173],
        [10.2472, -1.8875, -1.7193, -1.7699, -1.9869, -1.8666],
        [-2.2183,  9.3672, -1.4716, -2.1268, -2.5623, -2.0022],
        [-1.7731, -1.6352, -1.7159, -1.5966, -1.7124, 10.2815],
        [ 9.5003, -2.5876, -0.2654, -1.3088, -2.0283, -2.9226],
        [-2.5263,  9.2292, -1.3857, -2.1872, -2.6521, -1.5439],
        [-1.6646, -1.9425, -2.0379, -1.8245,  9.7255, -1.9102],
        [-1.6546, -1.8497, -2.1122, -1.8277,  9.7395, -1.8822],
        [-1.7078, -1.6204, -1.7738, -1.7490, -1.5674, 10.3141],
        [-1.9868, -2.0558, -2.2390,  9.9967, -1.4788, -2.2312],
        [-1.8116, -1.8804, -2.0174, -1.7

tensor([[-1.7742, -1.8755, -2.0388, -1.7960,  9.7325, -1.8851],
        [10.2672, -1.7698, -1.7075, -1.7823, -1.9759, -2.0041],
        [-1.7721, -1.5145, -1.7562, -1.8981, -1.5039, 10.3244],
        [-2.3568,  9.5828, -1.5726, -2.1702, -2.0836, -2.1937],
        [-2.4150, -2.2601,  9.0450, -2.2027, -2.4050, -2.6045],
        [-1.6051, -1.4051, -1.8694, -1.8871, -1.6330, 10.3352],
        [-2.4545, -1.9199,  9.0889, -2.2690, -2.8294, -2.5506],
        [10.2491, -1.6735, -1.7536, -1.7303, -2.1653, -1.8993],
        [-2.1682, -2.1279,  8.9970, -2.5023, -3.0114, -2.0866],
        [-1.8376, -1.7248, -2.0553, -1.8000,  9.7304, -1.9417],
        [-1.7090, -1.3299, -1.8372, -1.9978, -1.5266, 10.3509],
        [-3.4456,  8.1274,  0.3624, -2.4592, -2.7254, -1.3512],
        [-1.6019, -1.5701, -1.7169, -1.8422, -1.6720, 10.2992],
        [-2.4759, -2.2328,  9.0569, -2.0802, -2.7790, -2.3619],
        [-1.8080, -1.9116, -1.9848, -1.7901,  9.7168, -1.9358],
        [-1.9619, -2.1459, -2.2881, 10.0

In [28]:
print(classification_report(test_pred, list(y_test),target_names=data_labels))

              precision    recall  f1-score   support

     sadness       0.95      0.92      0.93        76
         joy       0.81      0.98      0.89        58
        love       0.97      0.86      0.91        70
       anger       0.87      0.96      0.91        68
        fear       0.92      0.87      0.89        91
    surprise       0.94      0.90      0.92        87

    accuracy                           0.91       450
   macro avg       0.91      0.91      0.91       450
weighted avg       0.91      0.91      0.91       450

