In [1]:
import torch
import numpy as np
import librosa
from sklearn.metrics.pairwise import euclidean_distances
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

from hparams import create_hparams
from test_data_class import TestTextMelCollate, TestTextMelLoader
from model_utils import load_generator, load_model_checkpoint, parse_test_batch_generator

In [2]:
def calculate_mcd(generated_mels, ground_truth_mels, mfcc_dim=25, exclude_zeroth=True):
    mcd_scores = []
    for mel1, mel2 in tqdm(
        zip(generated_mels, ground_truth_mels), total=len(generated_mels)
    ):
        # Compute the MFCCs
        mfcc1 = librosa.feature.mfcc(
            S=librosa.power_to_db(mel1),
            n_mfcc=mfcc_dim,
            n_fft=2048,
            hop_length=200,
            win_length=800,
            window="hann",
            center=True,
            pad_mode="reflect",
        )
        mfcc2 = librosa.feature.mfcc(
            S=librosa.power_to_db(mel2),
            n_mfcc=mfcc_dim,
            n_fft=2048,
            hop_length=200,
            win_length=800,
            window="hann",
            center=True,
            pad_mode="reflect",
        )

        if exclude_zeroth:
            mfcc1 = mfcc1[4:]
            mfcc2 = mfcc2[4:]

        mcd_value = euclidean_distances(mfcc1, mfcc2)
        mcd_scores.append(mcd_value)

    return np.mean(mcd_scores) / 100

In [3]:
def pad_mel_list(gen_list, truth_list):
    mel_list = gen_list + truth_list
    all_mels = [mel for batch in mel_list for mel in batch]
    gen_mels = [mel for batch in gen_list for mel in batch]
    truth_mels = [mel for batch in truth_list for mel in batch]

    max_length = max(mel.size(1) for mel in all_mels)
    padded_gen = [
        F.pad(mel, (0, max_length - mel.size(1)), "constant", 0) for mel in gen_mels
    ]
    padded_truth = [
        F.pad(mel, (0, max_length - mel.size(1)), "constant", 0) for mel in truth_mels
    ]

    return padded_gen, padded_truth

In [4]:
hparams = create_hparams()
test_dataset = TestTextMelLoader("./data/ESD/test.csv", hparams)
collate_fn = TestTextMelCollate(hparams.n_frames_per_step)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=collate_fn,
    shuffle=False,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [5]:
model = load_generator(hparams, device)
model_file = "/home/xzodia/dev/emo-gan/outputs/train_gan/train_6/generator/checkpoint_epoch_100.pth.tar"
model = load_model_checkpoint(model_file, model)

In [6]:
ground_truth_mels = []
generated_mels = []

model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y, z = parse_test_batch_generator(batch)

        (
            text_padded,
            _,
            emotion_vectors,
            _,
            _,
            
        ) = x

        real_mel, _, _ = y

        mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model.inference(
            text_padded, emotion_vectors
        )

        ground_truth_mels.append(real_mel)
        generated_mels.append(mel_outputs_postnet)

 50%|█████     | 80/160 [00:12<00:21,  3.65it/s]



100%|██████████| 160/160 [00:22<00:00,  7.19it/s]


In [7]:
generated_mels_padded, ground_truth_mels_padded = pad_mel_list(
    generated_mels, ground_truth_mels
)

In [8]:
# convert to numpy
ground_truth_mels_np = [mel.cpu().numpy() for mel in ground_truth_mels_padded]
generated_mels_np = [mel.cpu().numpy() for mel in generated_mels_padded]

In [9]:
mcd = calculate_mcd(generated_mels_np, ground_truth_mels_np)
print(f"MCD: {mcd}")

100%|██████████| 160/160 [00:00<00:00, 297.36it/s]

MCD: 6.199351196289062



