In [1]:
import json
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb
from transformers import BertPreTrainedModel, BertModel, BertConfig
import pickle
from torch.nn.functional import leaky_relu

In [2]:
# Source: https://github.com/LCS2-IIITD/Emotion-Flip-Reasoning/blob/main/Dataloaders/nlp_utils.py
import string
import nltk
import re

numbers = {
    "0":"zero",
    "1":"one",
    "2":"two",
    "3":"three",
    "4":"four",
    "5":"five",
    "6":"six",
    "7":"seven",
    "8":"eight",
    "9":"nine"
}

def remove_puntuations(txt):
    punct = set(string.punctuation)
    txt = " ".join(txt.split("."))
    txt = " ".join(txt.split("!"))
    txt = " ".join(txt.split("?"))
    txt = " ".join(txt.split(":"))
    txt = " ".join(txt.split(";"))
    
    txt = "".join(ch for ch in txt if ch not in punct)
    return txt

def number_to_words(txt):
    for k in numbers.keys():
        txt = txt.replace(k,numbers[k]+" ")
    return txt

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'_',' ',text)
    text = number_to_words(text)
    text = remove_puntuations(text)
    text = ''.join([i if ord(i) < 128 else '' for i in text])
    text = ' '.join(text.split())
    return text

In [3]:
train_data = json.load(open('/kaggle/input/Dataset/ERC_utterance_level/train_utterance_level.json'))
val_data = json.load(open('/kaggle/input/Dataset/ERC_utterance_level/val_utterance_level.json'))

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
emotion2int = {
    'anger': 0,
    'joy': 1,
    'fear': 2,
    'disgust': 3,
    'neutral': 4,
    'surprise': 5,
    'sadness': 6
}

In [6]:
utterance2vec = pickle.load(open('/kaggle/input/Dataset/Embeddings/sentence_transformer_utterance2vec_768.pkl', 'rb'))

In [7]:
MAX_CONV_LEN = 35
# Defined index 7 for padding
class ERC_Dataset_Utt_Level(Dataset):
    def __init__(self, data, utterance2vec, device):
        self.data = data
        self.utterance2vec = utterance2vec
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[f'id_{idx+1}']['text']
        emotion = self.data[f'id_{idx+1}']['emotion']
        context = self.data[f'id_{idx+1}']['context']

        context_embeddings = [torch.tensor(self.utterance2vec[preprocess_text(utterance)]).to(self.device) for utterance in context]
        target_embedding = torch.tensor(self.utterance2vec[preprocess_text(text)]).to(self.device)
        context_embeddings_add = [emb+target_embedding for emb in context_embeddings]
                
        if(len(context_embeddings_add)<MAX_CONV_LEN):
            num_pads = MAX_CONV_LEN - len(context_embeddings_add)
            attention_mask = [1]*len(context_embeddings_add) + [0]*num_pads
            context_embeddings_add = context_embeddings_add + [torch.zeros(768).to(self.device)]*num_pads  
            context_embeddings = context_embeddings + [torch.zeros(768).to(self.device)]*num_pads
        else:
            context_embeddings_add = context_embeddings_add[len(context_embeddings_add)-MAX_CONV_LEN:]
            context_embeddings = context_embeddings[len(context_embeddings)-MAX_CONV_LEN:]
            attention_mask = [1]*MAX_CONV_LEN

        context_embeddings_add = torch.stack(context_embeddings_add)
        context_embeddings = torch.stack(context_embeddings)
        attention_mask = torch.tensor(attention_mask)

        return {
            'context_embeddings_add': context_embeddings_add,
            'context_embeddings': context_embeddings,
            'target_embedding': target_embedding,
            'attention_mask': attention_mask,   
            'emotion': emotion2int[emotion]
        }

In [8]:
class BertForSentenceClassificationGivenContext(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size*2, config.num_labels)

        self.post_init()

    def forward(self, context_embeds, target_embeds, attention_mask, labels=None):
        output = self.bert(inputs_embeds=context_embeds, attention_mask=attention_mask)
        pooled_output = output.pooler_output
        pooled_output = self.dropout(pooled_output)
        pooled_output_cat = torch.cat((pooled_output, target_embeds), 1)
        logits = self.classifier(pooled_output_cat)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {'loss': loss, 'logits': logits}
        

In [9]:
train_dataset = ERC_Dataset_Utt_Level(train_data, utterance2vec, device)
val_dataset = ERC_Dataset_Utt_Level(val_data, utterance2vec, device)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [10]:
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=7)
model = BertForSentenceClassificationGivenContext.from_pretrained('bert-base-uncased', config=config).to(device)

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSentenceClassificationGivenContext were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
epochs = 30
optimizer = AdamW(model.parameters(), lr=1e-6)

In [12]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_login_key")
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [13]:
wandb.init(project='TECPEC', name='BERT_add_Utt_Level', config={
    'Embedding': 'Sentence-Transformer',
    'Level': 'Utterance Level',
    'Approach': 'Added each utterance embedding with the target utterance embedding',
    'Epochs': epochs,
    'Optimizer': 'AdamW',
    'Learning Rate': 1e-6,
    'Batch Size': 16
})

[34m[1mwandb[0m: Currently logged in as: [33mshreyas21563[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
for epoch in range(epochs):
    model.train()
    train_pred, train_true, train_loss = [], [], 0.0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        context_embeddings_add, target_embedding, emotions, attention_mask = batch['context_embeddings_add'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
        outputs = model(context_embeds=context_embeddings_add, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
        loss = outputs['loss']
        loss.backward()
        optimizer.step()
        train_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
        train_true.extend(emotions.tolist())
        train_loss += loss.item()
    train_loss /= len(train_loader) 
    model.eval()
    val_pred, val_true, val_loss = [], [], 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            context_embeddings_add, target_embedding, emotions, attention_mask = batch['context_embeddings_add'].to(device), batch['target_embedding'].to(device), batch['emotion'].to(device), batch['attention_mask'].to(device)
            outputs = model(context_embeds=context_embeddings_add, target_embeds=target_embedding, attention_mask=attention_mask, labels=emotions)
            loss = outputs['loss']
            val_pred.extend(torch.argmax(outputs['logits'], 1).tolist())
            val_true.extend(emotions.tolist())
            val_loss += loss.item()
            
    val_loss /= len(val_loader)
    train_report = classification_report(train_true, train_pred, target_names=emotion2int.keys(), zero_division=0)
    val_report = classification_report(val_true, val_pred, target_names=emotion2int.keys(), zero_division=0)

    train_report_dict = classification_report(train_true, train_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    val_report_dict = classification_report(val_true, val_pred, target_names=emotion2int.keys(), output_dict=True, zero_division=0)
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_report_dict['accuracy'],
        'val_accuracy': val_report_dict['accuracy'],
        'Macro train_f1': train_report_dict['macro avg']['f1-score'],
        'Macro val_f1': val_report_dict['macro avg']['f1-score'],
        'Weighted train_f1': train_report_dict['weighted avg']['f1-score'],
        'Weighted val_f1': val_report_dict['weighted avg']['f1-score'],
    })
    print(f"Epoch: {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Train Report: \n{train_report}")
    print(f"Val Report: \n{val_report}")


100%|██████████| 759/759 [01:35<00:00,  7.92it/s]
100%|██████████| 93/93 [00:03<00:00, 24.84it/s]


Epoch: 26, Train Loss: 1.1710549161054087, Val Loss: 1.5348472409350897
Train Report: 
              precision    recall  f1-score   support

       anger       0.46      0.40      0.43      1423
         joy       0.63      0.44      0.51      2047
        fear       1.00      0.01      0.01       336
     disgust       0.58      0.02      0.04       372
     neutral       0.61      0.85      0.71      5299
    surprise       0.60      0.52      0.56      1656
     sadness       0.50      0.34      0.40      1011

    accuracy                           0.59     12144
   macro avg       0.63      0.37      0.38     12144
weighted avg       0.60      0.59      0.56     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.27      0.25      0.26       192
         joy       0.48      0.43      0.46       254
        fear       0.00      0.00      0.00        37
     disgust       0.50      0.05      0.09        42
     neutral       0.52      0.6

100%|██████████| 759/759 [01:35<00:00,  7.95it/s]
100%|██████████| 93/93 [00:03<00:00, 24.77it/s]


Epoch: 27, Train Loss: 1.160197956759939, Val Loss: 1.5424878283213543
Train Report: 
              precision    recall  f1-score   support

       anger       0.48      0.40      0.43      1423
         joy       0.62      0.43      0.51      2047
        fear       0.60      0.01      0.02       336
     disgust       0.50      0.02      0.05       372
     neutral       0.61      0.85      0.71      5299
    surprise       0.62      0.53      0.57      1656
     sadness       0.51      0.35      0.42      1011

    accuracy                           0.59     12144
   macro avg       0.56      0.37      0.39     12144
weighted avg       0.58      0.59      0.56     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.26      0.26      0.26       192
         joy       0.52      0.41      0.45       254
        fear       0.00      0.00      0.00        37
     disgust       0.33      0.05      0.08        42
     neutral       0.52      0.67

100%|██████████| 759/759 [01:35<00:00,  7.96it/s]
100%|██████████| 93/93 [00:03<00:00, 24.87it/s]


Epoch: 28, Train Loss: 1.14800644965354, Val Loss: 1.5495244412011997
Train Report: 
              precision    recall  f1-score   support

       anger       0.45      0.39      0.42      1423
         joy       0.63      0.45      0.52      2047
        fear       0.60      0.01      0.02       336
     disgust       0.50      0.04      0.08       372
     neutral       0.62      0.85      0.71      5299
    surprise       0.62      0.52      0.57      1656
     sadness       0.51      0.37      0.43      1011

    accuracy                           0.59     12144
   macro avg       0.56      0.38      0.39     12144
weighted avg       0.59      0.59      0.56     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.24      0.28      0.25       192
         joy       0.51      0.38      0.43       254
        fear       0.00      0.00      0.00        37
     disgust       0.33      0.05      0.08        42
     neutral       0.52      0.65 

100%|██████████| 759/759 [01:35<00:00,  7.95it/s]
100%|██████████| 93/93 [00:03<00:00, 24.82it/s]


Epoch: 29, Train Loss: 1.1253829558376267, Val Loss: 1.5543633487916761
Train Report: 
              precision    recall  f1-score   support

       anger       0.51      0.45      0.47      1423
         joy       0.64      0.45      0.53      2047
        fear       0.67      0.01      0.02       336
     disgust       0.60      0.06      0.10       372
     neutral       0.63      0.86      0.72      5299
    surprise       0.64      0.55      0.59      1656
     sadness       0.52      0.39      0.45      1011

    accuracy                           0.61     12144
   macro avg       0.60      0.39      0.41     12144
weighted avg       0.61      0.61      0.58     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.26      0.21      0.23       192
         joy       0.49      0.39      0.44       254
        fear       0.00      0.00      0.00        37
     disgust       0.44      0.10      0.16        42
     neutral       0.51      0.6

100%|██████████| 759/759 [01:35<00:00,  7.97it/s]
100%|██████████| 93/93 [00:03<00:00, 24.80it/s]

Epoch: 30, Train Loss: 1.1159638771112415, Val Loss: 1.5802325343572965
Train Report: 
              precision    recall  f1-score   support

       anger       0.50      0.43      0.46      1423
         joy       0.65      0.47      0.55      2047
        fear       0.71      0.03      0.06       336
     disgust       0.57      0.08      0.15       372
     neutral       0.63      0.86      0.73      5299
    surprise       0.64      0.55      0.59      1656
     sadness       0.53      0.40      0.45      1011

    accuracy                           0.61     12144
   macro avg       0.61      0.40      0.43     12144
weighted avg       0.61      0.61      0.59     12144

Val Report: 
              precision    recall  f1-score   support

       anger       0.26      0.23      0.24       192
         joy       0.49      0.42      0.45       254
        fear       0.00      0.00      0.00        37
     disgust       0.42      0.12      0.19        42
     neutral       0.52      0.6




In [16]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Macro train_f1,▁▁▁▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇██
Macro val_f1,▁▂▂▄▄▅▅▆▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█████
Weighted train_f1,▁▁▁▂▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██
Weighted val_f1,▁▂▂▄▅▆▆▇▆▇▇▇▇▇█▇████▇█████████
train_accuracy,▁▁▁▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▇▇▇▇▇██
train_loss,█▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁
val_accuracy,▁▂▃▆▅▇▆▄▇▇▇█▇▇▇▄▆▆▆▆▆▆▆▅▅▆▆▄▅▅
val_loss,█▇▆▄▃▂▂▃▁▁▁▁▂▁▂▂▂▂▃▃▃▃▄▅▅▅▆▆▆█

0,1
Macro train_f1,0.42607
Macro val_f1,0.30559
Weighted train_f1,0.58792
Weighted val_f1,0.4375
train_accuracy,0.61495
train_loss,1.11596
val_accuracy,0.45627
val_loss,1.58023


In [17]:
torch.save(model, '/kaggle/working/BERT_add_Utt_Level.pth')