In [67]:
import json
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb
from transformers import BertPreTrainedModel, BertModel, BertConfig
import pickle
from torch.nn.functional import leaky_relu
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

In [68]:
# Source: https://github.com/LCS2-IIITD/Emotion-Flip-Reasoning/blob/main/Dataloaders/nlp_utils.py
import string
import nltk
import re

numbers = {
    "0":"zero",
    "1":"one",
    "2":"two",
    "3":"three",
    "4":"four",
    "5":"five",
    "6":"six",
    "7":"seven",
    "8":"eight",
    "9":"nine"
}

def remove_puntuations(txt):
    punct = set(string.punctuation)
    txt = " ".join(txt.split("."))
    txt = " ".join(txt.split("!"))
    txt = " ".join(txt.split("?"))
    txt = " ".join(txt.split(":"))
    txt = " ".join(txt.split(";"))
    
    txt = "".join(ch for ch in txt if ch not in punct)
    return txt

def number_to_words(txt):
    for k in numbers.keys():
        txt = txt.replace(k,numbers[k]+" ")
    return txt

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'_',' ',text)
    text = number_to_words(text)
    text = remove_puntuations(text)
    text = ''.join([i if ord(i) < 128 else '' for i in text])
    text = ' '.join(text.split())
    return text

In [69]:
train_data = json.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/ERC_utterance_level/train_utterance_level.json'))
val_data = json.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/ERC_utterance_level/val_utterance_level.json'))

In [70]:
original_train_data = json.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/Original_Dataset/Subtask_1_train.json'))

In [71]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [72]:
emotion2int = {
    'anger': 0,
    'joy': 1,
    'fear': 2,
    'disgust': 3,
    'neutral': 4,
    'surprise': 5,
    'sadness': 6
}

In [73]:
L = []
for conversation in original_train_data:
    for utterance in conversation['conversation']:
        L.append(emotion2int[utterance['emotion']])
class_weights=compute_class_weight(class_weight='balanced',classes=list(emotion2int.values()), y=np.array(L))

In [74]:
utterance2vec = pickle.load(open('/kaggle/input/semeval3-task-3-dataset/Dataset/Embeddings/sentence_transformer_utterance2vec_768.pkl', 'rb'))

In [75]:
MAX_CONV_LEN = 35
# Defined index 7 for padding
class ERC_Dataset_Utt_Level(Dataset):
    def __init__(self, data, utterance2vec, device):
        self.data = data
        self.utterance2vec = utterance2vec
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[f'id_{idx+1}']['text']
        emotion = self.data[f'id_{idx+1}']['emotion']
        context = self.data[f'id_{idx+1}']['context']

        context_embeddings = [torch.tensor(self.utterance2vec[preprocess_text(utterance)]).to(self.device) for utterance in context]
        target_embedding = torch.tensor(self.utterance2vec[preprocess_text(text)]).to(self.device)
        context_embeddings.append(target_embedding)
                
        if(len(context_embeddings)<MAX_CONV_LEN):
            num_pads = MAX_CONV_LEN - len(context_embeddings)
            attention_mask = [1]*len(context_embeddings) + [0]*num_pads
            context_embeddings = context_embeddings + [torch.zeros(768).to(self.device)]*num_pads
        else:
            context_embeddings = context_embeddings[len(context_embeddings)-MAX_CONV_LEN:]
            attention_mask = [1]*MAX_CONV_LEN

        context_embeddings = torch.stack(context_embeddings)
        attention_mask = torch.tensor(attention_mask)

        return {
            'context_embeddings': context_embeddings,
            'target_embedding': target_embedding,
            'attention_mask': attention_mask,   
            'emotion': emotion2int[emotion]
        }

In [76]:
class BertForSentenceClassificationGivenContext(BertPreTrainedModel):
    def __init__(self, config, weights):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.weights = weights
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size*2, config.num_labels)

        self.post_init()

    def forward(self, context_embeds, target_embeds, attention_mask, labels=None):
        output = self.bert(inputs_embeds=context_embeds, attention_mask=attention_mask)
        out_target = [output.last_hidden_state[b][int(sum(attention_mask[b]))-1] for b in range(output.last_hidden_state.shape[0])]
        out_target = torch.stack(out_target)
        out_target = self.dropout(out_target)
        out_target_cat = torch.cat((out_target, target_embeds), 1)
        logits = self.classifier(out_target_cat)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(weight=self.weights.to(device))
            loss = loss_fct(logits.view(-1, self.num_labels).to(device), labels.view(-1).to(device))

        return {'loss': loss, 'logits': logits}
        

In [77]:
train_dataset = ERC_Dataset_Utt_Level(train_data, utterance2vec, device)
val_dataset = ERC_Dataset_Utt_Level(val_data, utterance2vec, device)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [78]:
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=7)
model = BertForSentenceClassificationGivenContext.from_pretrained('bert-base-uncased', config=config, weights=torch.from_numpy(class_weights).float()).to(device)

Some weights of BertForSentenceClassificationGivenContext were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [79]:
epochs = 30
optimizer = AdamW(model.parameters(), lr=1e-6)

In [80]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_login_key")
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: Currently logged in as: [33mparth-dholariya9221[0m ([33mTECPEC[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [81]:
wandb.init(project='TECPEC', name='BERT_Utt_Level', config={
    'Embedding': 'Sentence-Transformer',
    'Level': 'Utterance Level',
    'Epochs': epochs,
    'Optimizer': 'AdamW',
    'Learning Rate': 1e-6,
    'Batch Size': 16
})

In [82]:
for epoch in range(epochs):
    model.train()
    train_pred, train_true, train_loss = [], [], 0.0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        context_embeddings, target_embedding, emotions, attention_mask = batch['context_embeddings'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
        outputs = model(context_embeds=context_embeddings, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
        loss = outputs['loss']
        loss.backward()
        optimizer.step()
        train_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
        train_true.extend(emotions.tolist())
        train_loss += loss.item()
    train_loss /= len(train_loader) 
    model.eval()
    val_pred, val_true, val_loss = [], [], 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            context_embeddings, target_embedding, emotions, attention_mask = batch['context_embeddings'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
            outputs = model(context_embeds=context_embeddings, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
            loss = outputs['loss']
            val_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
            val_true.extend(emotions.tolist())
            val_loss += loss.item()
            
    val_loss /= len(val_loader)
    train_report = classification_report(train_true, train_pred, target_names=emotion2int.keys(), zero_division=0)
    val_report = classification_report(val_true, val_pred, target_names=emotion2int.keys(), zero_division=0)

    train_report_dict = classification_report(train_true, train_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    val_report_dict = classification_report(val_true, val_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_report_dict['accuracy'],
        'val_accuracy': val_report_dict['accuracy'],
        'Macro train_f1': train_report_dict['macro avg']['f1-score'],
        'Macro val_f1': val_report_dict['macro avg']['f1-score'],
        'Weighted train_f1': train_report_dict['weighted avg']['f1-score'],
        'Weighted val_f1': val_report_dict['weighted avg']['f1-score'],
    })
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Train Report: \n{train_report}")
    print(f"Val Report: \n{val_report}")


100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.58it/s]


Epoch: 1, Train Loss: 1.950147879610577, Val Loss: 1.9272488150545346
Train Report: 
              precision    recall  f1-score   support

       anger       0.13      0.14      0.14      1423
         joy       0.16      0.21      0.18      2047
        fear       0.03      0.04      0.03       336
     disgust       0.02      0.01      0.01       372
     neutral       0.44      0.31      0.37      5299
    surprise       0.15      0.18      0.16      1656
     sadness       0.10      0.15      0.12      1011

    accuracy                           0.23     12144
   macro avg       0.15      0.15      0.14     12144
weighted avg       0.26      0.23      0.24     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.14      0.20      0.17       192
         joy       0.27      0.26      0.27       254
        fear       0.00      0.00      0.00        37
     disgust       0.04      0.10      0.06        42
     neutral       0.46      0.26 

100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.04it/s]


Epoch: 2, Train Loss: 1.9347428466648608, Val Loss: 1.9189553619712911
Train Report: 
              precision    recall  f1-score   support

       anger       0.14      0.21      0.17      1423
         joy       0.20      0.24      0.22      2047
        fear       0.04      0.01      0.02       336
     disgust       0.05      0.05      0.05       372
     neutral       0.46      0.30      0.36      5299
    surprise       0.16      0.20      0.18      1656
     sadness       0.10      0.16      0.12      1011

    accuracy                           0.24     12144
   macro avg       0.16      0.17      0.16     12144
weighted avg       0.28      0.24      0.25     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.17      0.38      0.23       192
         joy       0.25      0.34      0.29       254
        fear       0.14      0.03      0.05        37
     disgust       0.03      0.07      0.05        42
     neutral       0.43      0.21

100%|██████████| 759/759 [01:33<00:00,  8.09it/s]
100%|██████████| 93/93 [00:03<00:00, 24.18it/s]


Epoch: 3, Train Loss: 1.9215461923513801, Val Loss: 1.9028389992252472
Train Report: 
              precision    recall  f1-score   support

       anger       0.17      0.24      0.20      1423
         joy       0.22      0.27      0.24      2047
        fear       0.04      0.02      0.03       336
     disgust       0.03      0.07      0.05       372
     neutral       0.48      0.29      0.36      5299
    surprise       0.17      0.21      0.19      1656
     sadness       0.12      0.17      0.14      1011

    accuracy                           0.24     12144
   macro avg       0.18      0.18      0.17     12144
weighted avg       0.30      0.24      0.26     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.19      0.32      0.24       192
         joy       0.28      0.58      0.37       254
        fear       0.04      0.03      0.03        37
     disgust       0.02      0.05      0.03        42
     neutral       0.48      0.23

100%|██████████| 759/759 [01:33<00:00,  8.09it/s]
100%|██████████| 93/93 [00:03<00:00, 24.22it/s]


Epoch: 4, Train Loss: 1.9063160355069106, Val Loss: 1.8826393030023063
Train Report: 
              precision    recall  f1-score   support

       anger       0.17      0.29      0.21      1423
         joy       0.24      0.30      0.26      2047
        fear       0.06      0.05      0.06       336
     disgust       0.06      0.10      0.08       372
     neutral       0.49      0.28      0.36      5299
    surprise       0.18      0.22      0.19      1656
     sadness       0.14      0.16      0.15      1011

    accuracy                           0.25     12144
   macro avg       0.19      0.20      0.19     12144
weighted avg       0.31      0.25      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.20      0.22      0.21       192
         joy       0.23      0.74      0.35       254
        fear       0.07      0.05      0.06        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.16

100%|██████████| 759/759 [01:33<00:00,  8.09it/s]
100%|██████████| 93/93 [00:03<00:00, 24.31it/s]


Epoch: 5, Train Loss: 1.8807842255581038, Val Loss: 1.8605477886815225
Train Report: 
              precision    recall  f1-score   support

       anger       0.20      0.28      0.23      1423
         joy       0.25      0.40      0.31      2047
        fear       0.07      0.07      0.07       336
     disgust       0.07      0.13      0.09       372
     neutral       0.52      0.26      0.34      5299
    surprise       0.20      0.19      0.20      1656
     sadness       0.14      0.25      0.18      1011

    accuracy                           0.26     12144
   macro avg       0.21      0.22      0.20     12144
weighted avg       0.34      0.26      0.28     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.31      0.25       192
         joy       0.33      0.51      0.40       254
        fear       0.03      0.05      0.03        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.21

100%|██████████| 759/759 [01:33<00:00,  8.09it/s]
100%|██████████| 93/93 [00:03<00:00, 24.17it/s]


Epoch: 6, Train Loss: 1.8586429506736624, Val Loss: 1.860182700618621
Train Report: 
              precision    recall  f1-score   support

       anger       0.21      0.32      0.26      1423
         joy       0.27      0.39      0.32      2047
        fear       0.07      0.15      0.10       336
     disgust       0.07      0.09      0.08       372
     neutral       0.52      0.27      0.35      5299
    surprise       0.21      0.21      0.21      1656
     sadness       0.17      0.26      0.20      1011

    accuracy                           0.28     12144
   macro avg       0.22      0.24      0.22     12144
weighted avg       0.34      0.28      0.29     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.20      0.32      0.24       192
         joy       0.32      0.58      0.41       254
        fear       0.00      0.00      0.00        37
     disgust       0.04      0.29      0.07        42
     neutral       0.51      0.11 

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.26it/s]


Epoch: 7, Train Loss: 1.8297990846382615, Val Loss: 1.8499751821641
Train Report: 
              precision    recall  f1-score   support

       anger       0.23      0.33      0.27      1423
         joy       0.27      0.43      0.34      2047
        fear       0.08      0.12      0.10       336
     disgust       0.08      0.15      0.11       372
     neutral       0.54      0.24      0.33      5299
    surprise       0.23      0.23      0.23      1656
     sadness       0.18      0.30      0.22      1011

    accuracy                           0.28     12144
   macro avg       0.23      0.26      0.23     12144
weighted avg       0.36      0.28      0.29     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.18      0.33      0.23       192
         joy       0.33      0.59      0.42       254
        fear       0.05      0.14      0.07        37
     disgust       0.04      0.10      0.06        42
     neutral       0.53      0.16   

100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.42it/s]


Epoch: 8, Train Loss: 1.8081820235736128, Val Loss: 1.8507854182233092
Train Report: 
              precision    recall  f1-score   support

       anger       0.22      0.29      0.25      1423
         joy       0.29      0.45      0.35      2047
        fear       0.10      0.20      0.13       336
     disgust       0.10      0.20      0.13       372
     neutral       0.54      0.24      0.34      5299
    surprise       0.22      0.21      0.22      1656
     sadness       0.19      0.30      0.23      1011

    accuracy                           0.28     12144
   macro avg       0.24      0.27      0.24     12144
weighted avg       0.36      0.28      0.29     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.18      0.16      0.17       192
         joy       0.33      0.55      0.41       254
        fear       0.04      0.08      0.05        37
     disgust       0.04      0.26      0.08        42
     neutral       0.50      0.12

100%|██████████| 759/759 [01:33<00:00,  8.09it/s]
100%|██████████| 93/93 [00:03<00:00, 24.04it/s]


Epoch: 9, Train Loss: 1.7843365956672095, Val Loss: 1.8326878932214552
Train Report: 
              precision    recall  f1-score   support

       anger       0.25      0.31      0.28      1423
         joy       0.31      0.46      0.37      2047
        fear       0.09      0.20      0.13       336
     disgust       0.11      0.24      0.15       372
     neutral       0.56      0.24      0.34      5299
    surprise       0.25      0.24      0.25      1656
     sadness       0.20      0.35      0.26      1011

    accuracy                           0.29     12144
   macro avg       0.25      0.29      0.25     12144
weighted avg       0.38      0.29      0.30     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.19      0.28      0.23       192
         joy       0.34      0.59      0.43       254
        fear       0.10      0.11      0.10        37
     disgust       0.07      0.19      0.10        42
     neutral       0.50      0.20

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.04it/s]


Epoch: 10, Train Loss: 1.7619772400309446, Val Loss: 1.8421191079642183
Train Report: 
              precision    recall  f1-score   support

       anger       0.25      0.31      0.28      1423
         joy       0.31      0.45      0.37      2047
        fear       0.11      0.22      0.15       336
     disgust       0.11      0.30      0.16       372
     neutral       0.57      0.24      0.34      5299
    surprise       0.25      0.27      0.26      1656
     sadness       0.21      0.36      0.26      1011

    accuracy                           0.30     12144
   macro avg       0.26      0.31      0.26     12144
weighted avg       0.39      0.30      0.31     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.30      0.25       192
         joy       0.34      0.56      0.42       254
        fear       0.06      0.24      0.10        37
     disgust       0.08      0.17      0.10        42
     neutral       0.50      0.1

100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.28it/s]


Epoch: 11, Train Loss: 1.7389725545649473, Val Loss: 1.8479280433347147
Train Report: 
              precision    recall  f1-score   support

       anger       0.27      0.33      0.29      1423
         joy       0.32      0.48      0.38      2047
        fear       0.11      0.24      0.15       336
     disgust       0.11      0.29      0.16       372
     neutral       0.58      0.25      0.35      5299
    surprise       0.26      0.26      0.26      1656
     sadness       0.23      0.37      0.28      1011

    accuracy                           0.31     12144
   macro avg       0.27      0.32      0.27     12144
weighted avg       0.40      0.31      0.32     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.20      0.20       192
         joy       0.33      0.57      0.42       254
        fear       0.07      0.24      0.10        37
     disgust       0.07      0.31      0.11        42
     neutral       0.54      0.1

100%|██████████| 759/759 [01:33<00:00,  8.11it/s]
100%|██████████| 93/93 [00:03<00:00, 24.36it/s]


Epoch: 12, Train Loss: 1.713396757644784, Val Loss: 1.8469729410704745
Train Report: 
              precision    recall  f1-score   support

       anger       0.28      0.32      0.30      1423
         joy       0.33      0.45      0.38      2047
        fear       0.12      0.30      0.17       336
     disgust       0.11      0.31      0.16       372
     neutral       0.58      0.25      0.35      5299
    surprise       0.27      0.30      0.28      1656
     sadness       0.23      0.40      0.29      1011

    accuracy                           0.31     12144
   macro avg       0.27      0.33      0.28     12144
weighted avg       0.40      0.31      0.32     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.23      0.26      0.24       192
         joy       0.34      0.57      0.42       254
        fear       0.06      0.22      0.09        37
     disgust       0.07      0.36      0.12        42
     neutral       0.54      0.16

100%|██████████| 759/759 [01:33<00:00,  8.11it/s]
100%|██████████| 93/93 [00:03<00:00, 24.41it/s]


Epoch: 13, Train Loss: 1.6989038903094407, Val Loss: 1.8659686811508671
Train Report: 
              precision    recall  f1-score   support

       anger       0.29      0.35      0.32      1423
         joy       0.33      0.49      0.39      2047
        fear       0.12      0.31      0.17       336
     disgust       0.11      0.32      0.16       372
     neutral       0.59      0.23      0.33      5299
    surprise       0.26      0.26      0.26      1656
     sadness       0.26      0.42      0.32      1011

    accuracy                           0.32     12144
   macro avg       0.28      0.34      0.28     12144
weighted avg       0.41      0.32      0.32     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.18      0.19       192
         joy       0.32      0.58      0.41       254
        fear       0.06      0.19      0.09        37
     disgust       0.07      0.40      0.11        42
     neutral       0.51      0.1

100%|██████████| 759/759 [01:33<00:00,  8.11it/s]
100%|██████████| 93/93 [00:03<00:00, 24.32it/s]


Epoch: 14, Train Loss: 1.6769749967477068, Val Loss: 1.8500237054722284
Train Report: 
              precision    recall  f1-score   support

       anger       0.29      0.34      0.31      1423
         joy       0.33      0.48      0.39      2047
        fear       0.13      0.34      0.19       336
     disgust       0.12      0.37      0.19       372
     neutral       0.61      0.23      0.34      5299
    surprise       0.27      0.29      0.28      1656
     sadness       0.26      0.43      0.32      1011

    accuracy                           0.32     12144
   macro avg       0.29      0.35      0.29     12144
weighted avg       0.42      0.32      0.33     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.22      0.27      0.24       192
         joy       0.35      0.57      0.43       254
        fear       0.04      0.14      0.07        37
     disgust       0.10      0.31      0.15        42
     neutral       0.52      0.2

100%|██████████| 759/759 [01:33<00:00,  8.11it/s]
100%|██████████| 93/93 [00:03<00:00, 24.33it/s]


Epoch: 15, Train Loss: 1.6633129495248806, Val Loss: 1.8560477008101761
Train Report: 
              precision    recall  f1-score   support

       anger       0.31      0.33      0.32      1423
         joy       0.35      0.48      0.40      2047
        fear       0.14      0.39      0.21       336
     disgust       0.13      0.41      0.20       372
     neutral       0.61      0.25      0.35      5299
    surprise       0.29      0.31      0.30      1656
     sadness       0.24      0.42      0.31      1011

    accuracy                           0.33     12144
   macro avg       0.30      0.37      0.30     12144
weighted avg       0.43      0.33      0.34     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.23      0.25       192
         joy       0.34      0.56      0.43       254
        fear       0.04      0.24      0.07        37
     disgust       0.08      0.29      0.12        42
     neutral       0.56      0.1

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.26it/s]


Epoch: 16, Train Loss: 1.6292067139516235, Val Loss: 1.8662432534720308
Train Report: 
              precision    recall  f1-score   support

       anger       0.31      0.34      0.32      1423
         joy       0.34      0.49      0.40      2047
        fear       0.15      0.45      0.22       336
     disgust       0.14      0.43      0.21       372
     neutral       0.61      0.24      0.34      5299
    surprise       0.29      0.29      0.29      1656
     sadness       0.25      0.43      0.32      1011

    accuracy                           0.33     12144
   macro avg       0.30      0.38      0.30     12144
weighted avg       0.43      0.33      0.33     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.22      0.33      0.26       192
         joy       0.36      0.55      0.43       254
        fear       0.04      0.08      0.05        37
     disgust       0.09      0.24      0.13        42
     neutral       0.56      0.1

100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.21it/s]


Epoch: 17, Train Loss: 1.5974428816275164, Val Loss: 1.8739386912315124
Train Report: 
              precision    recall  f1-score   support

       anger       0.31      0.36      0.33      1423
         joy       0.36      0.50      0.42      2047
        fear       0.15      0.40      0.22       336
     disgust       0.16      0.46      0.23       372
     neutral       0.62      0.25      0.35      5299
    surprise       0.30      0.32      0.31      1656
     sadness       0.27      0.46      0.34      1011

    accuracy                           0.34     12144
   macro avg       0.31      0.39      0.32     12144
weighted avg       0.44      0.34      0.35     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.25      0.24      0.25       192
         joy       0.32      0.59      0.42       254
        fear       0.04      0.16      0.07        37
     disgust       0.10      0.38      0.16        42
     neutral       0.53      0.1

100%|██████████| 759/759 [01:33<00:00,  8.11it/s]
100%|██████████| 93/93 [00:03<00:00, 24.42it/s]


Epoch: 18, Train Loss: 1.5804573128660049, Val Loss: 1.8735531683891051
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.37      0.35      1423
         joy       0.35      0.49      0.41      2047
        fear       0.15      0.46      0.22       336
     disgust       0.16      0.48      0.23       372
     neutral       0.63      0.25      0.36      5299
    surprise       0.31      0.32      0.32      1656
     sadness       0.27      0.46      0.34      1011

    accuracy                           0.34     12144
   macro avg       0.31      0.41      0.32     12144
weighted avg       0.45      0.34      0.35     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.24      0.26       192
         joy       0.39      0.51      0.44       254
        fear       0.05      0.27      0.09        37
     disgust       0.10      0.31      0.15        42
     neutral       0.57      0.2

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.32it/s]


Epoch: 19, Train Loss: 1.572630481365013, Val Loss: 1.8877526278136878
Train Report: 
              precision    recall  f1-score   support

       anger       0.31      0.37      0.34      1423
         joy       0.37      0.51      0.43      2047
        fear       0.16      0.49      0.24       336
     disgust       0.16      0.50      0.25       372
     neutral       0.62      0.25      0.36      5299
    surprise       0.31      0.31      0.31      1656
     sadness       0.27      0.45      0.34      1011

    accuracy                           0.35     12144
   macro avg       0.31      0.41      0.32     12144
weighted avg       0.44      0.35      0.35     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.25      0.26      0.25       192
         joy       0.35      0.55      0.43       254
        fear       0.05      0.14      0.07        37
     disgust       0.10      0.40      0.16        42
     neutral       0.54      0.22

100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.43it/s]


Epoch: 20, Train Loss: 1.5423421561325334, Val Loss: 1.8814446785116707
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.37      0.35      1423
         joy       0.37      0.50      0.43      2047
        fear       0.16      0.50      0.25       336
     disgust       0.17      0.50      0.26       372
     neutral       0.62      0.25      0.36      5299
    surprise       0.32      0.36      0.34      1656
     sadness       0.28      0.47      0.35      1011

    accuracy                           0.36     12144
   macro avg       0.32      0.42      0.33     12144
weighted avg       0.45      0.36      0.36     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.24      0.29      0.26       192
         joy       0.36      0.55      0.44       254
        fear       0.07      0.24      0.11        37
     disgust       0.07      0.14      0.09        42
     neutral       0.56      0.1

100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.41it/s]


Epoch: 21, Train Loss: 1.5313582690494019, Val Loss: 1.8873810857854865
Train Report: 
              precision    recall  f1-score   support

       anger       0.32      0.38      0.35      1423
         joy       0.38      0.51      0.43      2047
        fear       0.17      0.51      0.25       336
     disgust       0.18      0.52      0.27       372
     neutral       0.64      0.27      0.38      5299
    surprise       0.31      0.33      0.32      1656
     sadness       0.29      0.47      0.36      1011

    accuracy                           0.36     12144
   macro avg       0.33      0.43      0.34     12144
weighted avg       0.46      0.36      0.37     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.27      0.27       192
         joy       0.39      0.52      0.45       254
        fear       0.04      0.30      0.07        37
     disgust       0.09      0.29      0.13        42
     neutral       0.58      0.2

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.44it/s]


Epoch: 22, Train Loss: 1.5112209898838098, Val Loss: 1.8911135568413684
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.40      0.37      1423
         joy       0.37      0.50      0.43      2047
        fear       0.18      0.55      0.27       336
     disgust       0.17      0.52      0.26       372
     neutral       0.63      0.26      0.37      5299
    surprise       0.32      0.32      0.32      1656
     sadness       0.29      0.47      0.36      1011

    accuracy                           0.36     12144
   macro avg       0.33      0.43      0.34     12144
weighted avg       0.45      0.36      0.37     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.26      0.27      0.27       192
         joy       0.34      0.56      0.43       254
        fear       0.04      0.16      0.07        37
     disgust       0.10      0.40      0.16        42
     neutral       0.57      0.1

100%|██████████| 759/759 [01:33<00:00,  8.11it/s]
100%|██████████| 93/93 [00:03<00:00, 24.44it/s]


Epoch: 23, Train Loss: 1.5010850506494795, Val Loss: 1.89159405872386
Train Report: 
              precision    recall  f1-score   support

       anger       0.34      0.40      0.37      1423
         joy       0.38      0.51      0.43      2047
        fear       0.18      0.55      0.27       336
     disgust       0.18      0.55      0.27       372
     neutral       0.64      0.27      0.38      5299
    surprise       0.32      0.33      0.33      1656
     sadness       0.31      0.50      0.38      1011

    accuracy                           0.37     12144
   macro avg       0.34      0.44      0.35     12144
weighted avg       0.46      0.37      0.37     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.25      0.26       192
         joy       0.38      0.54      0.45       254
        fear       0.05      0.22      0.08        37
     disgust       0.09      0.31      0.14        42
     neutral       0.55      0.24 

100%|██████████| 759/759 [01:33<00:00,  8.10it/s]
100%|██████████| 93/93 [00:03<00:00, 24.25it/s]


Epoch: 24, Train Loss: 1.472172100160747, Val Loss: 1.911321569514531
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.42      0.39      1423
         joy       0.38      0.52      0.44      2047
        fear       0.19      0.57      0.28       336
     disgust       0.19      0.56      0.28       372
     neutral       0.64      0.27      0.38      5299
    surprise       0.32      0.34      0.33      1656
     sadness       0.31      0.50      0.38      1011

    accuracy                           0.37     12144
   macro avg       0.34      0.45      0.35     12144
weighted avg       0.47      0.37      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.21      0.24       192
         joy       0.39      0.52      0.45       254
        fear       0.04      0.24      0.07        37
     disgust       0.09      0.24      0.13        42
     neutral       0.57      0.22 

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.30it/s]


Epoch: 25, Train Loss: 1.4623276480572969, Val Loss: 1.910038484040127
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.40      0.38      1423
         joy       0.39      0.51      0.44      2047
        fear       0.18      0.57      0.27       336
     disgust       0.19      0.60      0.29       372
     neutral       0.64      0.27      0.37      5299
    surprise       0.33      0.35      0.34      1656
     sadness       0.31      0.50      0.38      1011

    accuracy                           0.37     12144
   macro avg       0.34      0.46      0.35     12144
weighted avg       0.47      0.37      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.25      0.21      0.23       192
         joy       0.38      0.56      0.46       254
        fear       0.05      0.19      0.08        37
     disgust       0.09      0.29      0.13        42
     neutral       0.57      0.24

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.30it/s]


Epoch: 26, Train Loss: 1.4358819163992156, Val Loss: 1.9267870149304789
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.41      0.38      1423
         joy       0.38      0.51      0.44      2047
        fear       0.21      0.63      0.32       336
     disgust       0.18      0.55      0.28       372
     neutral       0.65      0.27      0.38      5299
    surprise       0.34      0.37      0.35      1656
     sadness       0.31      0.52      0.39      1011

    accuracy                           0.38     12144
   macro avg       0.35      0.47      0.36     12144
weighted avg       0.47      0.38      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.22      0.25       192
         joy       0.39      0.56      0.46       254
        fear       0.05      0.24      0.08        37
     disgust       0.08      0.21      0.12        42
     neutral       0.57      0.1

100%|██████████| 759/759 [01:33<00:00,  8.11it/s]
100%|██████████| 93/93 [00:03<00:00, 24.54it/s]


Epoch: 27, Train Loss: 1.4228783734388188, Val Loss: 1.932991373923517
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.40      0.37      1423
         joy       0.40      0.51      0.45      2047
        fear       0.19      0.60      0.29       336
     disgust       0.20      0.62      0.31       372
     neutral       0.66      0.28      0.39      5299
    surprise       0.34      0.36      0.35      1656
     sadness       0.31      0.51      0.39      1011

    accuracy                           0.38     12144
   macro avg       0.35      0.47      0.36     12144
weighted avg       0.48      0.38      0.39     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.25      0.23      0.24       192
         joy       0.38      0.56      0.45       254
        fear       0.05      0.22      0.08        37
     disgust       0.07      0.24      0.11        42
     neutral       0.55      0.25

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.37it/s]


Epoch: 28, Train Loss: 1.4030019496426438, Val Loss: 1.9594216064740253
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.41      0.38      1423
         joy       0.39      0.52      0.45      2047
        fear       0.22      0.62      0.32       336
     disgust       0.21      0.63      0.32       372
     neutral       0.65      0.29      0.40      5299
    surprise       0.35      0.36      0.36      1656
     sadness       0.32      0.53      0.40      1011

    accuracy                           0.39     12144
   macro avg       0.36      0.48      0.38     12144
weighted avg       0.48      0.39      0.40     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.26      0.26      0.26       192
         joy       0.45      0.47      0.46       254
        fear       0.04      0.35      0.07        37
     disgust       0.09      0.17      0.12        42
     neutral       0.58      0.2

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.20it/s]


Epoch: 29, Train Loss: 1.3909649617273032, Val Loss: 1.918106238047282
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.45      0.41      1423
         joy       0.40      0.53      0.46      2047
        fear       0.21      0.63      0.31       336
     disgust       0.22      0.66      0.34       372
     neutral       0.66      0.29      0.40      5299
    surprise       0.36      0.37      0.36      1656
     sadness       0.33      0.52      0.41      1011

    accuracy                           0.40     12144
   macro avg       0.37      0.49      0.38     12144
weighted avg       0.49      0.40      0.40     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.27      0.27       192
         joy       0.40      0.52      0.46       254
        fear       0.05      0.22      0.09        37
     disgust       0.08      0.26      0.12        42
     neutral       0.57      0.24

100%|██████████| 759/759 [01:33<00:00,  8.12it/s]
100%|██████████| 93/93 [00:03<00:00, 24.21it/s]

Epoch: 30, Train Loss: 1.372480718985848, Val Loss: 1.9485335055217947
Train Report: 
              precision    recall  f1-score   support

       anger       0.38      0.46      0.42      1423
         joy       0.40      0.52      0.45      2047
        fear       0.22      0.65      0.32       336
     disgust       0.22      0.63      0.33       372
     neutral       0.66      0.29      0.40      5299
    surprise       0.35      0.37      0.36      1656
     sadness       0.33      0.54      0.41      1011

    accuracy                           0.40     12144
   macro avg       0.37      0.49      0.38     12144
weighted avg       0.49      0.40      0.40     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.26      0.27      0.27       192
         joy       0.35      0.61      0.45       254
        fear       0.07      0.16      0.10        37
     disgust       0.08      0.24      0.12        42
     neutral       0.56      0.21




In [83]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Macro train_f1,▁▁▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
Macro val_f1,▁▂▂▁▄▂▄▃▅▅▅▅▄▆▆▅▆▇▇▇▇▆▇▇▇▇▇▇██
Weighted train_f1,▁▁▂▂▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
Weighted val_f1,▃▂▃▁▅▁▃▁▅▄▄▄▃▅▅▄▄▇▆▆▇▅▇▇▇▆▇██▇
train_accuracy,▁▁▂▂▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_loss,███▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▁▁▁
val_accuracy,▁▂▄▃▅▁▃▁▅▄▄▄▂▅▄▄▅▇▆▆▅▅▇▆▇▇▇▇██
val_loss,▆▆▅▄▃▃▂▂▁▂▂▂▃▂▂▃▃▃▄▄▄▄▄▅▅▆▇█▆▇

0,1
Macro train_f1,0.38476
Macro val_f1,0.26761
Weighted train_f1,0.40351
Weighted val_f1,0.31744
train_accuracy,0.40036
train_loss,1.37248
val_accuracy,0.31864
val_loss,1.94853


In [84]:
torch.save(model, '/kaggle/working/BERT_Utt_Level.pth')