In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import sys

lib_path = os.path.abspath("").replace("notebooks", "src")
sys.path.append(lib_path)

import torch
from sklearn import svm
from sklearn.metrics import balanced_accuracy_score, accuracy_score,confusion_matrix, f1_score, precision_score, recall_score
from data.dataloader import build_train_test_dataset
from tqdm.auto import tqdm
from models import networks

  from .autonotebook import tqdm as notebook_tqdm


## Eval scripts

In [2]:
from collections import Counter
def calculate_accuracy(y_pred, y_true):
    class_weights = {cls: 1.0/count for cls, count in Counter(y_true).items()}
    wa = balanced_accuracy_score(y_true, y_pred, sample_weight=[class_weights[cls] for cls in y_true])
    ua = accuracy_score(y_true, y_pred)
    return ua, wa

In [4]:
def eval(cfg, checkpoint_path, all_state_dict=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    network = getattr(networks, cfg.model_type)(cfg)
    network.to(device)

    # Build dataset
    _, test_ds = build_train_test_dataset(cfg)
    weight = torch.load(checkpoint_path, map_location=torch.device(device))
    if all_state_dict:
        weight = weight['state_dict_network']
    else:
        weight = weight.state_dict()
    
    network.load_state_dict(weight)
    network.eval()
    network.to(device)

    y_actu=[]
    y_pred=[]

    for every_test_list in tqdm(test_ds):
        input_ids, audio, label = every_test_list
        input_ids = input_ids.to(device)
        audio = audio.to(device)
        label = label.to(device)
        with torch.no_grad():
            output = network(input_ids,audio)[0]
            _, preds = torch.max(output, 1)
            y_actu.append(label.detach().cpu().numpy()[0])
            y_pred.append(preds.detach().cpu().numpy()[0])
    bacc = balanced_accuracy_score(y_actu, y_pred)
    print("Balanced Accuracy: ", bacc)
    ua, wa = calculate_accuracy(y_actu, y_pred)
    print("Unweighted Accuracy: ", ua)
    print("Weighted Accuracy: ", wa)
    ua_f1 = f1_score(y_actu, y_pred, average='macro')
    # mean_f1 = np.mean(f1_score(y_actu, y_pred, average=None))
    # w_f1 = f1_score(y_actu, y_pred, average='weighted')
    # f1 = f1_score(y_actu, y_pred, average='micro')
    # print("Micro F1: ", f1)
    print("Macro F1: ", ua_f1)
    # print("Weighted F1: ", w_f1)
    # print("Mean F1:", mean_f1)
    # ua_precision = precision_score(y_actu, y_pred, average='macro')
    # w_precision = precision_score(y_actu, y_pred, average='weighted')
    # precision = precision_score(y_actu, y_pred, average='micro')
    # mean_precision = np.mean(precision_score(y_actu, y_pred, average=None))
    # print("Micro Precision: ", precision)
    # print("Macro Precision: ", ua_precision)
    # print("Weighted Precision: ", w_precision)
    # print("Mean precision:", mean_precision)
    # ua_recall = recall_score(y_actu, y_pred, average='macro')
    # w_recall = recall_score(y_actu, y_pred, average='weighted')
    # recall = recall_score(y_actu, y_pred, average='micro')
    # print("Micro Recall: ", recall)
    # print("Macro Recall: ", ua_recall)
    # print("Weighted Recall: ", w_recall)
    
    # cm = confusion_matrix(y_actu, y_pred)
    # print("Confusion Matrix: \n", cm)
    # cmn = (cm.astype('float') / cm.sum(axis=1)[:, np.newaxis])*100

    # ax = plt.subplots(figsize=(8, 5.5))[1]
    # sns.heatmap(cmn, cmap='YlOrBr', annot=True, square=True, linecolor='black', linewidths=0.75, ax = ax, fmt = '.2f', annot_kws={'size': 16})
    # ax.set_xlabel('Predicted', fontsize=18, fontweight='bold')
    # ax.xaxis.set_label_position('bottom')
    # ax.xaxis.set_ticklabels(["Anger", "Happiness", "Sadness", "Neutral"], fontsize=16)
    # ax.set_ylabel('Ground Truth', fontsize=18, fontweight='bold')
    # ax.yaxis.set_ticklabels(["Anger", "Happiness", "Sadness", "Neutral"], fontsize=16)
    # plt.tight_layout()
    # # plt.savefig(cfg.name + '.png', format='png', dpi=1200)
    # plt.show()

In [5]:
def eval_svm(cfg, checkpoint_path, all_state_dict=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    network = getattr(networks, cfg.model_type)(cfg)
    network.to(device)

    # Build dataset
    train_ds, test_ds = build_train_test_dataset(cfg)
    weight = torch.load(checkpoint_path, map_location=torch.device(device))
    if all_state_dict:
        weight = weight['state_dict_network']
    else:
        weight = weight.state_dict()
    
    network.load_state_dict(weight)
    network.eval()
    network.to(device)
    
    network2 = getattr(networks, cfg.model_type)(cfg)
    network2.to(device)

    # Get train features
    train_x = []
    train_y = []
    for every_train_list in tqdm(train_ds):
        input_ids, audio, label = every_train_list
        input_ids = input_ids.to(device)
        audio = audio.to(device)
        label = label.to(device)
        with torch.no_grad():
            feature = network(input_ids,audio)[1]
            train_x.append(feature.detach().cpu().numpy()[0])
            train_y.append(label.detach().cpu().numpy()[0])
    
    # SVM
    clf = svm.SVC()
    clf.fit(train_x, train_y)
    
    y_actu=[]
    y_pred=[]

    for every_test_list in tqdm(test_ds):
        input_ids, audio, label = every_test_list
        input_ids = input_ids.to(device)
        audio = audio.to(device)
        label = label.to(device)
        with torch.no_grad():
            feature = network(input_ids,audio)[1]
            preds = clf.predict(feature.detach().cpu().numpy())
            y_actu.append(label.detach().cpu().numpy()[0])
            y_pred.append(preds[0])
    bacc = balanced_accuracy_score(y_actu, y_pred)
    ua, wa = calculate_accuracy(y_actu, y_pred)
    print("Balanced Accuracy: ", bacc)
    print("Unweighted Accuracy: ", ua)
    print("Weighted Accuracy: ", wa)
    
    ua_f1 = f1_score(y_actu, y_pred, average='macro')
    # w_f1 = f1_score(y_actu, y_pred, average='weighted')
    # f1 = f1_score(y_actu, y_pred, average='micro')
    # print("Micro F1: ", f1)
    print("Macro F1: ", ua_f1)
    # print("Weighted F1: ", w_f1)
    # ua_precision = precision_score(y_actu, y_pred, average='macro')
    # w_precision = precision_score(y_actu, y_pred, average='weighted')
    # precision = precision_score(y_actu, y_pred, average='micro')
    # print("Micro Precision: ", precision)
    # print("Macro Precision: ", ua_precision)
    # print("Weighted Precision: ", w_precision)
    # ua_recall = recall_score(y_actu, y_pred, average='macro')
    # w_recall = recall_score(y_actu, y_pred, average='weighted')
    # recall = recall_score(y_actu, y_pred, average='micro')
    # print("Micro Recall: ", recall)
    # print("Macro Recall: ", ua_recall)
    # print("Weighted Recall: ", w_recall)
    
    # cm = confusion_matrix(y_actu, y_pred)
    # print("Confusion Matrix: \n", cm)
    # cmn = (cm.astype('float') / cm.sum(axis=1)[:, np.newaxis])*100

    # ax = plt.subplots(figsize=(8, 5.5))[1]
    # sns.heatmap(cmn, cmap='YlOrBr', annot=True, square=True, linecolor='black', linewidths=0.75, ax = ax, fmt = '.2f', annot_kws={'size': 16})
    # ax.set_xlabel('Predicted', fontsize=18, fontweight='bold')
    # ax.xaxis.set_label_position('bottom')
    # ax.xaxis.set_ticklabels(["Anger", "Happiness", "Sadness", "Neutral"], fontsize=16)
    # ax.set_ylabel('Ground Truth', fontsize=18, fontweight='bold')
    # ax.yaxis.set_ticklabels(["Anger", "Happiness", "Sadness", "Neutral"], fontsize=16)
    # plt.tight_layout()
    # plt.savefig(cfg.name + '.png', format='png', dpi=1200)
    # plt.show()

## Eval

In [None]:
for p1, p2 = zip(model1.parameters(), model2.parameters()):
    if p1.data.ne(p2.data).sum() > 0:
        return False
return True

In [6]:
from configs.base import Config
checkpoint_path = "/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/checkpoints_latest/IEMOCAP/3M-SER/CrossEntropyLoss_bert_vggish_cls/20240129-160154"
cfg_path = os.path.join(checkpoint_path,"cfg.log")
ckpt_path = os.path.join(checkpoint_path,"weights/best_acc/checkpoint_0_0.pt")

cfg = Config()

cfg.load(cfg_path)
# Set dataset path
cfg.data_root="/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/data/IEMOCAP_preprocessed"
# Change to test set
cfg.data_valid="test.pkl"

eval(cfg, ckpt_path)

FileNotFoundError: [Errno 2] No such file or directory: '/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/checkpoints/3M-SER/CrossEntropyLoss_bert_hubert_base_cls/20240127-112454/cfg.log'

In [18]:
from configs.base import Config
checkpoint_path = "/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/checkpoints/3M-SER/CrossEntropyLoss_bert_vggish_cls/20240127-112233"
cfg_path = os.path.join(checkpoint_path,"cfg.log")
ckpt_path = os.path.join(checkpoint_path,"weights/best_acc/checkpoint_0_0.pt")

cfg = Config()

cfg.load(cfg_path)
# Set dataset path
cfg.data_root="/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/data/IEMOCAP_preprocessed"
# Change to test set
cfg.data_valid="test.pkl"

eval(cfg, ckpt_path)

100%|██████████| 554/554 [00:08<00:00, 67.65it/s]

Balanced Accuracy:  0.7702337657537457
Unweighted Accuracy:  0.7671480144404332
Weighted Accuracy:  0.7787450643734654
Macro F1:  0.7701193146838053





In [21]:
from configs.base import Config
checkpoint_path = "/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/checkpoints/3M-SER/CrossEntropyLoss_bert_wav2vec2_base_cls/20240127-112408"
opt_path = os.path.join(checkpoint_path,"cfg.log")
ckpt_path = os.path.join(checkpoint_path,"weights/best_acc/checkpoint_0_0.pt")

cfg = Config()

cfg.load(opt_path)
# Set dataset path
cfg.data_root="/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/data/IEMOCAP_preprocessed"
# Change to test set
cfg.data_valid="test.pkl"

eval(cfg, ckpt_path)

100%|██████████| 554/554 [00:12<00:00, 44.47it/s]

Balanced Accuracy:  0.7476584744596645
Unweighted Accuracy:  0.7364620938628159
Weighted Accuracy:  0.7503535596999569
Macro F1:  0.7435537095207938





In [22]:
from configs.base import Config
checkpoint_path = "/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/checkpoints/3M-SER/CrossEntropyLoss_bert_wavlm_base_cls/20240127-112336"
opt_path = os.path.join(checkpoint_path,"cfg.log")
ckpt_path = os.path.join(checkpoint_path,"weights/best_acc/checkpoint_0_0.pt")

cfg = Config()

cfg.load(opt_path)
# Set dataset path
cfg.data_root="/home/namphuongtran9196/Code/EmotionClassification/code/3m-ser-private/scripts/data/IEMOCAP_preprocessed"
# Change to test set
cfg.data_valid="test.pkl"

eval(cfg, ckpt_path)

100%|██████████| 554/554 [00:13<00:00, 41.96it/s]

Balanced Accuracy:  0.7630128329663103
Unweighted Accuracy:  0.7545126353790613
Weighted Accuracy:  0.7696427097256077
Macro F1:  0.764027027268112



