In [1]:
import json
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb
from transformers import BertForSequenceClassification
import pickle

In [2]:
# Source: https://github.com/LCS2-IIITD/Emotion-Flip-Reasoning/blob/main/Dataloaders/nlp_utils.py
import string
import nltk
import re

numbers = {
    "0":"zero",
    "1":"one",
    "2":"two",
    "3":"three",
    "4":"four",
    "5":"five",
    "6":"six",
    "7":"seven",
    "8":"eight",
    "9":"nine"
}

def remove_puntuations(txt):
    punct = set(string.punctuation)
    txt = " ".join(txt.split("."))
    txt = " ".join(txt.split("!"))
    txt = " ".join(txt.split("?"))
    txt = " ".join(txt.split(":"))
    txt = " ".join(txt.split(";"))
    
    txt = "".join(ch for ch in txt if ch not in punct)
    return txt

def number_to_words(txt):
    for k in numbers.keys():
        txt = txt.replace(k,numbers[k]+" ")
    return txt

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'_',' ',text)
    text = number_to_words(text)
    text = remove_puntuations(text)
    text = ''.join([i if ord(i) < 128 else '' for i in text])
    text = ' '.join(text.split())
    return text

In [3]:
train_data = json.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/ERC_utterance_level/train_utterance_level.json'))
val_data = json.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/ERC_utterance_level/val_utterance_level.json'))

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
emotion2int = {
    'anger': 0,
    'joy': 1,
    'fear': 2,
    'disgust': 3,
    'neutral': 4,
    'surprise': 5,
    'sadness': 6
}

In [6]:
utterance2vec = pickle.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/Embeddings/bert_utterance2vec.pkl', 'rb'))

In [7]:
MAX_CONV_LEN = 35
# Defined index 7 for padding
class ERC_BERT_Dataset(Dataset):
    def __init__(self, data, utterance2vec, device):
        self.data = data
        self.utterance2vec = utterance2vec
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[f'id_{idx+1}']['text']
        emotion = self.data[f'id_{idx+1}']['emotion']
        context = self.data[f'id_{idx+1}']['context']
        context.append(text)
        embeddings = [torch.tensor(self.utterance2vec[preprocess_text(utterance)]).to(self.device) for utterance in context]
#         embeddings = [embeddings[i] + embeddings[-1] for i in range(len(embeddings))]        
        if(len(embeddings)<MAX_CONV_LEN):
            num_pads = MAX_CONV_LEN - len(embeddings)
            attention_mask = [1]*len(embeddings) + [0]*num_pads
            embeddings = embeddings + [torch.zeros(768).to(self.device)]*num_pads  
        else:
            embeddings = embeddings[:MAX_CONV_LEN]
            attention_mask = [1]*MAX_CONV_LEN

        embeddings = torch.stack(embeddings)
        attention_mask = torch.tensor(attention_mask)

        return {
            'embeddings': embeddings,
            'attention_mask': attention_mask,   
            'emotion': emotion2int[emotion]
        }
        

In [8]:
train_dataset = ERC_BERT_Dataset(train_data, utterance2vec, device)
val_dataset = ERC_BERT_Dataset(val_data, utterance2vec, device)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [9]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=7).to(device)

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
epochs = 15
optimizer = AdamW(model.parameters(), lr=1e-6)

In [11]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_login_key")
wandb.login(key=secret_value_0) 

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [12]:
wandb.init(project='TECPEC', name='ERC_BERT_Utterance_Level', config={
    'Embedding': 'BERT',
    'Level': 'Utterance Level',
    'Approach': 'Not adding each utterance embedding with the target utterance embedding',
    'Epochs': epochs,
    'Optimizer': 'AdamW',
    'Learning Rate': 1e-6,
    'Batch Size': 32
})

[34m[1mwandb[0m: Currently logged in as: [33mshreyas21563[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
for epoch in range(epochs):
    model.train()
    train_pred, train_true, train_loss = [], [], 0.0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        embeddings = batch['embeddings'].to(device)
        emotions = batch['emotion'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask, labels=emotions)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_pred.extend(torch.argmax(outputs.logits, 1).tolist())
        train_true.extend(emotions.tolist())
        train_loss += loss.item()
    train_loss /= len(train_loader) 
    model.eval()
    val_pred, val_true, val_loss = [], [], 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            embeddings = batch['embeddings'].to(device)
            emotions = batch['emotion'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask, labels=emotions)
            val_pred.extend(torch.argmax(outputs.logits, 1).tolist())
            val_true.extend(emotions.tolist())
            val_loss += outputs.loss.item()
    val_loss /= len(val_loader)
    train_report = classification_report(train_true, train_pred, target_names=emotion2int.keys(), zero_division=0)
    val_report = classification_report(val_true, val_pred, target_names=emotion2int.keys(), zero_division=0)

    train_report_dict = classification_report(train_true, train_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    val_report_dict = classification_report(val_true, val_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_report_dict['accuracy'],
        'val_accuracy': val_report_dict['accuracy'],
        'Macro train_f1': train_report_dict['macro avg']['f1-score'],
        'Macro val_f1': val_report_dict['macro avg']['f1-score'],
        'Weighted train_f1': train_report_dict['weighted avg']['f1-score'],
        'Weighted val_f1': val_report_dict['weighted avg']['f1-score'],
    })
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Train Report: \n{train_report}")
    print(f"Val Report: \n{val_report}")


100%|██████████| 380/380 [01:19<00:00,  4.79it/s]
100%|██████████| 47/47 [00:03<00:00, 13.28it/s]


Epoch: 1, Train Loss: 1.6744096589715858, Val Loss: 1.6046759853971766
Train Report: 
              precision    recall  f1-score   support

       anger       0.13      0.02      0.03      1423
         joy       0.18      0.04      0.06      2047
        fear       0.03      0.02      0.03       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      0.91      0.59      5299
    surprise       0.10      0.01      0.02      1656
     sadness       0.20      0.00      0.01      1011

    accuracy                           0.41     12144
   macro avg       0.15      0.14      0.11     12144
weighted avg       0.27      0.41      0.28     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      1.00

100%|██████████| 380/380 [01:24<00:00,  4.52it/s]
100%|██████████| 47/47 [00:03<00:00, 13.04it/s]


Epoch: 2, Train Loss: 1.5998838547028993, Val Loss: 1.580490510514442
Train Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00      1423
         joy       0.15      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      1.00      0.61      5299
    surprise       0.00      0.00      0.00      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.08      0.14      0.09     12144
weighted avg       0.22      0.44      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      1.00 

100%|██████████| 380/380 [01:24<00:00,  4.50it/s]
100%|██████████| 47/47 [00:03<00:00, 12.82it/s]


Epoch: 3, Train Loss: 1.5743206751974006, Val Loss: 1.5516766385829195
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.00      0.01      1423
         joy       0.50      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      1.00      0.61      5299
    surprise       0.00      0.00      0.00      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.18      0.14      0.09     12144
weighted avg       0.31      0.44      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.24      0.02      0.04       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      0.99

100%|██████████| 380/380 [01:25<00:00,  4.45it/s]
100%|██████████| 47/47 [00:03<00:00, 12.49it/s]


Epoch: 4, Train Loss: 1.5395681641603771, Val Loss: 1.517691490497995
Train Report: 
              precision    recall  f1-score   support

       anger       0.28      0.05      0.09      1423
         joy       0.43      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      0.99      0.61      5299
    surprise       0.50      0.00      0.00      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.24      0.15      0.10     12144
weighted avg       0.37      0.44      0.28     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.26      0.13      0.17       192
         joy       0.50      0.01      0.02       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.44      0.96 

100%|██████████| 380/380 [01:25<00:00,  4.42it/s]
100%|██████████| 47/47 [00:03<00:00, 12.27it/s]


Epoch: 5, Train Loss: 1.494153747746819, Val Loss: 1.4835275589151586
Train Report: 
              precision    recall  f1-score   support

       anger       0.32      0.13      0.19      1423
         joy       0.55      0.03      0.06      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.45      0.97      0.62      5299
    surprise       0.47      0.03      0.06      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.45     12144
   macro avg       0.26      0.17      0.13     12144
weighted avg       0.39      0.45      0.31     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.22      0.24       192
         joy       0.46      0.13      0.20       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.47      0.85 

100%|██████████| 380/380 [01:26<00:00,  4.38it/s]
100%|██████████| 47/47 [00:03<00:00, 12.33it/s]


Epoch: 6, Train Loss: 1.447823787362952, Val Loss: 1.4373841704206263
Train Report: 
              precision    recall  f1-score   support

       anger       0.31      0.19      0.24      1423
         joy       0.55      0.12      0.20      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.48      0.92      0.63      5299
    surprise       0.51      0.20      0.29      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.47     12144
   macro avg       0.26      0.21      0.19     12144
weighted avg       0.41      0.47      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.20      0.23       192
         joy       0.53      0.16      0.24       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.50      0.83 

100%|██████████| 380/380 [01:27<00:00,  4.36it/s]
100%|██████████| 47/47 [00:03<00:00, 12.18it/s]


Epoch: 7, Train Loss: 1.4105405144001308, Val Loss: 1.412213089618277
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.23      0.27      1423
         joy       0.50      0.18      0.27      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.51      0.88      0.64      5299
    surprise       0.48      0.34      0.40      1656
     sadness       0.20      0.00      0.00      1011

    accuracy                           0.49     12144
   macro avg       0.29      0.23      0.23     12144
weighted avg       0.43      0.49      0.41     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.31      0.29       192
         joy       0.51      0.25      0.33       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.52      0.76 

100%|██████████| 380/380 [01:28<00:00,  4.31it/s]
100%|██████████| 47/47 [00:03<00:00, 11.98it/s]


Epoch: 8, Train Loss: 1.380104339436481, Val Loss: 1.3801858843641077
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.26      0.29      1423
         joy       0.50      0.25      0.33      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.53      0.86      0.65      5299
    surprise       0.50      0.40      0.44      1656
     sadness       0.47      0.04      0.07      1011

    accuracy                           0.50     12144
   macro avg       0.33      0.26      0.26     12144
weighted avg       0.46      0.50      0.44     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.29      0.32      0.31       192
         joy       0.61      0.22      0.33       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.83 

100%|██████████| 380/380 [01:28<00:00,  4.28it/s]
100%|██████████| 47/47 [00:03<00:00, 11.77it/s]


Epoch: 9, Train Loss: 1.3515726153787813, Val Loss: 1.3601111249720796
Train Report: 
              precision    recall  f1-score   support

       anger       0.32      0.26      0.29      1423
         joy       0.54      0.30      0.38      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.54      0.84      0.66      5299
    surprise       0.51      0.43      0.47      1656
     sadness       0.43      0.07      0.12      1011

    accuracy                           0.51     12144
   macro avg       0.33      0.27      0.27     12144
weighted avg       0.47      0.51      0.46     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.33      0.23      0.28       192
         joy       0.53      0.42      0.47       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.55      0.75

100%|██████████| 380/380 [01:29<00:00,  4.24it/s]
100%|██████████| 47/47 [00:04<00:00, 11.74it/s]


Epoch: 10, Train Loss: 1.3261461858686647, Val Loss: 1.3482188633147707
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.26      0.30      1423
         joy       0.53      0.31      0.39      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.55      0.83      0.66      5299
    surprise       0.52      0.45      0.48      1656
     sadness       0.41      0.14      0.21      1011

    accuracy                           0.52     12144
   macro avg       0.33      0.29      0.29     12144
weighted avg       0.47      0.52      0.47     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.33      0.23      0.27       192
         joy       0.57      0.39      0.46       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.55      0.7

100%|██████████| 380/380 [01:30<00:00,  4.18it/s]
100%|██████████| 47/47 [00:04<00:00, 11.55it/s]


Epoch: 11, Train Loss: 1.3066109549058111, Val Loss: 1.3439910208925288
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.27      0.31      1423
         joy       0.54      0.35      0.43      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.56      0.83      0.67      5299
    surprise       0.53      0.47      0.50      1656
     sadness       0.44      0.20      0.28      1011

    accuracy                           0.54     12144
   macro avg       0.35      0.30      0.31     12144
weighted avg       0.49      0.54      0.49     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.36      0.24      0.29       192
         joy       0.49      0.45      0.47       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.58      0.7

100%|██████████| 380/380 [01:32<00:00,  4.12it/s]
100%|██████████| 47/47 [00:04<00:00, 11.37it/s]


Epoch: 12, Train Loss: 1.2821636154463416, Val Loss: 1.3427521244008491
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.29      0.32      1423
         joy       0.54      0.37      0.44      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.57      0.81      0.67      5299
    surprise       0.54      0.50      0.52      1656
     sadness       0.43      0.22      0.29      1011

    accuracy                           0.54     12144
   macro avg       0.35      0.31      0.32     12144
weighted avg       0.49      0.54      0.50     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.31      0.29      0.30       192
         joy       0.60      0.35      0.45       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.56      0.7

100%|██████████| 380/380 [01:33<00:00,  4.09it/s]
100%|██████████| 47/47 [00:04<00:00, 11.22it/s]


Epoch: 13, Train Loss: 1.2677805682546215, Val Loss: 1.3195817889051233
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.30      0.33      1423
         joy       0.55      0.38      0.45      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.58      0.81      0.68      5299
    surprise       0.55      0.51      0.53      1656
     sadness       0.41      0.26      0.32      1011

    accuracy                           0.55     12144
   macro avg       0.35      0.32      0.33     12144
weighted avg       0.50      0.55      0.51     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.36      0.21      0.26       192
         joy       0.60      0.43      0.50       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.55      0.7

100%|██████████| 380/380 [01:33<00:00,  4.05it/s]
100%|██████████| 47/47 [00:04<00:00, 11.07it/s]


Epoch: 14, Train Loss: 1.2480714992473, Val Loss: 1.3255903987174338
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.33      0.36      1423
         joy       0.55      0.39      0.46      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.58      0.81      0.68      5299
    surprise       0.56      0.52      0.54      1656
     sadness       0.45      0.27      0.34      1011

    accuracy                           0.55     12144
   macro avg       0.36      0.33      0.34     12144
weighted avg       0.51      0.55      0.52     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.40      0.18      0.25       192
         joy       0.58      0.41      0.48       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.56      0.79  

100%|██████████| 380/380 [01:34<00:00,  4.02it/s]
100%|██████████| 47/47 [00:04<00:00, 10.93it/s]

Epoch: 15, Train Loss: 1.2341732671386316, Val Loss: 1.3433352264952152
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.31      0.35      1423
         joy       0.56      0.41      0.47      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.60      0.81      0.69      5299
    surprise       0.57      0.54      0.56      1656
     sadness       0.44      0.32      0.37      1011

    accuracy                           0.56     12144
   macro avg       0.37      0.34      0.35     12144
weighted avg       0.52      0.56      0.53     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.36      0.28      0.31       192
         joy       0.57      0.41      0.47       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.59      0.6




In [14]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Macro train_f1,▁▁▁▁▂▄▅▆▆▇▇▇███
Macro val_f1,▁▁▁▂▄▅▆▆▇▇█████
Weighted train_f1,▁▁▁▁▂▄▅▆▆▇▇▇▇██
Weighted val_f1,▁▁▁▂▄▅▆▆▇██████
train_accuracy,▁▂▂▂▃▄▅▅▆▆▇▇▇██
train_loss,█▇▆▆▅▄▄▃▃▂▂▂▂▁▁
val_accuracy,▁▁▁▁▃▄▄▆▇▇▇▇██▇
val_loss,█▇▇▆▅▄▃▂▂▂▂▂▁▁▂

0,1
Macro train_f1,0.34776
Macro val_f1,0.32691
Weighted train_f1,0.52779
Weighted val_f1,0.48965
train_accuracy,0.56094
train_loss,1.23417
val_accuracy,0.51119
val_loss,1.34334


In [15]:
torch.save(model, '/kaggle/working/not_add_target.pth')