In [1]:
import json
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb
from transformers import BertForSequenceClassification
import pickle

In [2]:
# Source: https://github.com/LCS2-IIITD/Emotion-Flip-Reasoning/blob/main/Dataloaders/nlp_utils.py
import string
import nltk
import re

numbers = {
    "0":"zero",
    "1":"one",
    "2":"two",
    "3":"three",
    "4":"four",
    "5":"five",
    "6":"six",
    "7":"seven",
    "8":"eight",
    "9":"nine"
}

def remove_puntuations(txt):
    punct = set(string.punctuation)
    txt = " ".join(txt.split("."))
    txt = " ".join(txt.split("!"))
    txt = " ".join(txt.split("?"))
    txt = " ".join(txt.split(":"))
    txt = " ".join(txt.split(";"))
    
    txt = "".join(ch for ch in txt if ch not in punct)
    return txt

def number_to_words(txt):
    for k in numbers.keys():
        txt = txt.replace(k,numbers[k]+" ")
    return txt

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'_',' ',text)
    text = number_to_words(text)
    text = remove_puntuations(text)
    text = ''.join([i if ord(i) < 128 else '' for i in text])
    text = ' '.join(text.split())
    return text

In [3]:
train_data = json.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/ERC_utterance_level/train_utterance_level.json'))
val_data = json.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/ERC_utterance_level/val_utterance_level.json'))

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
emotion2int = {
    'anger': 0,
    'joy': 1,
    'fear': 2,
    'disgust': 3,
    'neutral': 4,
    'surprise': 5,
    'sadness': 6
}

In [6]:
utterance2vec = pickle.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/Embeddings/bert_utterance2vec.pkl', 'rb'))

In [7]:
MAX_CONV_LEN = 35
# Defined index 7 for padding
class ERC_BERT_Dataset(Dataset):
    def __init__(self, data, utterance2vec, device):
        self.data = data
        self.utterance2vec = utterance2vec
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[f'id_{idx+1}']['text']
        emotion = self.data[f'id_{idx+1}']['emotion']
        context = self.data[f'id_{idx+1}']['context']
        context.append(text)
        embeddings = [torch.tensor(self.utterance2vec[preprocess_text(utterance)]).to(self.device) for utterance in context]
        embeddings = [embeddings[i] + embeddings[-1] for i in range(len(embeddings))]        
        if(len(embeddings)<MAX_CONV_LEN):
            num_pads = MAX_CONV_LEN - len(embeddings)
            attention_mask = [1]*len(embeddings) + [0]*num_pads
            embeddings = embeddings + [torch.zeros(768).to(self.device)]*num_pads  
        else:
            embeddings = embeddings[:MAX_CONV_LEN]
            attention_mask = [1]*MAX_CONV_LEN

        embeddings = torch.stack(embeddings)
        attention_mask = torch.tensor(attention_mask)

        return {
            'embeddings': embeddings,
            'attention_mask': attention_mask,   
            'emotion': emotion2int[emotion]
        }
        

In [8]:
train_dataset = ERC_BERT_Dataset(train_data, utterance2vec, device)
val_dataset = ERC_BERT_Dataset(val_data, utterance2vec, device)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [9]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=7).to(device)

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
epochs = 15
optimizer = AdamW(model.parameters(), lr=1e-6)

In [11]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_login_key")
wandb.login(key=secret_value_0) 

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [12]:
wandb.init(project='TECPEC', name='ERC_BERT_Utterance_Level', config={
    'Embedding': 'BERT',
    'Level': 'Utterance Level',
    'Approach': 'Added each utterance embedding with the target utterance embedding',
    'Epochs': epochs,
    'Optimizer': 'AdamW',
    'Learning Rate': 1e-6,
    'Batch Size': 32
})

[34m[1mwandb[0m: Currently logged in as: [33mshreyas21563[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112879866666642, max=1.0…

In [13]:
for epoch in range(epochs):
    model.train()
    train_pred, train_true, train_loss = [], [], 0.0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        embeddings = batch['embeddings'].to(device)
        emotions = batch['emotion'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask, labels=emotions)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_pred.extend(torch.argmax(outputs.logits, 1).tolist())
        train_true.extend(emotions.tolist())
        train_loss += loss.item()
    train_loss /= len(train_loader) 
    model.eval()
    val_pred, val_true, val_loss = [], [], 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            embeddings = batch['embeddings'].to(device)
            emotions = batch['emotion'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask, labels=emotions)
            val_pred.extend(torch.argmax(outputs.logits, 1).tolist())
            val_true.extend(emotions.tolist())
            val_loss += outputs.loss.item()
    val_loss /= len(val_loader)
    train_report = classification_report(train_true, train_pred, target_names=emotion2int.keys(), zero_division=0)
    val_report = classification_report(val_true, val_pred, target_names=emotion2int.keys(), zero_division=0)

    train_report_dict = classification_report(train_true, train_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    val_report_dict = classification_report(val_true, val_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_report_dict['accuracy'],
        'val_accuracy': val_report_dict['accuracy'],
        'Macro train_f1': train_report_dict['macro avg']['f1-score'],
        'Macro val_f1': val_report_dict['macro avg']['f1-score'],
        'Weighted train_f1': train_report_dict['weighted avg']['f1-score'],
        'Weighted val_f1': val_report_dict['weighted avg']['f1-score'],
    })
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Train Report: \n{train_report}")
    print(f"Val Report: \n{val_report}")


100%|██████████| 380/380 [01:18<00:00,  4.84it/s]
100%|██████████| 47/47 [00:03<00:00, 13.42it/s]


Epoch: 1, Train Loss: 1.693370968730826, Val Loss: 1.5863463396721698
Train Report: 
              precision    recall  f1-score   support

       anger       0.15      0.02      0.04      1423
         joy       0.17      0.07      0.10      2047
        fear       0.00      0.00      0.00       336
     disgust       0.01      0.00      0.00       372
     neutral       0.44      0.90      0.59      5299
    surprise       0.11      0.00      0.00      1656
     sadness       0.07      0.00      0.00      1011

    accuracy                           0.41     12144
   macro avg       0.13      0.14      0.11     12144
weighted avg       0.26      0.41      0.28     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      1.00 

100%|██████████| 380/380 [01:22<00:00,  4.63it/s]
100%|██████████| 47/47 [00:03<00:00, 13.17it/s]


Epoch: 2, Train Loss: 1.5783514656518636, Val Loss: 1.540742186789817
Train Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00      1423
         joy       0.00      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      1.00      0.61      5299
    surprise       0.58      0.03      0.06      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.15      0.15      0.10     12144
weighted avg       0.27      0.44      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.45      0.98 

100%|██████████| 380/380 [01:22<00:00,  4.62it/s]
100%|██████████| 47/47 [00:03<00:00, 13.14it/s]


Epoch: 3, Train Loss: 1.527363647912678, Val Loss: 1.4843349101695609
Train Report: 
              precision    recall  f1-score   support

       anger       0.17      0.00      0.00      1423
         joy       0.22      0.00      0.01      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.46      0.97      0.62      5299
    surprise       0.48      0.23      0.31      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.46     12144
   macro avg       0.19      0.17      0.14     12144
weighted avg       0.32      0.46      0.32     12144

Val Report: 
              precision    recall  f1-score   support

       anger       1.00      0.01      0.01       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.46      0.97 

100%|██████████| 380/380 [01:22<00:00,  4.61it/s]
100%|██████████| 47/47 [00:03<00:00, 12.77it/s]


Epoch: 4, Train Loss: 1.4807159646561272, Val Loss: 1.444395438153693
Train Report: 
              precision    recall  f1-score   support

       anger       0.30      0.01      0.03      1423
         joy       0.38      0.02      0.04      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.47      0.95      0.63      5299
    surprise       0.47      0.38      0.42      1656
     sadness       0.50      0.00      0.00      1011

    accuracy                           0.47     12144
   macro avg       0.30      0.19      0.16     12144
weighted avg       0.41      0.47      0.34     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.39      0.08      0.14       192
         joy       0.80      0.02      0.03       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.48      0.95 

100%|██████████| 380/380 [01:23<00:00,  4.57it/s]
100%|██████████| 47/47 [00:03<00:00, 12.66it/s]


Epoch: 5, Train Loss: 1.4339297169133236, Val Loss: 1.4053166825720604
Train Report: 
              precision    recall  f1-score   support

       anger       0.32      0.08      0.13      1423
         joy       0.53      0.06      0.11      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.49      0.93      0.64      5299
    surprise       0.47      0.43      0.45      1656
     sadness       0.41      0.02      0.05      1011

    accuracy                           0.48     12144
   macro avg       0.32      0.22      0.20     12144
weighted avg       0.44      0.48      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.40      0.10      0.17       192
         joy       0.59      0.05      0.09       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.49      0.95

100%|██████████| 380/380 [01:24<00:00,  4.51it/s]
100%|██████████| 47/47 [00:03<00:00, 12.51it/s]


Epoch: 6, Train Loss: 1.392144123347182, Val Loss: 1.3846622122094987
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.15      0.20      1423
         joy       0.54      0.13      0.22      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.52      0.90      0.66      5299
    surprise       0.51      0.47      0.49      1656
     sadness       0.39      0.09      0.15      1011

    accuracy                           0.50     12144
   macro avg       0.33      0.25      0.24     12144
weighted avg       0.46      0.50      0.43     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.36      0.09      0.14       192
         joy       0.61      0.19      0.29       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.54      0.84 

100%|██████████| 380/380 [01:24<00:00,  4.47it/s]
100%|██████████| 47/47 [00:03<00:00, 12.19it/s]


Epoch: 7, Train Loss: 1.3609769010230115, Val Loss: 1.3668753682298864
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.19      0.24      1423
         joy       0.52      0.23      0.32      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.54      0.87      0.67      5299
    surprise       0.52      0.50      0.51      1656
     sadness       0.39      0.13      0.20      1011

    accuracy                           0.52     12144
   macro avg       0.33      0.27      0.28     12144
weighted avg       0.47      0.52      0.46     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.37      0.14      0.20       192
         joy       0.56      0.21      0.31       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.55      0.79

100%|██████████| 380/380 [01:25<00:00,  4.43it/s]
100%|██████████| 47/47 [00:03<00:00, 12.05it/s]


Epoch: 8, Train Loss: 1.3332701372472864, Val Loss: 1.3417471850172003
Train Report: 
              precision    recall  f1-score   support

       anger       0.34      0.22      0.27      1423
         joy       0.54      0.28      0.37      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.55      0.86      0.67      5299
    surprise       0.54      0.49      0.51      1656
     sadness       0.42      0.18      0.25      1011

    accuracy                           0.53     12144
   macro avg       0.34      0.29      0.30     12144
weighted avg       0.48      0.53      0.48     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.31      0.18      0.22       192
         joy       0.52      0.37      0.43       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.57      0.74

100%|██████████| 380/380 [01:26<00:00,  4.39it/s]
100%|██████████| 47/47 [00:03<00:00, 11.80it/s]


Epoch: 9, Train Loss: 1.308814726691497, Val Loss: 1.331839142961705
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.24      0.28      1423
         joy       0.54      0.32      0.41      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.57      0.84      0.68      5299
    surprise       0.55      0.52      0.54      1656
     sadness       0.42      0.23      0.30      1011

    accuracy                           0.54     12144
   macro avg       0.35      0.31      0.31     12144
weighted avg       0.49      0.54      0.49     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.29      0.16      0.20       192
         joy       0.53      0.39      0.45       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.57      0.75  

100%|██████████| 380/380 [01:27<00:00,  4.34it/s]
100%|██████████| 47/47 [00:04<00:00, 11.61it/s]


Epoch: 10, Train Loss: 1.2913127464683432, Val Loss: 1.3255293141020106
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.24      0.29      1423
         joy       0.55      0.35      0.43      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.57      0.83      0.68      5299
    surprise       0.56      0.53      0.55      1656
     sadness       0.42      0.25      0.31      1011

    accuracy                           0.54     12144
   macro avg       0.35      0.31      0.32     12144
weighted avg       0.49      0.54      0.50     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.33      0.20      0.25       192
         joy       0.58      0.34      0.43       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.55      0.8

100%|██████████| 380/380 [01:28<00:00,  4.30it/s]
100%|██████████| 47/47 [00:04<00:00, 11.21it/s]


Epoch: 11, Train Loss: 1.2686321344814802, Val Loss: 1.3387428179700325
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.26      0.31      1423
         joy       0.55      0.36      0.44      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.58      0.83      0.69      5299
    surprise       0.57      0.55      0.56      1656
     sadness       0.44      0.27      0.34      1011

    accuracy                           0.55     12144
   macro avg       0.36      0.33      0.33     12144
weighted avg       0.50      0.55      0.51     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.30      0.20      0.24       192
         joy       0.57      0.35      0.43       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.57      0.7

100%|██████████| 380/380 [01:29<00:00,  4.26it/s]
100%|██████████| 47/47 [00:04<00:00, 11.11it/s]


Epoch: 12, Train Loss: 1.2505783637887553, Val Loss: 1.3276141184441588
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.28      0.32      1423
         joy       0.55      0.40      0.46      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.59      0.83      0.69      5299
    surprise       0.58      0.55      0.57      1656
     sadness       0.47      0.30      0.37      1011

    accuracy                           0.56     12144
   macro avg       0.37      0.34      0.34     12144
weighted avg       0.51      0.56      0.52     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.35      0.17      0.23       192
         joy       0.49      0.44      0.46       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.57      0.7

100%|██████████| 380/380 [01:30<00:00,  4.21it/s]
100%|██████████| 47/47 [00:04<00:00, 10.98it/s]


Epoch: 13, Train Loss: 1.230191365982357, Val Loss: 1.321953867344146
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.29      0.33      1423
         joy       0.56      0.41      0.47      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.60      0.83      0.70      5299
    surprise       0.59      0.58      0.59      1656
     sadness       0.49      0.32      0.39      1011

    accuracy                           0.57     12144
   macro avg       0.38      0.35      0.35     12144
weighted avg       0.52      0.57      0.53     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.34      0.20      0.25       192
         joy       0.53      0.40      0.45       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.56      0.75 

100%|██████████| 380/380 [01:31<00:00,  4.16it/s]
100%|██████████| 47/47 [00:04<00:00, 10.68it/s]


Epoch: 14, Train Loss: 1.2166089514368459, Val Loss: 1.3365645636903478
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.30      0.34      1423
         joy       0.56      0.42      0.48      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.60      0.82      0.70      5299
    surprise       0.60      0.59      0.59      1656
     sadness       0.47      0.33      0.39      1011

    accuracy                           0.57     12144
   macro avg       0.38      0.35      0.36     12144
weighted avg       0.53      0.57      0.54     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.32      0.24      0.27       192
         joy       0.55      0.37      0.44       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.57      0.7

100%|██████████| 380/380 [01:32<00:00,  4.12it/s]
100%|██████████| 47/47 [00:04<00:00, 10.50it/s]

Epoch: 15, Train Loss: 1.1954192490954147, Val Loss: 1.3294243127741712
Train Report: 
              precision    recall  f1-score   support

       anger       0.40      0.33      0.36      1423
         joy       0.58      0.43      0.49      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.62      0.82      0.71      5299
    surprise       0.60      0.61      0.60      1656
     sadness       0.49      0.34      0.41      1011

    accuracy                           0.58     12144
   macro avg       0.38      0.36      0.37     12144
weighted avg       0.54      0.58      0.55     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.35      0.21      0.26       192
         joy       0.48      0.44      0.46       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.58      0.7




In [14]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Macro train_f1,▁▁▂▃▄▅▆▆▇▇▇▇███
Macro val_f1,▁▃▃▄▅▇▇████████
Weighted train_f1,▁▁▂▃▄▅▆▆▇▇▇▇███
Weighted val_f1,▁▃▃▄▅▇▇████████
train_accuracy,▁▂▃▄▄▅▆▆▆▆▇▇███
train_loss,█▆▆▅▄▄▃▃▃▂▂▂▁▁▁
val_accuracy,▁▃▄▅▆▇▆▇▇█▆█▇▇▇
val_loss,█▇▅▄▃▃▂▂▁▁▁▁▁▁▁

0,1
Macro train_f1,0.36674
Macro val_f1,0.31957
Weighted train_f1,0.54921
Weighted val_f1,0.48629
train_accuracy,0.58177
train_loss,1.19542
val_accuracy,0.51458
val_loss,1.32942


In [15]:
torch.save(model, '/kaggle/working/add_target.pth')