In [16]:
import json
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb
from transformers import RobertaPreTrainedModel, RobertaModel, RobertaConfig
import pickle
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

In [17]:
# Source: https://github.com/LCS2-IIITD/Emotion-Flip-Reasoning/blob/main/Dataloaders/nlp_utils.py
import string
import nltk
import re

numbers = {
    "0":"zero",
    "1":"one",
    "2":"two",
    "3":"three",
    "4":"four",
    "5":"five",
    "6":"six",
    "7":"seven",
    "8":"eight",
    "9":"nine"
}

def remove_puntuations(txt):
    punct = set(string.punctuation)
    txt = " ".join(txt.split("."))
    txt = " ".join(txt.split("!"))
    txt = " ".join(txt.split("?"))
    txt = " ".join(txt.split(":"))
    txt = " ".join(txt.split(";"))
    
    txt = "".join(ch for ch in txt if ch not in punct)
    return txt

def number_to_words(txt):
    for k in numbers.keys():
        txt = txt.replace(k,numbers[k]+" ")
    return txt

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'_',' ',text)
    text = number_to_words(text)
    text = remove_puntuations(text)
    text = ''.join([i if ord(i) < 128 else '' for i in text])
    text = ' '.join(text.split())
    return text

In [18]:
train_data = json.load(open('/kaggle/input/Dataset/ERC_utterance_level/train_utterance_level.json'))
val_data = json.load(open('/kaggle/input/Dataset/ERC_utterance_level/val_utterance_level.json'))

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [20]:
emotion2int = {
    'anger': 0,
    'joy': 1,
    'fear': 2,
    'disgust': 3,
    'neutral': 4,
    'surprise': 5,
    'sadness': 6
}

In [None]:
L = []
for conversation in train_data:
    for utterance in conversation['conversation']:
        L.append(emotion2int[utterance['emotion']])
class_weights=compute_class_weight(class_weight='balanced',classes=list(emotion2int.values()), y=np.array(L))

In [21]:
utterance2vec = pickle.load(open('/kaggle/input/Dataset/Embeddings/sentence_transformer_utterance2vec_786.pkl', 'rb'))

In [22]:
MAX_CONV_LEN = 35
# Defined index 7 for padding
class ERC_Dataset_Utt_Level(Dataset):
    def __init__(self, data, utterance2vec, device):
        self.data = data
        self.utterance2vec = utterance2vec
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[f'id_{idx+1}']['text']
        emotion = self.data[f'id_{idx+1}']['emotion']
        context = self.data[f'id_{idx+1}']['context']

        context_embeddings = [torch.tensor(self.utterance2vec[preprocess_text(utterance)]).to(self.device) for utterance in context]
        target_embedding = torch.tensor(self.utterance2vec[preprocess_text(text)]).to(self.device)
        # context_embeddings.append(target_embedding)
        context_embeddings_cat = [torch.cat((emb, target_embedding)) for emb in context_embeddings]
        
        if(len(context_embeddings_cat)<MAX_CONV_LEN):
            num_pads = MAX_CONV_LEN - len(context_embeddings_cat)
            attention_mask = [1]*len(context_embeddings_cat) + [0]*num_pads
            context_embeddings_cat = context_embeddings_cat + [torch.zeros(768).to(self.device)]*num_pads  
            context_embeddings = context_embeddings + [torch.zeros(384).to(self.device)]*num_pads
        else:
            context_embeddings_cat = context_embeddings_cat[len(context_embeddings_cat)-MAX_CONV_LEN:]
            context_embeddings = context_embeddings[len(context_embeddings)-MAX_CONV_LEN:]
            attention_mask = [1]*MAX_CONV_LEN
        # if(len(context_embeddings)<MAX_CONV_LEN):
        #     num_pads = MAX_CONV_LEN - len(context_embeddings)
        #     attention_mask = [1]*len(context_embeddings) + [0]*num_pads
        #     # context_embeddings_cat = context_embeddings_cat + [torch.zeros(768).to(self.device)]*num_pads  
        #     context_embeddings = context_embeddings + [torch.zeros(786).to(self.device)]*num_pads
        # else:
        #     # context_embeddings_cat = context_embeddings_cat[len(context_embeddings_cat)-MAX_CONV_LEN:]
        #     context_embeddings = context_embeddings[len(context_embeddings)-MAX_CONV_LEN:]
        #     attention_mask = [1]*MAX_CONV_LEN

        context_embeddings_cat = torch.stack(context_embeddings_cat)
        context_embeddings = torch.stack(context_embeddings)
        attention_mask = torch.tensor(attention_mask)

        return {
            'context_embeddings_cat': context_embeddings_cat,
            'context_embeddings': context_embeddings,
            'target_embedding': target_embedding,
            'attention_mask': attention_mask,   
            'emotion': emotion2int[emotion]
        }

In [23]:
class RobertaForSentenceClassificationGivenContext(RobertaPreTrainedModel):
    def __init__(self, config,weights):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.weights = weights
        self.roberta = RobertaModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(int(config.hidden_size*1.5), config.num_labels)
        self.post_init()

    def forward(self, context_embeds, target_embeds, attention_mask, labels=None):
        output = self.roberta(inputs_embeds=context_embeds, attention_mask=attention_mask)
        print(output.shape)
        pooled_output = output.pooler_output
        pooled_output = self.dropout(pooled_output)
        pooled_output_cat = torch.cat((pooled_output, target_embeds), 1)
        logits = self.classifier(pooled_output_cat)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(weight=self.weights)
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {'loss': loss, 'logits': logits}
        

In [24]:
train_dataset = ERC_Dataset_Utt_Level(train_data, utterance2vec, device)
val_dataset = ERC_Dataset_Utt_Level(val_data, utterance2vec, device)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [25]:
config = RobertaConfig.from_pretrained('roberta-base', num_labels=7)
model = RobertaForSentenceClassificationGivenContext.from_pretrained('roberta-base', config=config, weights=torch.from_numpy(class_weights).float()).to(device)

Some weights of RobertaForSentenceClassificationGivenContext were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
epochs = 30
optimizer = AdamW(model.parameters(), lr=1e-6)

In [27]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_login_key")
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [28]:
wandb.init(project='TECPEC', name='RoBERTa_cat_Utt_Level', config={
    'Embedding': 'Sentence-Transformer',
    'Level': 'Utterance Level',
    'Approach': 'Concat each utterance embedding with the target utterance embedding',
    'Epochs': epochs,
    'Optimizer': 'AdamW',
    'Learning Rate': 1e-6,
    'Batch Size': 16
})

In [29]:
for epoch in range(epochs):
    model.train()
    train_pred, train_true, train_loss = [], [], 0.0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        context_embeddings_cat, target_embedding, emotions, attention_mask = batch['context_embeddings_cat'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
        outputs = model(context_embeds=context_embeddings_cat, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
        loss = outputs['loss']
        loss.backward()
        optimizer.step()
        train_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
        train_true.extend(emotions.tolist())
        train_loss += loss.item()
    train_loss /= len(train_loader) 
    model.eval()
    val_pred, val_true, val_loss = [], [], 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            context_embeddings_cat, target_embedding, emotions, attention_mask = batch['context_embeddings_cat'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
            outputs = model(context_embeds=context_embeddings_cat, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
            loss = outputs['loss']
            val_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
            val_true.extend(emotions.tolist())
            val_loss += loss.item()
            
    val_loss /= len(val_loader)
    train_report = classification_report(train_true, train_pred, target_names=emotion2int.keys(), zero_division=0)
    val_report = classification_report(val_true, val_pred, target_names=emotion2int.keys(), zero_division=0)

    train_report_dict = classification_report(train_true, train_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    val_report_dict = classification_report(val_true, val_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_report_dict['accuracy'],
        'val_accuracy': val_report_dict['accuracy'],
        'Macro train_f1': train_report_dict['macro avg']['f1-score'],
        'Macro val_f1': val_report_dict['macro avg']['f1-score'],
        'Weighted train_f1': train_report_dict['weighted avg']['f1-score'],
        'Weighted val_f1': val_report_dict['weighted avg']['f1-score'],
    })
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Train Report: \n{train_report}")
    print(f"Val Report: \n{val_report}")

100%|██████████| 759/759 [01:36<00:00,  7.88it/s]
100%|██████████| 93/93 [00:03<00:00, 24.11it/s]


Epoch: 1, Train Loss: 1.6718021756889634, Val Loss: 1.600694219271342
Train Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00      1423
         joy       0.10      0.01      0.01      2047
        fear       0.03      0.04      0.04       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      0.88      0.58      5299
    surprise       0.04      0.00      0.00      1656
     sadness       0.07      0.06      0.06      1011

    accuracy                           0.39     12144
   macro avg       0.10      0.14      0.10     12144
weighted avg       0.22      0.39      0.26     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      1.00 

100%|██████████| 759/759 [01:36<00:00,  7.90it/s]
100%|██████████| 93/93 [00:03<00:00, 24.13it/s]


Epoch: 2, Train Loss: 1.5994040923627468, Val Loss: 1.5952196557034728
Train Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00      1423
         joy       0.00      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      1.00      0.61      5299
    surprise       0.00      0.00      0.00      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.06      0.14      0.09     12144
weighted avg       0.19      0.44      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      1.00

100%|██████████| 759/759 [01:36<00:00,  7.90it/s]
100%|██████████| 93/93 [00:03<00:00, 24.16it/s]


Epoch: 3, Train Loss: 1.5901452224402246, Val Loss: 1.5770513908837431
Train Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00      1423
         joy       0.00      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      1.00      0.61      5299
    surprise       1.00      0.00      0.00      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.21      0.14      0.09     12144
weighted avg       0.33      0.44      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      1.00

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.07it/s]


Epoch: 4, Train Loss: 1.5743744492687883, Val Loss: 1.5562053201019124
Train Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00      1423
         joy       0.00      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      1.00      0.61      5299
    surprise       0.69      0.01      0.01      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.16      0.14      0.09     12144
weighted avg       0.28      0.44      0.27     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      1.00

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.50it/s]


Epoch: 5, Train Loss: 1.5511426942935889, Val Loss: 1.537509766317183
Train Report: 
              precision    recall  f1-score   support

       anger       0.31      0.00      0.01      1423
         joy       0.50      0.00      0.00      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.44      1.00      0.61      5299
    surprise       0.58      0.05      0.09      1656
     sadness       0.00      0.00      0.00      1011

    accuracy                           0.44     12144
   macro avg       0.26      0.15      0.10     12144
weighted avg       0.39      0.44      0.28     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       192
         joy       0.00      0.00      0.00       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.43      0.99 

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.29it/s]


Epoch: 6, Train Loss: 1.529186386678844, Val Loss: 1.5169675061779637
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.01      0.03      1423
         joy       0.45      0.01      0.01      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.45      0.98      0.62      5299
    surprise       0.50      0.14      0.22      1656
     sadness       0.43      0.00      0.01      1011

    accuracy                           0.45     12144
   macro avg       0.31      0.16      0.13     12144
weighted avg       0.42      0.45      0.30     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.44      0.04      0.07       192
         joy       0.67      0.01      0.02       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.45      0.95 

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.15it/s]


Epoch: 7, Train Loss: 1.509188142768322, Val Loss: 1.5158832771803743
Train Report: 
              precision    recall  f1-score   support

       anger       0.27      0.03      0.05      1423
         joy       0.47      0.02      0.03      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.45      0.96      0.62      5299
    surprise       0.48      0.18      0.26      1656
     sadness       0.43      0.02      0.04      1011

    accuracy                           0.45     12144
   macro avg       0.30      0.17      0.14     12144
weighted avg       0.41      0.45      0.32     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.38      0.06      0.10       192
         joy       0.48      0.05      0.09       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.47      0.89 

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.18it/s]


Epoch: 8, Train Loss: 1.4909585403359455, Val Loss: 1.500956178352397
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.06      0.11      1423
         joy       0.50      0.04      0.07      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.46      0.95      0.62      5299
    surprise       0.48      0.23      0.31      1656
     sadness       0.44      0.04      0.07      1011

    accuracy                           0.46     12144
   macro avg       0.32      0.19      0.17     12144
weighted avg       0.43      0.46      0.34     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.38      0.15      0.22       192
         joy       0.49      0.17      0.26       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.48      0.86 

100%|██████████| 759/759 [01:35<00:00,  7.93it/s]
100%|██████████| 93/93 [00:03<00:00, 24.31it/s]


Epoch: 9, Train Loss: 1.4716182733555871, Val Loss: 1.4886454254068353
Train Report: 
              precision    recall  f1-score   support

       anger       0.32      0.06      0.11      1423
         joy       0.55      0.10      0.16      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.47      0.93      0.63      5299
    surprise       0.47      0.25      0.32      1656
     sadness       0.43      0.07      0.11      1011

    accuracy                           0.47     12144
   macro avg       0.32      0.20      0.19     12144
weighted avg       0.44      0.47      0.37     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.33      0.17      0.23       192
         joy       0.49      0.21      0.30       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.48      0.86

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.18it/s]


Epoch: 10, Train Loss: 1.4616928377013276, Val Loss: 1.4922468604580048
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.10      0.15      1423
         joy       0.49      0.12      0.19      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.48      0.92      0.63      5299
    surprise       0.47      0.25      0.33      1656
     sadness       0.44      0.09      0.15      1011

    accuracy                           0.48     12144
   macro avg       0.32      0.21      0.21     12144
weighted avg       0.43      0.48      0.38     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.39      0.08      0.14       192
         joy       0.48      0.22      0.31       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.49      0.8

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.47it/s]


Epoch: 11, Train Loss: 1.4465552534667556, Val Loss: 1.4871996557840736
Train Report: 
              precision    recall  f1-score   support

       anger       0.37      0.10      0.15      1423
         joy       0.51      0.14      0.22      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.48      0.91      0.63      5299
    surprise       0.48      0.26      0.34      1656
     sadness       0.41      0.13      0.20      1011

    accuracy                           0.48     12144
   macro avg       0.32      0.22      0.22     12144
weighted avg       0.44      0.48      0.39     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.35      0.17      0.23       192
         joy       0.48      0.25      0.33       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.50      0.7

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.09it/s]


Epoch: 12, Train Loss: 1.4318437385464846, Val Loss: 1.4797452444671302
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.11      0.17      1423
         joy       0.47      0.15      0.22      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.49      0.91      0.63      5299
    surprise       0.50      0.27      0.35      1656
     sadness       0.40      0.13      0.20      1011

    accuracy                           0.48     12144
   macro avg       0.32      0.22      0.23     12144
weighted avg       0.44      0.48      0.40     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.34      0.16      0.22       192
         joy       0.47      0.24      0.32       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.49      0.8

100%|██████████| 759/759 [01:36<00:00,  7.90it/s]
100%|██████████| 93/93 [00:03<00:00, 24.12it/s]


Epoch: 13, Train Loss: 1.4191835437053435, Val Loss: 1.4813449100781513
Train Report: 
              precision    recall  f1-score   support

       anger       0.38      0.13      0.20      1423
         joy       0.48      0.18      0.26      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.49      0.90      0.64      5299
    surprise       0.51      0.30      0.37      1656
     sadness       0.42      0.15      0.22      1011

    accuracy                           0.49     12144
   macro avg       0.33      0.24      0.24     12144
weighted avg       0.45      0.49      0.41     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.36      0.14      0.20       192
         joy       0.45      0.26      0.33       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.50      0.8

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.17it/s]


Epoch: 14, Train Loss: 1.4121413984788735, Val Loss: 1.485440561848302
Train Report: 
              precision    recall  f1-score   support

       anger       0.36      0.14      0.20      1423
         joy       0.49      0.18      0.27      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.50      0.89      0.64      5299
    surprise       0.50      0.30      0.38      1656
     sadness       0.43      0.17      0.25      1011

    accuracy                           0.49     12144
   macro avg       0.33      0.24      0.25     12144
weighted avg       0.45      0.49      0.42     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.30      0.11      0.16       192
         joy       0.43      0.28      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.78

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.06it/s]


Epoch: 15, Train Loss: 1.3936398898659959, Val Loss: 1.478259458336779
Train Report: 
              precision    recall  f1-score   support

       anger       0.39      0.15      0.21      1423
         joy       0.50      0.20      0.29      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.50      0.89      0.64      5299
    surprise       0.53      0.32      0.40      1656
     sadness       0.44      0.20      0.27      1011

    accuracy                           0.50     12144
   macro avg       0.34      0.25      0.26     12144
weighted avg       0.46      0.50      0.43     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.35      0.18      0.24       192
         joy       0.46      0.24      0.32       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.49      0.86

100%|██████████| 759/759 [01:36<00:00,  7.90it/s]
100%|██████████| 93/93 [00:03<00:00, 24.30it/s]


Epoch: 16, Train Loss: 1.3813883192611462, Val Loss: 1.4793699710599837
Train Report: 
              precision    recall  f1-score   support

       anger       0.42      0.18      0.26      1423
         joy       0.51      0.22      0.31      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.51      0.88      0.65      5299
    surprise       0.54      0.33      0.41      1656
     sadness       0.45      0.20      0.28      1011

    accuracy                           0.51     12144
   macro avg       0.35      0.26      0.27     12144
weighted avg       0.47      0.51      0.44     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.31      0.20      0.25       192
         joy       0.47      0.24      0.32       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.7

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.22it/s]


Epoch: 17, Train Loss: 1.3718116037616304, Val Loss: 1.484023961328691
Train Report: 
              precision    recall  f1-score   support

       anger       0.38      0.18      0.24      1423
         joy       0.50      0.22      0.31      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.51      0.88      0.65      5299
    surprise       0.54      0.34      0.42      1656
     sadness       0.46      0.21      0.29      1011

    accuracy                           0.51     12144
   macro avg       0.34      0.26      0.27     12144
weighted avg       0.47      0.51      0.44     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.31      0.14      0.19       192
         joy       0.51      0.23      0.32       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.81

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 24.03it/s]


Epoch: 18, Train Loss: 1.361909548912752, Val Loss: 1.4945767079630206
Train Report: 
              precision    recall  f1-score   support

       anger       0.41      0.20      0.27      1423
         joy       0.55      0.24      0.33      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.52      0.88      0.66      5299
    surprise       0.55      0.35      0.43      1656
     sadness       0.43      0.23      0.30      1011

    accuracy                           0.52     12144
   macro avg       0.35      0.27      0.28     12144
weighted avg       0.48      0.52      0.46     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.22      0.24       192
         joy       0.41      0.30      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.53      0.73

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.33it/s]


Epoch: 19, Train Loss: 1.3515316337464827, Val Loss: 1.4983473580370668
Train Report: 
              precision    recall  f1-score   support

       anger       0.42      0.20      0.27      1423
         joy       0.52      0.24      0.33      2047
        fear       0.00      0.00      0.00       336
     disgust       0.00      0.00      0.00       372
     neutral       0.52      0.88      0.66      5299
    surprise       0.56      0.36      0.44      1656
     sadness       0.47      0.25      0.33      1011

    accuracy                           0.52     12144
   macro avg       0.36      0.28      0.29     12144
weighted avg       0.48      0.52      0.46     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.23      0.26       192
         joy       0.39      0.32      0.35       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.52      0.7

100%|██████████| 759/759 [01:35<00:00,  7.94it/s]
100%|██████████| 93/93 [00:03<00:00, 24.36it/s]


Epoch: 20, Train Loss: 1.3374424824915698, Val Loss: 1.4890805316227738
Train Report: 
              precision    recall  f1-score   support

       anger       0.43      0.22      0.29      1423
         joy       0.53      0.26      0.35      2047
        fear       1.00      0.00      0.01       336
     disgust       1.00      0.00      0.01       372
     neutral       0.53      0.87      0.66      5299
    surprise       0.57      0.36      0.44      1656
     sadness       0.46      0.25      0.32      1011

    accuracy                           0.52     12144
   macro avg       0.64      0.28      0.30     12144
weighted avg       0.54      0.52      0.47     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.29      0.21      0.24       192
         joy       0.42      0.28      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.52      0.7

100%|██████████| 759/759 [01:35<00:00,  7.95it/s]
100%|██████████| 93/93 [00:03<00:00, 24.51it/s]


Epoch: 21, Train Loss: 1.3269000983992112, Val Loss: 1.5107266960605499
Train Report: 
              precision    recall  f1-score   support

       anger       0.42      0.24      0.30      1423
         joy       0.55      0.27      0.36      2047
        fear       0.00      0.00      0.00       336
     disgust       1.00      0.00      0.01       372
     neutral       0.53      0.87      0.66      5299
    surprise       0.58      0.36      0.44      1656
     sadness       0.48      0.26      0.33      1011

    accuracy                           0.52     12144
   macro avg       0.51      0.29      0.30     12144
weighted avg       0.52      0.52      0.47     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.22      0.24       192
         joy       0.45      0.26      0.33       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.53      0.7

100%|██████████| 759/759 [01:35<00:00,  7.93it/s]
100%|██████████| 93/93 [00:03<00:00, 24.13it/s]


Epoch: 22, Train Loss: 1.3151847752343682, Val Loss: 1.5020532037622185
Train Report: 
              precision    recall  f1-score   support

       anger       0.42      0.24      0.30      1423
         joy       0.54      0.27      0.36      2047
        fear       0.00      0.00      0.00       336
     disgust       0.33      0.00      0.01       372
     neutral       0.53      0.87      0.66      5299
    surprise       0.58      0.38      0.46      1656
     sadness       0.48      0.27      0.34      1011

    accuracy                           0.53     12144
   macro avg       0.41      0.29      0.30     12144
weighted avg       0.50      0.53      0.48     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.30      0.15      0.20       192
         joy       0.50      0.22      0.30       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.8

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.15it/s]


Epoch: 23, Train Loss: 1.3036655624236357, Val Loss: 1.5077380852032733
Train Report: 
              precision    recall  f1-score   support

       anger       0.47      0.25      0.32      1423
         joy       0.55      0.28      0.37      2047
        fear       1.00      0.00      0.01       336
     disgust       0.80      0.01      0.02       372
     neutral       0.54      0.87      0.67      5299
    surprise       0.57      0.38      0.46      1656
     sadness       0.47      0.29      0.36      1011

    accuracy                           0.53     12144
   macro avg       0.63      0.30      0.32     12144
weighted avg       0.55      0.53      0.48     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.19      0.22       192
         joy       0.42      0.29      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.52      0.7

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.19it/s]


Epoch: 24, Train Loss: 1.296358489157811, Val Loss: 1.501633325571655
Train Report: 
              precision    recall  f1-score   support

       anger       0.45      0.26      0.33      1423
         joy       0.54      0.29      0.38      2047
        fear       1.00      0.01      0.02       336
     disgust       0.62      0.01      0.03       372
     neutral       0.54      0.87      0.67      5299
    surprise       0.58      0.39      0.47      1656
     sadness       0.49      0.28      0.35      1011

    accuracy                           0.54     12144
   macro avg       0.60      0.30      0.32     12144
weighted avg       0.55      0.54      0.49     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.19      0.22       192
         joy       0.42      0.29      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.76 

100%|██████████| 759/759 [01:35<00:00,  7.91it/s]
100%|██████████| 93/93 [00:03<00:00, 23.87it/s]


Epoch: 25, Train Loss: 1.2812524717472915, Val Loss: 1.5095000946393577
Train Report: 
              precision    recall  f1-score   support

       anger       0.44      0.28      0.34      1423
         joy       0.56      0.31      0.40      2047
        fear       0.50      0.01      0.01       336
     disgust       0.70      0.02      0.04       372
     neutral       0.54      0.87      0.67      5299
    surprise       0.60      0.40      0.48      1656
     sadness       0.49      0.29      0.37      1011

    accuracy                           0.54     12144
   macro avg       0.55      0.31      0.33     12144
weighted avg       0.54      0.54      0.50     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.29      0.21      0.24       192
         joy       0.46      0.25      0.32       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.50      0.7

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.17it/s]


Epoch: 26, Train Loss: 1.2740288441988477, Val Loss: 1.5162971865746282
Train Report: 
              precision    recall  f1-score   support

       anger       0.45      0.28      0.34      1423
         joy       0.57      0.31      0.41      2047
        fear       0.50      0.00      0.01       336
     disgust       0.43      0.02      0.03       372
     neutral       0.55      0.87      0.68      5299
    surprise       0.60      0.41      0.48      1656
     sadness       0.50      0.31      0.38      1011

    accuracy                           0.55     12144
   macro avg       0.51      0.31      0.33     12144
weighted avg       0.54      0.55      0.50     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.21      0.24       192
         joy       0.41      0.29      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.7

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.07it/s]


Epoch: 27, Train Loss: 1.2595881293570728, Val Loss: 1.5212213538026298
Train Report: 
              precision    recall  f1-score   support

       anger       0.47      0.29      0.36      1423
         joy       0.58      0.32      0.41      2047
        fear       0.43      0.01      0.02       336
     disgust       0.58      0.02      0.04       372
     neutral       0.55      0.87      0.68      5299
    surprise       0.60      0.42      0.49      1656
     sadness       0.51      0.31      0.38      1011

    accuracy                           0.55     12144
   macro avg       0.53      0.32      0.34     12144
weighted avg       0.55      0.55      0.51     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.19      0.23       192
         joy       0.45      0.28      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.7

100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.19it/s]


Epoch: 28, Train Loss: 1.250493832727666, Val Loss: 1.521008312702179
Train Report: 
              precision    recall  f1-score   support

       anger       0.47      0.30      0.37      1423
         joy       0.59      0.33      0.42      2047
        fear       0.50      0.00      0.01       336
     disgust       0.53      0.04      0.08       372
     neutral       0.56      0.87      0.68      5299
    surprise       0.61      0.41      0.49      1656
     sadness       0.52      0.33      0.40      1011

    accuracy                           0.55     12144
   macro avg       0.54      0.33      0.35     12144
weighted avg       0.55      0.55      0.51     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.19      0.22       192
         joy       0.40      0.30      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.51      0.75 

100%|██████████| 759/759 [01:35<00:00,  7.93it/s]
100%|██████████| 93/93 [00:03<00:00, 24.29it/s]


Epoch: 29, Train Loss: 1.233226150627664, Val Loss: 1.5538828385773527
Train Report: 
              precision    recall  f1-score   support

       anger       0.48      0.31      0.38      1423
         joy       0.59      0.34      0.43      2047
        fear       0.62      0.01      0.03       336
     disgust       0.70      0.07      0.13       372
     neutral       0.57      0.87      0.69      5299
    surprise       0.62      0.43      0.51      1656
     sadness       0.51      0.33      0.40      1011

    accuracy                           0.56     12144
   macro avg       0.58      0.34      0.37     12144
weighted avg       0.57      0.56      0.52     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.28      0.22      0.25       192
         joy       0.39      0.31      0.34       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.52      0.69

100%|██████████| 759/759 [01:35<00:00,  7.93it/s]
100%|██████████| 93/93 [00:03<00:00, 23.95it/s]

Epoch: 30, Train Loss: 1.2320747579667566, Val Loss: 1.5405513323763365
Train Report: 
              precision    recall  f1-score   support

       anger       0.48      0.32      0.38      1423
         joy       0.59      0.34      0.43      2047
        fear       0.67      0.01      0.02       336
     disgust       0.53      0.05      0.09       372
     neutral       0.56      0.87      0.68      5299
    surprise       0.61      0.43      0.50      1656
     sadness       0.52      0.33      0.40      1011

    accuracy                           0.56     12144
   macro avg       0.57      0.33      0.36     12144
weighted avg       0.56      0.56      0.52     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.21      0.24       192
         joy       0.41      0.30      0.35       254
        fear       0.00      0.00      0.00        37
     disgust       0.00      0.00      0.00        42
     neutral       0.52      0.7




In [None]:
for epoch in range(epochs):
    model.train()
    train_pred, train_true, train_loss = [], [], 0.0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        context_embeddings_cat, target_embedding, emotions, attention_mask = batch['context_embeddings'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
        outputs = model(context_embeds=context_embeddings_cat, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
        loss = outputs['loss']
        loss.backward()
        optimizer.step()
        train_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
        train_true.extend(emotions.tolist())
        train_loss += loss.item()
    train_loss /= len(train_loader) 
    model.eval()
    val_pred, val_true, val_loss = [], [], 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            context_embeddings_cat, target_embedding, emotions, attention_mask = batch['context_embeddings'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
            outputs = model(context_embeds=context_embeddings_cat, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
            loss = outputs['loss']
            val_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
            val_true.extend(emotions.tolist())
            val_loss += loss.item()
            
    val_loss /= len(val_loader)
    train_report = classification_report(train_true, train_pred, target_names=emotion2int.keys(), zero_division=0)
    val_report = classification_report(val_true, val_pred, target_names=emotion2int.keys(), zero_division=0)

    train_report_dict = classification_report(train_true, train_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    val_report_dict = classification_report(val_true, val_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_report_dict['accuracy'],
        'val_accuracy': val_report_dict['accuracy'],
        'Macro train_f1': train_report_dict['macro avg']['f1-score'],
        'Macro val_f1': val_report_dict['macro avg']['f1-score'],
        'Weighted train_f1': train_report_dict['weighted avg']['f1-score'],
        'Weighted val_f1': val_report_dict['weighted avg']['f1-score'],
    })
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Train Report: \n{train_report}")
    print(f"Val Report: \n{val_report}")

In [30]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Macro train_f1,▁▁▁▁▁▂▂▃▄▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
Macro val_f1,▁▁▁▂▂▄▄▆▇▇▇▇▇▇▇██████▇▇█▇█████
Weighted train_f1,▁▁▁▁▁▂▃▃▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇████
Weighted val_f1,▁▁▁▂▂▄▄▇▇▇████████████████████
train_accuracy,▁▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train_loss,█▇▇▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁
val_accuracy,▁▁▁▂▃▄▃▆▇▇▆▇▇▆█▇▇▆▅▆▅▇▅▆▆▆▆▅▃▅
val_loss,██▇▅▄▃▃▂▂▂▂▁▁▁▁▁▁▂▂▂▃▂▃▂▃▃▃▃▅▅

0,1
Macro train_f1,0.35964
Macro val_f1,0.25937
Weighted train_f1,0.52056
Weighted val_f1,0.42044
train_accuracy,0.56036
train_loss,1.23207
val_accuracy,0.45559
val_loss,1.54055


In [31]:
torch.save(model, '/kaggle/working/RoBERTa_cat_Utt_Level.pth')