In [None]:
import pickle
from tqdm import tqdm

import pandas as pd
import numpy as np

import torch
from torchmetrics import F1Score, ConfusionMatrix

from scipy.special import softmax
from sklearn import metrics

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#print(device)

In [None]:
# parameters
  # PATH: 데이터셋 루트 경로
  # weights: 개별 분류기에 적용될 가중치, [텍스트, 오디오, 센서] 순서

PATH = './dataset/KEMDy20_v1_1'
weights = [3, 2, 1.]

In [None]:
# load dataset
text = pd.read_csv(PATH+'/new/text/text_result.csv')
wav = pd.read_csv(PATH+'/new/wav/wav_result.pkl')
bio = pd.read_pickle(PATH+'/new/sensor/bio_result.pkl')

columns = text.columns
# softmax
#text.iloc[:, 1:] = pd.DataFrame(softmax(text.iloc[:, 1:], axis=1), columns=columns[1:]) # text: soft max 선처리 진행함함
wav.iloc[:, 1:] = pd.DataFrame(softmax(wav.iloc[:, 1:], axis=1), columns=columns[1:])
bio.iloc[:, 1:] = pd.DataFrame(softmax(bio.iloc[:, 1:], axis=1), columns=columns[1:])

final_prediction = np.concatenate((np.expand_dims(text.iloc[:, 1:].to_numpy()*weights[0], axis=0), 
                                   np.expand_dims(wav.iloc[:, 1:].to_numpy()*weights[1], axis=0),
                                   np.expand_dims(bio.iloc[:, 1:].to_numpy()*weights[2], axis=0)*weights[2])) # (3, 2879, 7)
final_prediction = final_prediction.mean(axis=0)
final_emotion = torch.Tensor(final_prediction.argmax(axis=1))


# f1 score, confusion matrix check
f1 = F1Score(task='multiclass', num_classes=7)
confmat = ConfusionMatrix(task='multiclass', num_classes=7)

test_y = pickle.load(open(PATH+'/new/annotation/test_origin.pkl', 'rb')).emotion_id.to_numpy() 
test_y = torch.Tensor(test_y)

f1_score = f1(final_emotion, test_y)
print(f'F1 score: {f1_score:.4f}')
display(confmat(final_emotion, test_y))

print(metrics.classification_report(final_emotion, test_y))

acc = (final_emotion == test_y).sum() / len(final_emotion)
print(f'Accuracy: {acc:.2f}')

F1 score: 0.8768


tensor([[   1,    0,    0,    1,   36,    0,    0],
        [   0,    2,    1,    0,   12,    0,    0],
        [   0,    0,    3,    0,    3,    0,    0],
        [   2,    0,    0,   25,  166,    1,    0],
        [   4,    2,    2,   38, 2256,    6,    2],
        [   0,    0,    0,    1,   21,    2,    0],
        [   0,    0,    0,    0,   22,    2,    3]])

              precision    recall  f1-score   support

         0.0       0.03      0.14      0.04         7
         1.0       0.13      0.50      0.21         4
         2.0       0.50      0.50      0.50         6
         3.0       0.13      0.38      0.19        65
         4.0       0.98      0.90      0.93      2516
         5.0       0.08      0.18      0.11        11
         6.0       0.11      0.60      0.19         5

    accuracy                           0.88      2614
   macro avg       0.28      0.46      0.31      2614
weighted avg       0.95      0.88      0.91      2614

Accuracy: 0.88
