Emotalk predictions from the wav files inside the network

In [53]:
import os
import numpy as np
import pandas as pd
from pprint import pprint
from pathlib import Path
import sys
import torch
import torch.nn as nn
from torchmetrics.regression import R2Score
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [11]:
path_to_data = "Z:/entered_NAIST_in_2021/Ryosuke_Miyawaki/blendshapes-analysis-tmp/user_data/"
emotalk = "../EmoTalk_release/demo.py"
result_path = "../EmoTalk_release/result/"
usr_folders = [f.path.replace("\\","/") for f in os.scandir(path_to_data) if f.is_dir() and "result" not in f.path] #select all the usr folders
wav_files = [[wav_f.path.replace("\\","/") for wav_f in os.scandir(folder) if ".wav" in wav_f.path and not "mono" in wav_f.path and not "denoise" in wav_f.path] for folder in usr_folders]

In [None]:
wav_files_sorted = [sorted(li) for li in wav_files]
flatten_wav_file_list = [item for sublist in wav_files_sorted for item in sublist]
test_data_files = [f.name for f in os.scandir("whole_test_dataset/")]
#this methods intersect both lists 
wavs = [x for x in flatten_wav_file_list if any(Path(x).name.rsplit("!", 1)[0] in y or y in Path(x).name.rsplit("!", 1)[0] for y in test_data_files)] #rsplit the string starting from the right with a max split of 1, so it removes the date
pprint(wavs)

In [12]:

def run_emotalk(wav_list):
    for wav in wav_list:
        usr_dir = Path(wav).name #split the whole path, get only the name of the wav file
        usr_dir = usr_dir.split("!")[0] #removes the suffix !voice_type!script_number!date.wav
        if not os.path.exists(result_path + usr_dir):
            os.mkdir(result_path + usr_dir)
            print("created folder:", usr_dir)
        r_path = result_path + usr_dir
        print(wav)
        %run "../EmoTalk_release/demo.py" "--wav_path" {wav} "--result_path" {r_path} "--model_path" "../EmoTalk_release/pretrain_model/EmoTalk.pth"


In [71]:
def evaluate_true_positive(prediction, target):
    pred_activated = (prediction != 0)
    activated = (target != 0)
    true_positive = torch.count_nonzero(torch.logical_and(prediction, target))
    non_zero_count_predictions = torch.count_nonzero(pred_activated).item()
    non_zero_count_target = torch.count_nonzero(activated).item()
    activation_precision = true_positive.item()
    if non_zero_count_predictions != 0: 
        activation_precision/=non_zero_count_predictions
    activation_recall = true_positive.item()
    if non_zero_count_target != 0:
        activation_recall/=non_zero_count_target
    return activation_precision, activation_recall

class RMSELoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(RMSELoss, self).__init__()
        self.reduction = reduction

    def forward(self, inp, target):
        if self.reduction == "mean":
            mse = torch.mean((inp - target) ** 2)
            rmse = torch.sqrt(mse + 1e-7)
            return rmse
        elif self.reduction == "none":
            mse = (inp - target) ** 2
            rmse = torch.sqrt(mse+1e-7)
        elif self.reduction == "sum":
            mse = torch.sum((inp - target) ** 2)
            rmse = torch.sqrt(mse + 1e-7)

In [73]:
def results_emotalk_method_gt():
    #### returns a tuple of table of emotalk, prediction from our method, and ground truth
    results = []
    emotalk_folders = [f for f in os.scandir(result_path)]
    emotalk_files = [f for folder in emotalk_folders for f in os.scandir(folder)]
    method_files = [f for f in os.scandir("GRUNetPack_RMSELoss_whole/")]
    gt_files = [f for f in os.scandir("whole_test_dataset")]
    for emo, method, gt in zip(emotalk_files, method_files, gt_files):
        emo_table = np.load(emo)
        method_table = pd.read_csv(method).values
        gt_table = pd.read_csv(gt).values
        if emo_table.shape[0] < method_table.shape[0]:
            method_table = method_table[:emo_table.shape[0]]
            gt_table = gt_table[:emo_table.shape[0]]
        else:
            emo_table = emo_table[:method_table.shape[0]]
        results.append((emo_table, method_table, gt_table))
    return results

def evaluate(criterion, tuple_results):
    emo_loss_array = []
    method_loss_array = []
    for emo, method, gt in tuple_results:
        emo_tensor = torch.from_numpy(emo).to(device)
        method_tensor = torch.from_numpy(method).to(device)
        gt_tensor = torch.from_numpy(gt).to(device)
        emo_loss = criterion(emo_tensor.contiguous().view(-1), gt_tensor.contiguous().view(-1))
        method_loss = criterion(method_tensor.contiguous().view(-1), gt_tensor.contiguous().view(-1))
        emo_loss_array.append(emo_loss.item())
        method_loss_array.append(method_loss.item())
    print(f"emotalk score: {np.mean(emo_loss_array)}")
    print(f"method score: {np.mean(method_loss_array)}")
    return emo_loss_array, method_loss_array

def evaluate_activation_score(tuple_results):
    act_precision_emotalk = []
    act_recall_emotalk = []
    act_precision_method = []
    act_recall_method = []
    for emo, method, gt in tuple_results:
        emo_tensor = torch.from_numpy(emo).to(device)
        method_tensor = torch.from_numpy(method).to(device)
        gt_tensor = torch.from_numpy(gt).to(device)
        act_p, act_r = evaluate_true_positive(emo_tensor, gt_tensor)
        act_precision_emotalk.append(act_p)
        act_recall_emotalk.append(act_r)
        act_p, act_r = evaluate_true_positive(method_tensor, gt_tensor)
        act_precision_method.append(act_p)
        act_recall_method.append(act_r)
    print(f"emotalk\n\t activation precision: {np.mean(act_precision_emotalk)}, activation recall: {np.mean(act_recall_emotalk)}")
    print(f"method\n\t activation precision: { np.mean(act_precision_method)}, activation recall: {np.mean(act_recall_method)}")
    return act_precision_emotalk, act_recall_emotalk, act_precision_method, act_recall_method


In [18]:
tuple_results = results_emotalk_method_gt()

In [75]:
criterion = RMSELoss()
evaluate(criterion, tuple_results)
evaluate_activation_score(tuple_results)

emotalk score: 0.14712120265607453
method score: 0.031193852197046026
emotalk
	 activation precision: 0.6243331395490344, activation recall: 0.971917282703744
method
	 activation precision: 0.9038539956983865, activation recall: 0.8487175011944117


([0.44936541512427286,
  0.46330056598153113,
  0.42642140468227424,
  0.4760298723199229,
  0.5419092505313742,
  0.49152755641305257,
  0.4431104543775397,
  0.45406278855032317,
  0.4393096751412429,
  0.46377882909520995,
  0.46894431158182437,
  0.4975469508335092,
  0.46662972540327546,
  0.4388160325083118,
  0.41778814846972,
  0.4112531295295823,
  0.3639342070072212,
  0.436728587457923,
  0.5153286168358454,
  0.49795989741198415,
  0.457347723240686,
  0.4609551738583997,
  0.5103283898305084,
  0.46484168865435355,
  0.4508974239747047,
  0.3912079793128925,
  0.4520778364116095,
  0.518127532522926,
  0.5028001191540066,
  0.5179039301310043,
  0.35072020951549543,
  0.4891293797922935,
  0.4238336143901068,
  0.40897269180754225,
  0.4152843601895735,
  0.4282568546668561,
  0.48499133073255307,
  0.3892239879364343,
  0.4058588138723579,
  0.40175763182238666,
  0.3683517640863612,
  0.38728297007757667,
  0.439446227929374,
  0.4138350696134748,
  0.4182526298883803,
 