In [None]:
import os
import sys

lib_path = os.path.abspath("").replace("notebooks", "src")
sys.path.append(lib_path)

import torch
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import balanced_accuracy_score, accuracy_score, confusion_matrix
from transformers import BertTokenizer
from data.dataloader import build_train_test_dataset
from tqdm.auto import tqdm
import numpy as np
from models import networks
from transformers import BertTokenizer, RobertaTokenizer


ModuleNotFoundError: No module named 'seaborn'

In [None]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def eval(opt, checkpoint_path, tokenizer):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    network = getattr(networks, opt.model_type)(
                num_classes=opt.num_classes,
                num_attention_head=opt.num_attention_head,
                dropout=opt.dropout,
                text_encoder_type=opt.text_encoder_type,
                text_encoder_dim=opt.text_encoder_dim,
                text_unfreeze=opt.text_unfreeze,
                audio_encoder_type=opt.audio_encoder_type,
                audio_encoder_dim=opt.audio_encoder_dim,
                audio_unfreeze=opt.audio_unfreeze,
                audio_norm_type=opt.audio_norm_type,
                fusion_head_output_type=opt.fusion_head_output_type,
            )
    network.to(device)

    # Build dataset
    _, test_ds = build_train_test_dataset(
        opt.data_root,
        opt.batch_size,
        tokenizer,
        opt.audio_max_length,
        text_max_length=opt.text_max_length,
        audio_encoder_type=opt.audio_encoder_type,
    )
    # Load checkpoint with save_all_states = False
    # network.load_state_dict(torch.load(checkpoint_path, map_location=torch.device(device)).state_dict())
    # Load checkpoint with save_all_states = True
    print(checkpoint_path)
    network.load_state_dict(torch.load(checkpoint_path, map_location=torch.device(device))["state_dict_backbone"])
    network.eval()
    network.to(device)

    y_actu=[]
    y_pred=[]

    for every_test_list in tqdm(test_ds):
        input_ids, audio, label = every_test_list
        input_ids = input_ids.to(device)
        audio = audio.to(device)
        label = label.to(device)
        with torch.no_grad():
            output = network(input_ids,audio)[0]
            _, preds = torch.max(output, 1)
            y_actu.append(label.detach().cpu().numpy()[0])
            y_pred.append(preds.detach().cpu().numpy()[0])
    print(accuracy_score(y_actu, y_pred))
    wa = balanced_accuracy_score(y_actu, y_pred)
    ua = accuracy_score(y_actu, y_pred)
    print("Balanced Accuracy: ", wa)
    print("Accuracy: ", ua)
    cm = confusion_matrix(y_actu, y_pred)
    print(cm)
    cmn = (cm.astype('float') / cm.sum(axis=1)[:, np.newaxis])*100

    ax = plt.subplots(figsize=(8, 5.5))[1]
    sns.heatmap(cmn, cmap='YlOrBr', annot=True, square=True, linecolor='black', linewidths=0.75, ax = ax, fmt = '.2f', annot_kws={'size': 16})
    ax.set_xlabel('Predicted', fontsize=18, fontweight='bold')
    ax.xaxis.set_label_position('bottom')
    ax.xaxis.set_ticklabels(["Anger", "Happiness", "Sadness", "Neutral"], fontsize=16)
    ax.set_ylabel('Ground Truth', fontsize=18, fontweight='bold')
    ax.yaxis.set_ticklabels(["Anger", "Happiness", "Sadness", "Neutral"], fontsize=16)
    plt.tight_layout()
    plt.savefig(opt.name + '.png', format='png', dpi=1200)
    plt.show()

In [None]:
from configs.bert_vggish import Config as roberta_wav2vec2_config
checkpoint_path = "D:/SER_ICIIT_2024/notebooks/checkpoints/bert_vggish_MMSERA/20230927-152624" #check point
opt_path = os.path.join(checkpoint_path,"opt.log")
with open(opt_path, "r") as f:
    data = f.read().split("\n")
    # remove all empty strings
    data = list(filter(None, data))
    # convert to dict
    data_dict ={}
    for i in range(len(data)):
        key, value = data[i].split(":")[0].strip(), data[i].split(":")[1].strip()
        if '.' in value and value.replace('.', '').isdigit():
            value = float(value)
        elif value.isdigit():
            value = int(value)
        elif value == 'True':
            value = True
        elif value == 'False':
            value = False
        elif value == 'None':
            value = None
        data_dict[key] = value
# Load checkpoint with save_all_states = False
ckpt_path = os.path.join(checkpoint_path,"weights/best_acc/checkpoint_0_0.pt")
# Load checkpoint with save_all_states = False
# ckpt_path = os.path.join(checkpoint_path,"weights/best_acc/checkpoint_0.pt")
opt = roberta_wav2vec2_config()
# Replace the default config with the loaded config
for key, value in data_dict.items():
    setattr(opt, key, value)
    
# Set dataset path
opt.data_root="data/IEMOCAP/"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [None]:
eval(opt, ckpt_path, tokenizer)