In [29]:
import math
import random
import pickle
import time
import datetime
import itertools
from tqdm import tqdm

import librosa
import librosa.display

import numpy as np
import pandas as pd
from sklearn import metrics

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
from torchmetrics import F1Score, ConfusionMatrix
from torchsummary import summary

from wav_preprocess import *
from wav_classifier import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# evaluation functions
def valid_model(model, dataset, batch_size, loss_func, device=device):
  model.eval()
  dl = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  f1_func = F1Score(task='multiclass', num_classes=7).to(device)

  with torch.no_grad():
    val_loss = 0
    val_f1 = 0
    val_outputs = torch.empty(0)
    for xb, yb in dl:
      xb, yb = xb.to(device), yb.to(device)
      outputs = model(xb)
      val_loss += loss_func(outputs, yb).item() / len(dl)

      val_f1 += f1_func(outputs, yb) / len(dl)

    return val_loss, val_f1.item()

def test_model(model, x, y, batch_size, device=device):
  model.eval()
  dl = DataLoader(TensorDataset(x, y), batch_size=batch_size, shuffle=False)

  with torch.no_grad():
    val_f1 = 0
    val_outputs = torch.empty(0).to(device)
    for xb, yb in dl:
      xb, yb = xb.to(device), yb.to(device)
      outputs = model(xb)
      val_outputs = torch.cat((val_outputs, outputs), 0)

    return val_outputs


# 모델 결과 저장 및 성능 확인 
def save_modelFile(path, x, y, model, annotation, batch_size, fold='train', device=device):
  f1_func = F1Score(task='multiclass', num_classes=7).to(device)
  confmat = ConfusionMatrix(task='multiclass', num_classes=7).to(device)

  model.eval()
  all_outputs = torch.empty(0).to(device)
  with torch.no_grad():
    dl = DataLoader(TensorDataset(x, y), batch_size=batch_size, shuffle=False)
    for xb, yb in tqdm(dl):
      xb, yb = xb.to(device), yb.to(device)
      outputs = model(xb)
      all_outputs = torch.cat((all_outputs, outputs))

  f1_score = f1_func(all_outputs, y.to(device))
  conf_mat = confmat(all_outputs.to(device), y.to(device))

  # save
  anno = annotation.segment_id.tolist()
  if path:
      outputs_csv = pd.DataFrame(all_outputs.cpu().numpy(), columns=['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise'])
      outputs_csv['segment_id'] = anno
      outputs_csv = outputs_csv[['segment_id', 'angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']]
      outputs_csv.to_pickle(path+'.pkl')
  
  return f1_score, conf_mat

### 1. Data Pre-processing

In [2]:
###############################
######### Parameters ##########
# annotation: 학습 데이터 셋의 session_id에 따른 감정 레이블 확인 데이터 -- 공통 전처리 파일 이후 저장된 데이터임
# aug_dict: 감정 클래스 id별 증강 배율 
# wav_PATH: directory of audio files (~~/wav/session/...)
# save_PATH: directory to save (augment dataset)
###############################

aug_dict = {0: 90, 1: 200, 2: 260, 3: 10, 4: 1, 5: 100, 6: 90}  # (neutral 클래스에 맞게 증강할 경우 대비)
PATH = '../dataset/KEMDy20_v1_1'
save_PATH = '../dataset/KEMDy20_v1_1/new/wav'

In [15]:
annotation = pd.read_pickle(PATH+'/new/annotation/all_annotation.pkl')

dl = data_load(annotation, data_path=PATH+'/wav')
mfcc_train_aug_x, mfcc_train_aug_y = dl.get_data(method='mfcc', fold='train', aug=True, aug_dict=aug_dict)
mfcc_train_x, mfcc_train_y = dl.get_data(method='mfcc', fold='train', aug=False)

print(mfcc_train_aug_x.shape, mfcc_train_aug_y.shape)
print(mfcc_train_x.shape, mfcc_train_y.shape)

##  Happy 데이터 개수에 맞게 데이터 추출 (Happy, Neutral 제외한 클래스)
train_anno = annotation.loc[annotation.fold=='train', :]
happy_num = len(train_anno.loc[train_anno.emotion=='happy'])
aug_number_happy = dict() # 감정별 추가할 데이터 개수
for e_id in train_anno.emotion_id.unique():
    if (e_id != 3) & (e_id != 4): # happy
        emotion_number = len(train_anno.loc[train_anno.emotion_id==e_id])
        aug_number_happy[e_id] = happy_num-emotion_number
#print(aug_number_happy)

# 랜덤하게 추출 (비복원추출)
aug_happy_x = []
aug_happy_y = []
for e in list(aug_number_happy.keys()):
    emotion_index = np.random.choice(np.where(mfcc_train_aug_y.numpy()==e)[0], aug_number_happy[e], replace=False)
    aug_happy_x += mfcc_train_aug_x[emotion_index]
    aug_happy_y += list(itertools.repeat(e, aug_number_happy[e]))
    
aug_happy_x = torch.stack(aug_happy_x)
aug_happy_y = torch.LongTensor(aug_happy_y)

# concat (original + augmentation data)
aug_x_h = torch.cat((mfcc_train_x, aug_happy_x), dim=0)
aug_y_h = torch.cat((mfcc_train_y, aug_happy_y), dim=0)
print(aug_x_h.shape, aug_y_h.shape)

# save (오디오 데이터 전처리 완료)
torch.save((aug_x_h, aug_y_h), save_PATH+'/train_mfcc_happy_14612.pt')
torch.save((mfcc_train_x, mfcc_train_y), save_PATH+'/train_mfcc.pt')

# test data save
mfcc_test_x, mfcc_test_y = dl.get_data(method='mfcc', fold='test')
#print(mfcc_test_x.shape, mfcc_test_y.shape)
torch.save((mfcc_test_x, mfcc_test_y), save_PATH+'/test_mfcc.pt')

torch.Size([14612, 1, 40, 251]) torch.Size([14612])


### 2. Model training

In [18]:
# Load dataset (MFCC)
train_x, train_y = torch.load(save_PATH+'/train_mfcc_happy_14612.pt') # 증강 데이터
#train_origin_x, train_origin_y = torch.load('save_PATH+'/train_mfcc.pt') # 원본 train (성능 확인에 사용)
test_x, test_y = torch.load(save_PATH+'/test_mfcc.pt') 

epochs = 100
batch_size = 64
classifier = wavNet(in_channels=64, num_classes=7, num_features=128).to(device) 

In [None]:
# z-score normalization
train_mu = train_x.mean()
train_std = train_x.std()
train_x = (train_x - train_mu) / train_std
#train_origin_x = (train_origin_x - train_mu) / train_std
test_x = (test_x - train_mu) / train_std

train_avg = TensorDataset(train_x, train_y)

train_len = int(len(train_x)*0.9)
train, valid = random_split(train_avg, [train_len, len(train_x)-train_len])
print('Train set:', len(train), 'Validation set:', len(valid))

train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) 
valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=True)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

start = time.time()
for epoch in tqdm(range(epochs)):
  classifier.train()
  running_loss = 0.0
  correct_prediction = 0
  total_prediction = 0

  for i, (x, y) in enumerate(train_loader):
    torch.save(classifier, save_PATH+'/wav_classifier.pt')
    x, y = x.to(device), y.to(device)

    # gradient initialize
    optimizer.zero_grad()

    outputs = classifier(x)
    loss = criterion(outputs, y)

    # compute gradients of each params
    loss.backward()
    # optimize
    optimizer.step()

    running_loss += loss.item()

    prediction = torch.max(outputs, 1)[1]
    correct_prediction += (prediction == y).sum()
    total_prediction += prediction.shape[0]

  num_batches = len(train_loader)
  avg_loss = running_loss / num_batches
  acc = correct_prediction / total_prediction

  # evaluation using validation data
  classifier.eval()
  with torch.no_grad():
    if (epoch == 0) | ((epoch+1) % 10 == 0):
      val_loss, val_f1 = valid_model(classifier, valid, batch_size, criterion, device)
      print(f'\nEpoch: {epoch+1}, Loss: {avg_loss:.2f}, Accuracy: {acc:.2f}, Val_Loss: {val_loss:.2f}, Val_F1 score: {val_f1:.2f}')
print('Finished Training')

end = time.time()
spen_time = end - start
result = str(datetime.timedelta(seconds=spen_time)).split(".")[0]            
print('Time:', result)

# Evaluation
with torch.no_grad():
  confmat = ConfusionMatrix(task='multiclass', num_classes=7).to(device)
  outputs = test_model(classifier, test_x, test_y, batch_size, device)
  
  conf = confmat(outputs.to(device), test_y.to(device))
  all_metrics = metrics.classification_report(torch.argmax(outputs, 1).cpu(), test_y.cpu())
  
  display(conf)
  print(all_metrics)

In [None]:
# save output 

batch_size = 128

print('--------[Train set]--------')
train_anno = pd.read_pickle(PATH+'/new/annotation/train_origin.pkl')
f1, conf = save_modelFile(path=None,
                          x=train_x, y=train_y, model=classifier, batch_size=batch_size,
                          annotation=train_anno, fold='train')
print(f1)
display(conf)

print('--------[Test set]--------')
test_anno = pd.read_pickle(PATH+'/new/annotation/test_origin.pkl')
f1, conf = save_modelFile(path=save_PATH+'/wav_result',
                          x=test_x, y=test_y, model=classifier, batch_size=batch_size,
                          annotation=test_anno, fold='test')
print(f1)
display(conf)