In [39]:
import json
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb
from transformers import BertPreTrainedModel, BertModel, BertConfig
import pickle
from torch.nn.functional import leaky_relu
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

In [40]:
# Source: https://github.com/LCS2-IIITD/Emotion-Flip-Reasoning/blob/main/Dataloaders/nlp_utils.py
import string
import nltk
import re

numbers = {
    "0":"zero",
    "1":"one",
    "2":"two",
    "3":"three",
    "4":"four",
    "5":"five",
    "6":"six",
    "7":"seven",
    "8":"eight",
    "9":"nine"
}

def remove_puntuations(txt):
    punct = set(string.punctuation)
    txt = " ".join(txt.split("."))
    txt = " ".join(txt.split("!"))
    txt = " ".join(txt.split("?"))
    txt = " ".join(txt.split(":"))
    txt = " ".join(txt.split(";"))
    
    txt = "".join(ch for ch in txt if ch not in punct)
    return txt

def number_to_words(txt):
    for k in numbers.keys():
        txt = txt.replace(k,numbers[k]+" ")
    return txt

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'_',' ',text)
    text = number_to_words(text)
    text = remove_puntuations(text)
    text = ''.join([i if ord(i) < 128 else '' for i in text])
    text = ' '.join(text.split())
    return text

In [41]:
train_data = json.load(open('/kaggle/input/nlp-project/Dataset/ERC_utterance_level/train_utterance_level.json'))
val_data = json.load(open('/kaggle/input/nlp-project/Dataset/ERC_utterance_level/val_utterance_level.json'))

In [42]:
original_train_data = json.load(open('/kaggle/input/nlp-project/Dataset/Original_Dataset/Subtask_1_train.json'))

In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [44]:
emotion2int = {
    'anger': 0,
    'joy': 1,
    'fear': 2,
    'disgust': 3,
    'neutral': 4,
    'surprise': 5,
    'sadness': 6
}

In [45]:
L = []
for conversation in original_train_data:
    for utterance in conversation['conversation']:
        L.append(emotion2int[utterance['emotion']])
class_weights=compute_class_weight(class_weight='balanced',classes=list(emotion2int.values()), y=np.array(L))

In [46]:
utterance2vec = pickle.load(open('/kaggle/input/nlp-project/Dataset/Embeddings/sentence_transformer_utterance2vec_384.pkl', 'rb'))

In [47]:
MAX_CONV_LEN = 35
# Defined index 7 for padding
class ERC_Dataset_Utt_Level(Dataset):
    def __init__(self, data, utterance2vec, device):
        self.data = data
        self.utterance2vec = utterance2vec
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[f'id_{idx+1}']['text']
        emotion = self.data[f'id_{idx+1}']['emotion']
        context = self.data[f'id_{idx+1}']['context']

        context_embeddings = [torch.tensor(self.utterance2vec[preprocess_text(utterance)]).to(self.device) for utterance in context]
        target_embedding = torch.tensor(self.utterance2vec[preprocess_text(text)]).to(self.device)
        context_embeddings_cat = [torch.cat((emb, target_embedding)) for emb in context_embeddings]

        if(len(context_embeddings_cat)<MAX_CONV_LEN):
            num_pads = MAX_CONV_LEN - len(context_embeddings_cat)
            attention_mask = [1]*len(context_embeddings_cat) + [0]*num_pads            
            context_embeddings_cat = context_embeddings_cat + [torch.zeros(768).to(self.device)]*num_pads  
            context_embeddings = context_embeddings + [torch.zeros(384).to(self.device)]*num_pads
        else:
            context_embeddings_cat = context_embeddings_cat[len(context_embeddings_cat)-MAX_CONV_LEN:]
            context_embeddings = context_embeddings[len(context_embeddings)-MAX_CONV_LEN:]
            attention_mask = [1]*MAX_CONV_LEN

        context_embeddings_cat = torch.stack(context_embeddings_cat)
        context_embeddings = torch.stack(context_embeddings)
        attention_mask = torch.tensor(attention_mask)
        target_embedding = torch.cat((target_embedding, target_embedding))

        return {
            'context_embeddings_cat': context_embeddings_cat,
            'context_embeddings': context_embeddings,
            'target_embedding': target_embedding,
            'attention_mask': attention_mask,   
            'emotion': emotion2int[emotion]
        }

In [48]:
class BertForSentenceClassificationGivenContext(BertPreTrainedModel):
    def __init__(self, config, weights):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.weights = weights
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size*2, config.num_labels)

        self.post_init()

    def forward(self, context_embeds, target_embeds, attention_mask, labels=None):
        output = self.bert(inputs_embeds=context_embeds, attention_mask=attention_mask)
#         pooled_output = output.pooler_output
#         pooled_output = self.dropout(pooled_output)
#         pooled_output_cat = torch.cat((pooled_output, target_embeds), 1)
#         logits = self.classifier(pooled_output_cat)
        out_target = [output.last_hidden_state[b][int(sum(attention_mask[b]))-1] for b in range(output.last_hidden_state.shape[0])]
        out_target = torch.stack(out_target)
        out_target = self.dropout(out_target)
        out_target_cat = torch.cat((out_target, target_embeds), 1)
        logits = self.classifier(out_target_cat)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(weight=self.weights.to(device))
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(device))

        return {'loss': loss, 'logits': logits}
        

In [49]:
train_dataset = ERC_Dataset_Utt_Level(train_data, utterance2vec, device)
val_dataset = ERC_Dataset_Utt_Level(val_data, utterance2vec, device)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [50]:
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=7)
model = BertForSentenceClassificationGivenContext.from_pretrained('bert-base-uncased', config=config, weights=torch.from_numpy(class_weights).float()).to(device)

Some weights of BertForSentenceClassificationGivenContext were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
epochs = 30
optimizer = AdamW(model.parameters(), lr=1e-6)

In [52]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_login_key")
wandb.login(key=secret_value_0)



True

In [53]:
wandb.init(project='TECPEC', name='BERT_with_target_cat_Utt_Level', config={
    'Embedding': 'Sentence-Transformer',
    'Level': 'Utterance Level',
    'Epochs': epochs,
    'Optimizer': 'AdamW',
    'Learning Rate': 1e-6,
    'Batch Size': 16
})

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [54]:
for epoch in range(epochs):
    model.train()
    train_pred, train_true, train_loss = [], [], 0.0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        context_embeddings_cat, target_embedding, emotions, attention_mask = batch['context_embeddings_cat'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
        outputs = model(context_embeds=context_embeddings_cat, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
        loss = outputs['loss']
        loss.backward()
        optimizer.step()
        train_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
        train_true.extend(emotions.tolist())
        train_loss += loss.item()
    train_loss /= len(train_loader) 
    model.eval()
    val_pred, val_true, val_loss = [], [], 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            context_embeddings_cat, target_embedding, emotions, attention_mask = batch['context_embeddings_cat'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
            outputs = model(context_embeds=context_embeddings_cat, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
            loss = outputs['loss']
            val_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
            val_true.extend(emotions.tolist())
            val_loss += loss.item()
            
    val_loss /= len(val_loader)
    train_report = classification_report(train_true, train_pred, target_names=emotion2int.keys(), zero_division=0)
    val_report = classification_report(val_true, val_pred, target_names=emotion2int.keys(), zero_division=0)

    train_report_dict = classification_report(train_true, train_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    val_report_dict = classification_report(val_true, val_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_report_dict['accuracy'],
        'val_accuracy': val_report_dict['accuracy'],
        'Macro train_f1': train_report_dict['macro avg']['f1-score'],
        'Macro val_f1': val_report_dict['macro avg']['f1-score'],
        'Weighted train_f1': train_report_dict['weighted avg']['f1-score'],
        'Weighted val_f1': val_report_dict['weighted avg']['f1-score'],
    })
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Train Report: \n{train_report}")
    print(f"Val Report: \n{val_report}")


100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.51it/s]


Epoch: 1, Train Loss: 1.957941618204431, Val Loss: 1.9348950578320412
Train Report: 
              precision    recall  f1-score   support

       anger       0.13      0.13      0.13      1423
         joy       0.17      0.20      0.18      2047
        fear       0.02      0.03      0.03       336
     disgust       0.03      0.14      0.05       372
     neutral       0.43      0.21      0.28      5299
    surprise       0.14      0.16      0.15      1656
     sadness       0.10      0.17      0.13      1011

    accuracy                           0.18     12144
   macro avg       0.15      0.15      0.14     12144
weighted avg       0.26      0.18      0.20     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.15      0.17      0.16       192
         joy       0.21      0.27      0.23       254
        fear       0.00      0.00      0.00        37
     disgust       0.05      0.10      0.07        42
     neutral       0.41      0.21 

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.36it/s]


Epoch: 2, Train Loss: 1.937777233375077, Val Loss: 1.9203401675788305
Train Report: 
              precision    recall  f1-score   support

       anger       0.14      0.16      0.15      1423
         joy       0.20      0.27      0.23      2047
        fear       0.04      0.04      0.04       336
     disgust       0.04      0.06      0.04       372
     neutral       0.46      0.26      0.33      5299
    surprise       0.17      0.23      0.20      1656
     sadness       0.11      0.16      0.13      1011

    accuracy                           0.23     12144
   macro avg       0.17      0.17      0.16     12144
weighted avg       0.29      0.23      0.24     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.18      0.16      0.17       192
         joy       0.25      0.29      0.27       254
        fear       0.06      0.03      0.04        37
     disgust       0.04      0.10      0.06        42
     neutral       0.46      0.27 

100%|██████████| 759/759 [01:38<00:00,  7.69it/s]
100%|██████████| 93/93 [00:04<00:00, 22.44it/s]


Epoch: 3, Train Loss: 1.9172747154009673, Val Loss: 1.900863077050896
Train Report: 
              precision    recall  f1-score   support

       anger       0.17      0.18      0.17      1423
         joy       0.21      0.29      0.25      2047
        fear       0.05      0.07      0.05       336
     disgust       0.04      0.07      0.05       372
     neutral       0.48      0.26      0.34      5299
    surprise       0.22      0.29      0.25      1656
     sadness       0.13      0.21      0.16      1011

    accuracy                           0.24     12144
   macro avg       0.18      0.19      0.18     12144
weighted avg       0.31      0.24      0.26     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.20      0.18      0.19       192
         joy       0.29      0.33      0.31       254
        fear       0.07      0.05      0.06        37
     disgust       0.05      0.14      0.07        42
     neutral       0.52      0.20 

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.40it/s]


Epoch: 4, Train Loss: 1.8896190513892293, Val Loss: 1.876508605095648
Train Report: 
              precision    recall  f1-score   support

       anger       0.17      0.20      0.18      1423
         joy       0.24      0.30      0.27      2047
        fear       0.09      0.12      0.10       336
     disgust       0.05      0.11      0.07       372
     neutral       0.48      0.23      0.31      5299
    surprise       0.24      0.34      0.28      1656
     sadness       0.14      0.24      0.17      1011

    accuracy                           0.25     12144
   macro avg       0.20      0.22      0.20     12144
weighted avg       0.32      0.25      0.26     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.20      0.21       192
         joy       0.30      0.46      0.36       254
        fear       0.06      0.22      0.09        37
     disgust       0.06      0.05      0.05        42
     neutral       0.47      0.18 

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.39it/s]


Epoch: 5, Train Loss: 1.8559317072075503, Val Loss: 1.8587350165972145
Train Report: 
              precision    recall  f1-score   support

       anger       0.20      0.22      0.21      1423
         joy       0.25      0.38      0.31      2047
        fear       0.07      0.17      0.10       336
     disgust       0.07      0.11      0.08       372
     neutral       0.50      0.22      0.30      5299
    surprise       0.27      0.36      0.31      1656
     sadness       0.17      0.26      0.21      1011

    accuracy                           0.26     12144
   macro avg       0.22      0.25      0.22     12144
weighted avg       0.34      0.26      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.19      0.19      0.19       192
         joy       0.33      0.41      0.37       254
        fear       0.04      0.24      0.06        37
     disgust       0.10      0.10      0.10        42
     neutral       0.51      0.17

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.44it/s]


Epoch: 6, Train Loss: 1.8110860008181948, Val Loss: 1.845969125788699
Train Report: 
              precision    recall  f1-score   support

       anger       0.22      0.24      0.23      1423
         joy       0.28      0.38      0.32      2047
        fear       0.10      0.26      0.14       336
     disgust       0.10      0.20      0.14       372
     neutral       0.52      0.24      0.33      5299
    surprise       0.30      0.37      0.33      1656
     sadness       0.17      0.27      0.21      1011

    accuracy                           0.28     12144
   macro avg       0.24      0.28      0.24     12144
weighted avg       0.36      0.28      0.30     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.19      0.12      0.15       192
         joy       0.29      0.58      0.38       254
        fear       0.03      0.27      0.06        37
     disgust       0.05      0.07      0.06        42
     neutral       0.53      0.08 

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.44it/s]


Epoch: 7, Train Loss: 1.7771239748742427, Val Loss: 1.8222479909978888
Train Report: 
              precision    recall  f1-score   support

       anger       0.22      0.21      0.21      1423
         joy       0.28      0.41      0.33      2047
        fear       0.09      0.27      0.13       336
     disgust       0.11      0.23      0.15       372
     neutral       0.53      0.24      0.33      5299
    surprise       0.32      0.39      0.35      1656
     sadness       0.20      0.31      0.24      1011

    accuracy                           0.29     12144
   macro avg       0.25      0.29      0.25     12144
weighted avg       0.37      0.29      0.30     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.17      0.16      0.16       192
         joy       0.36      0.44      0.39       254
        fear       0.07      0.16      0.10        37
     disgust       0.05      0.24      0.09        42
     neutral       0.50      0.22

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.45it/s]


Epoch: 8, Train Loss: 1.7502471481858506, Val Loss: 1.8187291006888113
Train Report: 
              precision    recall  f1-score   support

       anger       0.23      0.22      0.23      1423
         joy       0.29      0.43      0.35      2047
        fear       0.12      0.34      0.18       336
     disgust       0.11      0.30      0.16       372
     neutral       0.55      0.25      0.34      5299
    surprise       0.33      0.38      0.35      1656
     sadness       0.21      0.31      0.25      1011

    accuracy                           0.30     12144
   macro avg       0.26      0.32      0.26     12144
weighted avg       0.38      0.30      0.31     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.25      0.11      0.15       192
         joy       0.33      0.48      0.39       254
        fear       0.04      0.22      0.07        37
     disgust       0.07      0.21      0.11        42
     neutral       0.54      0.16

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.54it/s]


Epoch: 9, Train Loss: 1.7147922094945693, Val Loss: 1.8037936392650809
Train Report: 
              precision    recall  f1-score   support

       anger       0.25      0.22      0.24      1423
         joy       0.30      0.44      0.35      2047
        fear       0.11      0.36      0.17       336
     disgust       0.14      0.34      0.20       372
     neutral       0.56      0.25      0.35      5299
    surprise       0.35      0.41      0.37      1656
     sadness       0.21      0.33      0.26      1011

    accuracy                           0.31     12144
   macro avg       0.27      0.33      0.28     12144
weighted avg       0.40      0.31      0.32     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.12      0.16       192
         joy       0.35      0.49      0.41       254
        fear       0.06      0.22      0.10        37
     disgust       0.06      0.21      0.10        42
     neutral       0.52      0.23

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.15it/s]


Epoch: 10, Train Loss: 1.6761113507317302, Val Loss: 1.81237788995107
Train Report: 
              precision    recall  f1-score   support

       anger       0.26      0.23      0.25      1423
         joy       0.30      0.43      0.36      2047
        fear       0.13      0.40      0.20       336
     disgust       0.13      0.38      0.19       372
     neutral       0.57      0.26      0.36      5299
    surprise       0.36      0.39      0.37      1656
     sadness       0.23      0.38      0.29      1011

    accuracy                           0.32     12144
   macro avg       0.28      0.35      0.29     12144
weighted avg       0.41      0.32      0.33     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.19      0.20       192
         joy       0.40      0.46      0.43       254
        fear       0.05      0.24      0.08        37
     disgust       0.07      0.17      0.10        42
     neutral       0.52      0.29 

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.47it/s]


Epoch: 11, Train Loss: 1.656958412904199, Val Loss: 1.8165773922397244
Train Report: 
              precision    recall  f1-score   support

       anger       0.27      0.24      0.25      1423
         joy       0.31      0.43      0.36      2047
        fear       0.13      0.45      0.21       336
     disgust       0.15      0.40      0.21       372
     neutral       0.56      0.26      0.36      5299
    surprise       0.37      0.40      0.38      1656
     sadness       0.25      0.39      0.30      1011

    accuracy                           0.33     12144
   macro avg       0.29      0.37      0.30     12144
weighted avg       0.41      0.33      0.33     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.22      0.18      0.20       192
         joy       0.36      0.49      0.41       254
        fear       0.07      0.24      0.11        37
     disgust       0.07      0.19      0.10        42
     neutral       0.54      0.30

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.49it/s]


Epoch: 12, Train Loss: 1.6279909617816035, Val Loss: 1.8165957594430575
Train Report: 
              precision    recall  f1-score   support

       anger       0.28      0.26      0.27      1423
         joy       0.31      0.44      0.37      2047
        fear       0.14      0.45      0.21       336
     disgust       0.15      0.43      0.22       372
     neutral       0.57      0.27      0.36      5299
    surprise       0.38      0.41      0.39      1656
     sadness       0.24      0.37      0.29      1011

    accuracy                           0.33     12144
   macro avg       0.30      0.37      0.30     12144
weighted avg       0.42      0.33      0.34     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.18      0.21      0.19       192
         joy       0.39      0.44      0.41       254
        fear       0.08      0.22      0.11        37
     disgust       0.06      0.19      0.09        42
     neutral       0.52      0.2

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.41it/s]


Epoch: 13, Train Loss: 1.6018821797352063, Val Loss: 1.81428543021602
Train Report: 
              precision    recall  f1-score   support

       anger       0.27      0.26      0.27      1423
         joy       0.34      0.44      0.39      2047
        fear       0.15      0.49      0.23       336
     disgust       0.15      0.44      0.23       372
     neutral       0.58      0.29      0.39      5299
    surprise       0.38      0.42      0.40      1656
     sadness       0.26      0.39      0.31      1011

    accuracy                           0.35     12144
   macro avg       0.31      0.39      0.32     12144
weighted avg       0.43      0.35      0.36     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.24      0.19      0.21       192
         joy       0.38      0.45      0.41       254
        fear       0.04      0.24      0.08        37
     disgust       0.07      0.19      0.11        42
     neutral       0.54      0.26 

100%|██████████| 759/759 [01:38<00:00,  7.69it/s]
100%|██████████| 93/93 [00:04<00:00, 22.27it/s]


Epoch: 14, Train Loss: 1.5755595358307026, Val Loss: 1.828031211770991
Train Report: 
              precision    recall  f1-score   support

       anger       0.29      0.27      0.28      1423
         joy       0.34      0.46      0.39      2047
        fear       0.16      0.50      0.25       336
     disgust       0.16      0.48      0.24       372
     neutral       0.58      0.27      0.37      5299
    surprise       0.38      0.41      0.40      1656
     sadness       0.26      0.39      0.31      1011

    accuracy                           0.35     12144
   macro avg       0.31      0.40      0.32     12144
weighted avg       0.42      0.35      0.36     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.23      0.20      0.22       192
         joy       0.42      0.41      0.41       254
        fear       0.05      0.22      0.08        37
     disgust       0.10      0.17      0.13        42
     neutral       0.53      0.30

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.41it/s]


Epoch: 15, Train Loss: 1.5368017160845369, Val Loss: 1.8415947018131134
Train Report: 
              precision    recall  f1-score   support

       anger       0.30      0.29      0.29      1423
         joy       0.33      0.46      0.38      2047
        fear       0.17      0.54      0.25       336
     disgust       0.17      0.49      0.26       372
     neutral       0.59      0.29      0.39      5299
    surprise       0.40      0.43      0.41      1656
     sadness       0.28      0.42      0.34      1011

    accuracy                           0.36     12144
   macro avg       0.32      0.41      0.33     12144
weighted avg       0.44      0.36      0.37     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.20      0.20       192
         joy       0.42      0.42      0.42       254
        fear       0.04      0.22      0.07        37
     disgust       0.07      0.19      0.11        42
     neutral       0.54      0.2

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.57it/s]


Epoch: 16, Train Loss: 1.5142213627597874, Val Loss: 1.8583620793075972
Train Report: 
              precision    recall  f1-score   support

       anger       0.31      0.30      0.30      1423
         joy       0.34      0.47      0.40      2047
        fear       0.18      0.53      0.27       336
     disgust       0.20      0.55      0.29       372
     neutral       0.60      0.29      0.39      5299
    surprise       0.41      0.43      0.42      1656
     sadness       0.28      0.45      0.35      1011

    accuracy                           0.37     12144
   macro avg       0.33      0.43      0.35     12144
weighted avg       0.45      0.37      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.19      0.20      0.20       192
         joy       0.43      0.39      0.41       254
        fear       0.04      0.22      0.07        37
     disgust       0.08      0.24      0.11        42
     neutral       0.52      0.2

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.57it/s]


Epoch: 17, Train Loss: 1.4833797063280942, Val Loss: 1.8595144370550751
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.32      0.32      1423
         joy       0.34      0.49      0.40      2047
        fear       0.20      0.60      0.30       336
     disgust       0.19      0.57      0.29       372
     neutral       0.62      0.29      0.40      5299
    surprise       0.41      0.44      0.42      1656
     sadness       0.30      0.44      0.36      1011

    accuracy                           0.38     12144
   macro avg       0.34      0.45      0.36     12144
weighted avg       0.46      0.38      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.19      0.20       192
         joy       0.41      0.42      0.42       254
        fear       0.06      0.19      0.09        37
     disgust       0.08      0.21      0.11        42
     neutral       0.53      0.3

100%|██████████| 759/759 [01:38<00:00,  7.69it/s]
100%|██████████| 93/93 [00:04<00:00, 22.27it/s]


Epoch: 18, Train Loss: 1.4636637936781831, Val Loss: 1.866604846010926
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.33      0.33      1423
         joy       0.35      0.49      0.41      2047
        fear       0.20      0.59      0.30       336
     disgust       0.21      0.59      0.31       372
     neutral       0.60      0.30      0.40      5299
    surprise       0.42      0.45      0.43      1656
     sadness       0.31      0.43      0.36      1011

    accuracy                           0.38     12144
   macro avg       0.35      0.45      0.36     12144
weighted avg       0.45      0.38      0.39     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.19      0.21      0.20       192
         joy       0.38      0.46      0.42       254
        fear       0.06      0.19      0.09        37
     disgust       0.07      0.21      0.11        42
     neutral       0.55      0.32

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.59it/s]


Epoch: 19, Train Loss: 1.4521102037511482, Val Loss: 1.8778959858802058
Train Report: 
              precision    recall  f1-score   support

       anger       0.33      0.34      0.34      1423
         joy       0.36      0.50      0.42      2047
        fear       0.21      0.61      0.32       336
     disgust       0.21      0.58      0.31       372
     neutral       0.61      0.30      0.40      5299
    surprise       0.43      0.45      0.43      1656
     sadness       0.31      0.46      0.37      1011

    accuracy                           0.39     12144
   macro avg       0.35      0.46      0.37     12144
weighted avg       0.46      0.39      0.39     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.19      0.24      0.21       192
         joy       0.41      0.41      0.41       254
        fear       0.07      0.16      0.09        37
     disgust       0.07      0.17      0.10        42
     neutral       0.54      0.3

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.28it/s]


Epoch: 20, Train Loss: 1.4124457596633118, Val Loss: 1.8939954285980554
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.36      0.35      1423
         joy       0.38      0.49      0.43      2047
        fear       0.23      0.62      0.34       336
     disgust       0.23      0.60      0.33       372
     neutral       0.61      0.34      0.43      5299
    surprise       0.43      0.44      0.43      1656
     sadness       0.32      0.47      0.38      1011

    accuracy                           0.41     12144
   macro avg       0.36      0.48      0.38     12144
weighted avg       0.47      0.41      0.41     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.18      0.19      0.19       192
         joy       0.39      0.46      0.42       254
        fear       0.06      0.14      0.08        37
     disgust       0.08      0.21      0.11        42
     neutral       0.54      0.3

100%|██████████| 759/759 [01:38<00:00,  7.69it/s]
100%|██████████| 93/93 [00:04<00:00, 22.50it/s]


Epoch: 21, Train Loss: 1.3934991623573152, Val Loss: 1.9172395666440327
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.37      0.36      1423
         joy       0.37      0.51      0.43      2047
        fear       0.23      0.64      0.33       336
     disgust       0.24      0.65      0.35       372
     neutral       0.63      0.32      0.42      5299
    surprise       0.43      0.45      0.44      1656
     sadness       0.33      0.47      0.39      1011

    accuracy                           0.40     12144
   macro avg       0.37      0.49      0.39     12144
weighted avg       0.48      0.40      0.41     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.18      0.16      0.17       192
         joy       0.39      0.45      0.42       254
        fear       0.05      0.24      0.08        37
     disgust       0.09      0.19      0.12        42
     neutral       0.54      0.3

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.36it/s]


Epoch: 22, Train Loss: 1.3704370572004707, Val Loss: 1.9045645363869206
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.37      0.36      1423
         joy       0.36      0.50      0.42      2047
        fear       0.25      0.67      0.36       336
     disgust       0.24      0.64      0.35       372
     neutral       0.64      0.32      0.43      5299
    surprise       0.43      0.47      0.45      1656
     sadness       0.33      0.48      0.39      1011

    accuracy                           0.41     12144
   macro avg       0.37      0.49      0.39     12144
weighted avg       0.48      0.41      0.41     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.22      0.19      0.20       192
         joy       0.40      0.46      0.43       254
        fear       0.06      0.22      0.09        37
     disgust       0.09      0.19      0.12        42
     neutral       0.54      0.3

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.54it/s]


Epoch: 23, Train Loss: 1.3591339474296067, Val Loss: 1.9241803307687082
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.40      0.38      1423
         joy       0.39      0.48      0.43      2047
        fear       0.25      0.66      0.36       336
     disgust       0.24      0.64      0.35       372
     neutral       0.62      0.35      0.45      5299
    surprise       0.45      0.46      0.46      1656
     sadness       0.34      0.50      0.41      1011

    accuracy                           0.42     12144
   macro avg       0.38      0.50      0.41     12144
weighted avg       0.48      0.42      0.43     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.20      0.26      0.22       192
         joy       0.38      0.48      0.42       254
        fear       0.06      0.24      0.09        37
     disgust       0.10      0.19      0.13        42
     neutral       0.54      0.2

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.35it/s]


Epoch: 24, Train Loss: 1.3313359527091577, Val Loss: 1.9471025543828164
Train Report: 
              precision    recall  f1-score   support

       anger       0.35      0.40      0.37      1423
         joy       0.38      0.52      0.44      2047
        fear       0.26      0.68      0.38       336
     disgust       0.26      0.65      0.37       372
     neutral       0.63      0.33      0.43      5299
    surprise       0.46      0.47      0.46      1656
     sadness       0.34      0.50      0.41      1011

    accuracy                           0.42     12144
   macro avg       0.38      0.51      0.41     12144
weighted avg       0.49      0.42      0.43     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.20      0.21      0.20       192
         joy       0.40      0.46      0.42       254
        fear       0.06      0.24      0.09        37
     disgust       0.09      0.17      0.11        42
     neutral       0.54      0.3

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.43it/s]


Epoch: 25, Train Loss: 1.3014808917548188, Val Loss: 1.9583943652850326
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.42      0.39      1423
         joy       0.37      0.55      0.45      2047
        fear       0.28      0.73      0.41       336
     disgust       0.28      0.67      0.39       372
     neutral       0.63      0.30      0.41      5299
    surprise       0.45      0.46      0.46      1656
     sadness       0.34      0.50      0.41      1011

    accuracy                           0.42     12144
   macro avg       0.39      0.52      0.42     12144
weighted avg       0.49      0.42      0.42     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.20      0.19      0.20       192
         joy       0.41      0.47      0.44       254
        fear       0.06      0.19      0.09        37
     disgust       0.09      0.19      0.12        42
     neutral       0.54      0.3

100%|██████████| 759/759 [01:38<00:00,  7.69it/s]
100%|██████████| 93/93 [00:04<00:00, 22.56it/s]


Epoch: 26, Train Loss: 1.2727802123319805, Val Loss: 1.9795721762923784
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.43      0.41      1423
         joy       0.39      0.52      0.45      2047
        fear       0.31      0.76      0.44       336
     disgust       0.29      0.69      0.41       372
     neutral       0.64      0.35      0.45      5299
    surprise       0.47      0.49      0.48      1656
     sadness       0.35      0.52      0.42      1011

    accuracy                           0.44     12144
   macro avg       0.41      0.54      0.44     12144
weighted avg       0.50      0.44      0.45     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.21      0.21       192
         joy       0.45      0.41      0.43       254
        fear       0.06      0.24      0.09        37
     disgust       0.09      0.17      0.12        42
     neutral       0.53      0.3

100%|██████████| 759/759 [01:38<00:00,  7.69it/s]
100%|██████████| 93/93 [00:04<00:00, 22.27it/s]


Epoch: 27, Train Loss: 1.2688150051711262, Val Loss: 1.9913663229634684
Train Report: 
              precision    recall  f1-score   support

       anger       0.38      0.43      0.41      1423
         joy       0.39      0.51      0.44      2047
        fear       0.30      0.74      0.43       336
     disgust       0.28      0.70      0.40       372
     neutral       0.63      0.34      0.44      5299
    surprise       0.46      0.47      0.47      1656
     sadness       0.36      0.53      0.43      1011

    accuracy                           0.44     12144
   macro avg       0.40      0.53      0.43     12144
weighted avg       0.49      0.44      0.44     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.24      0.23       192
         joy       0.43      0.45      0.44       254
        fear       0.06      0.22      0.10        37
     disgust       0.10      0.17      0.13        42
     neutral       0.54      0.3

100%|██████████| 759/759 [01:38<00:00,  7.68it/s]
100%|██████████| 93/93 [00:04<00:00, 22.45it/s]


Epoch: 28, Train Loss: 1.2434481361952068, Val Loss: 1.9963045639376487
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.45      0.42      1423
         joy       0.38      0.52      0.44      2047
        fear       0.31      0.76      0.44       336
     disgust       0.29      0.72      0.42       372
     neutral       0.64      0.34      0.44      5299
    surprise       0.48      0.51      0.49      1656
     sadness       0.38      0.55      0.45      1011

    accuracy                           0.45     12144
   macro avg       0.41      0.55      0.44     12144
weighted avg       0.51      0.45      0.45     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.24      0.22       192
         joy       0.39      0.47      0.43       254
        fear       0.04      0.14      0.06        37
     disgust       0.09      0.21      0.12        42
     neutral       0.55      0.3

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.64it/s]


Epoch: 29, Train Loss: 1.2056947810298055, Val Loss: 2.025422869190093
Train Report: 
              precision    recall  f1-score   support

       anger       0.41      0.47      0.44      1423
         joy       0.41      0.53      0.46      2047
        fear       0.34      0.77      0.47       336
     disgust       0.35      0.75      0.47       372
     neutral       0.65      0.36      0.46      5299
    surprise       0.47      0.52      0.49      1656
     sadness       0.36      0.54      0.43      1011

    accuracy                           0.46     12144
   macro avg       0.43      0.56      0.46     12144
weighted avg       0.51      0.46      0.46     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.21      0.27      0.24       192
         joy       0.38      0.49      0.43       254
        fear       0.05      0.16      0.08        37
     disgust       0.10      0.14      0.12        42
     neutral       0.57      0.29

100%|██████████| 759/759 [01:38<00:00,  7.67it/s]
100%|██████████| 93/93 [00:04<00:00, 22.47it/s]

Epoch: 30, Train Loss: 1.2017905665402042, Val Loss: 2.0461592366618495
Train Report: 
              precision    recall  f1-score   support

       anger       0.40      0.49      0.44      1423
         joy       0.39      0.56      0.46      2047
        fear       0.33      0.77      0.47       336
     disgust       0.33      0.73      0.46       372
     neutral       0.66      0.33      0.44      5299
    surprise       0.49      0.51      0.50      1656
     sadness       0.39      0.55      0.45      1011

    accuracy                           0.46     12144
   macro avg       0.43      0.57      0.46     12144
weighted avg       0.52      0.46      0.46     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.20      0.26      0.23       192
         joy       0.45      0.41      0.43       254
        fear       0.05      0.16      0.08        37
     disgust       0.09      0.14      0.11        42
     neutral       0.55      0.3




In [55]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.059 MB uploaded\r'), FloatProgress(value=0.02339660211324148, max=1.…

0,1
Macro train_f1,▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
Macro val_f1,▁▃▄▅▄▃▆▅▆▆▇▇▆▇▇▆▇▇▇▇▇▇▇▇█▇█▇▇▇
Weighted train_f1,▁▂▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█▇███
Weighted val_f1,▁▃▃▃▃▁▅▃▅▆▇▆▆▇▆▆▇▇▇▇▇▇▆▇▇▇█▇▇█
train_accuracy,▁▂▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█▇███
train_loss,███▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
val_accuracy,▁▃▃▄▃▂▅▄▅▆▇▆▅▆▆▆▇▇▇▇▆▇▆▇▇▇█▇▇▇
val_loss,▅▄▄▃▃▂▂▁▁▁▁▁▁▂▂▃▃▃▃▄▄▄▄▅▅▆▆▇▇█

0,1
Macro train_f1,0.45998
Macro val_f1,0.26058
Weighted train_f1,0.4553
Weighted val_f1,0.3512
train_accuracy,0.45693
train_loss,1.20179
val_accuracy,0.32814
val_loss,2.04616


In [56]:
torch.save(model, '/kaggle/working/BERT_add_Utt_Level.pth')