In [None]:
import torch
import torch.nn as nn
import json
import pickle
from load_pickel import *
from dataset import LSTM_Dataset_ERC
from torch.utils.data import DataLoader
from model import LSTM_Model_ERC
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
train_data = json.load(open("../dataset/train.json"))
test_data = json.load(open("../dataset/test.json"))
val_data = json.load(open("../dataset/val.json"))

In [None]:
train_dataset = LSTM_Dataset_ERC(train_data)
val_dataset = LSTM_Dataset_ERC(val_data)

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)

In [None]:
model = LSTM_Model_ERC(768, 250, 4).to(device)

In [None]:
freq_emotion = [0] * 7
for conversation in train_data:
    for utterance in conversation['conversation']:
        freq_emotion[emotion2int[utterance['emotion']]] += 1
weight_to_labels = [1 / freq_emotion[i] for i in range(7)]
weight_to_labels = torch.tensor(weight_to_labels).to(device) # weight to balance the class. This helps to avoid the model to predict the most frequent class

In [None]:
criterion = nn.CrossEntropyLoss(weight=weight_to_labels,ignore_index=7) # for ignoring padding token in loss calculation
optimizer = AdamW(model.parameters(), lr=1e-4)
num_epochs = 100

In [None]:
best_epoch = -1
best_val_loss = 1e9
best_train_loss = 1e9
best_val_classification_report = None
best_train_classification_report = None

In [None]:
wandb.login()

In [None]:
wandb.init(
    project="NLP-Project",
    name="LSTM_ERC",
    config={
        "model": "LSTM with classificaiton head",
        "task": "ERC",
        "optimizer": "AdamW",
        "learning_rate": 1e-4,
        "batch_size": 4,
        "num_epochs": 100
    }
)

In [None]:
for epoch in (range(num_epochs)):
    model.train()
    train_loss = 0.0
    train_labels_pred, train_labels_true, train_labels_attention_mask = [], [], []
    for batch in tqdm(train_dataloader):
        texts, emotion_indices, attention_masks = batch['text'].to(device), batch['emotion_labels'].to(device), batch['attention_mask'].to(device)
        emotion_logits = model(texts)
        emotion_logits = emotion_logits.view(-1, emotion_logits.size(-1))
        emotion_indices = emotion_indices.view(-1)
        attention_masks = attention_masks.view(-1).float()
        loss = criterion(emotion_logits, emotion_indices)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        predicted_emotions = torch.argmax(emotion_logits, dim=1)
        train_labels_pred.extend(predicted_emotions.cpu().numpy())
        train_labels_true.extend(emotion_indices.cpu().numpy())
        train_labels_attention_mask.extend(attention_masks.cpu().numpy())

    model.eval()
    val_loss = 0.0
    val_labels_pred, val_labels_true, val_labels_attention_mask = [], [], []

    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            text, emotion_indices, attention_masks = batch['text'].to(device), batch['emotion_labels'].to(device), batch['attention_mask'].to(device)
            emotion_logits = model(text)
            emotion_logits = emotion_logits.view(-1, emotion_logits.size(-1))
            emotion_indices = emotion_indices.view(-1)

            attention_masks = attention_masks.view(-1)   

            loss = criterion(emotion_logits, emotion_indices)
            
            val_loss += loss.item()

            predicted_emotions = torch.argmax(emotion_logits, dim=1)
            val_labels_pred.extend(predicted_emotions.cpu().numpy())
            val_labels_true.extend(emotion_indices.cpu().numpy())
            val_labels_attention_mask.extend(attention_masks.cpu().numpy())

    val_labels_pred_filtered = [val_labels_pred[i] for i in range(len(val_labels_attention_mask)) if val_labels_attention_mask[i] == 1]
    val_labels_true_filtered = [val_labels_true[i] for i in range(len(val_labels_attention_mask)) if val_labels_attention_mask[i] == 1]

    train_labels_pred_filtered = [train_labels_pred[i] for i in range(len(train_labels_attention_mask)) if train_labels_attention_mask[i] == 1]
    train_labels_true_filtered = [train_labels_true[i] for i in range(len(train_labels_attention_mask)) if train_labels_attention_mask[i] == 1]

    train_loss = train_loss / len(train_dataloader)
    val_loss = val_loss / len(val_dataloader)

    training_classification_rep = classification_report(train_labels_true_filtered, train_labels_pred_filtered, zero_division=0)
    validation_classification_rep = classification_report(val_labels_true_filtered, val_labels_pred_filtered, zero_division=0)
    
    training_classification_rep_dict = classification_report(train_labels_true_filtered, train_labels_pred_filtered, output_dict=True, zero_division=0)
    validation_classification_rep_dict = classification_report(val_labels_true_filtered, val_labels_pred_filtered, output_dict=True, zero_division=0)
    
    wandb.log({
        "Train Loss": train_loss,
        "Validation Loss": val_loss,
        "Train Accuracy": training_classification_rep_dict['accuracy'],
        "Validation Accuracy": validation_classification_rep_dict['accuracy']
    })

    print("Epoch: ", epoch)
    print(f"Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f}")
    print("Training Classification Report: ", training_classification_rep)
    print("Validation Classification Report: ", validation_classification_rep)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_train_loss = train_loss
        best_epoch = epoch
        best_val_classification_report = validation_classification_rep
        best_train_classification_report = training_classification_rep
        torch.save(model, f"ERC_checkpoints/epoch_{epoch}.pth")

print("Best Epoch: ", best_epoch)
print("Best Validation Accuracy: ", best_val_loss)
print("Best Training Loss: ", best_train_loss)
print("Best Validation Classification Report: ", best_val_classification_report)
print("Best Training Classification Report: ", best_train_classification_report)    