In [4]:
%cd ..

/home/wu/repo/gesture-wgan


In [5]:
import numpy as np
import pickle
from tools.takekuchi_dataset_tool.rot_to_pos import rot2pos
from metrics import calculate_kde_score, calculate_mae, calculate_frechet_distance

np.random.seed(1)

In [6]:
def calculate_derivative(x: np.array) -> np.array:
    x_prev = x[:-1]
    x_next = x[1:]
    der = x_next - x_prev
    return der

In [10]:
def evaluate_kde(motions_pred: list[np.array], motions_true: list[np.array], pca=False, dim_pca=10):

    vels_pred = list(map(calculate_derivative, motions_pred))
    vels_true = list(map(calculate_derivative, motions_true))

    accs_pred = list(map(calculate_derivative, vels_pred))
    accs_true = list(map(calculate_derivative, vels_true))

    pose_kde = calculate_kde_score(motions_pred, motions_true, pca=pca, dim_pca=dim_pca)
    vel_kde = calculate_kde_score(vels_pred, vels_true, pca=pca, dim_pca=dim_pca)
    acc_kde = calculate_kde_score(accs_pred, accs_true, pca=pca, dim_pca=dim_pca)

    return pose_kde, vel_kde, acc_kde

In [6]:
# Load data
with open('data/takekuchi/processed/prosody_hip/Y_dev.p', 'rb') as f:
    Y_dev = pickle.load(f)

with open('data/takekuchi/processed/prosody_hip/Y_test.p', 'rb') as f:
    Y_test = pickle.load(f)

In [15]:
"""Mismatch"""

n_samples = 45
ratios = [0, 0.2, 0.5, 1.0]
data_mis = Y_test[:]

pose_scores, vel_scores, acc_scores = [], [], []

for r in ratios:

    n_samples_mis = int(n_samples * r)
    idxs = sorted(np.random.permutation(n_samples_mis))

    for i in idxs:
        data_mis[i] = Y_dev[i]

    pose_score, vel_score, acc_score = evaluate_kde(data_mis, Y_test, pca=True, dim_pca=10)

    pose_scores.append(pose_score)
    vel_scores.append(vel_score)
    acc_scores.append(acc_score)

KDE best bandwidth: 10.0
ll: -38.32777060661347, std: 1.6974812549759781, se: 0.016062653260935908
KDE best bandwidth: 2.3357214690901213
ll: -22.42603373976402, std: 2.840227740103801, se: 0.02693036230513999
KDE best bandwidth: 2.9763514416313175
ll: -24.43620514771498, std: 2.6016891932391335, se: 0.024718649110516255
KDE best bandwidth: 10.0
ll: -41.74997208319298, std: 7.052872806005663, se: 0.06673879316443836
KDE best bandwidth: 2.3357214690901213
ll: -23.274670772130342, std: 5.486596216463335, se: 0.05202259728861316
KDE best bandwidth: 2.9763514416313175
ll: -24.863096486879872, std: 4.037474079363285, se: 0.038360041360794676
KDE best bandwidth: 10.0
ll: -47.29144535598103, std: 18.78099118498652, se: 0.17771774999977916
KDE best bandwidth: 1.8329807108324356
ll: -26.393258259151768, std: 103.98959338065792, se: 0.9860045327220686
KDE best bandwidth: 2.3357214690901213
ll: -27.618803112597273, std: 99.46349723063429, se: 0.9450026904589107
KDE best bandwidth: 10.0
ll: -47.72