In [1]:
import torch
import librosa
import numpy as np
from pesq import pesq
from tqdm import tqdm
from torch.utils.data import DataLoader

from inference_utils import recover_wav
from hparams import create_hparams
from test_data_class import TestTextMelCollate, TestTextMelLoader
from model_utils import load_generator, load_model_checkpoint, parse_test_batch_generator

In [2]:
def evaluate_audio(audio_ref, audio_deg, sr=16000):
    """Evaluate two sets of audio signals using PESQ."""
    audio_ref = audio_ref.astype(np.float32)
    audio_deg = audio_deg.astype(np.float32)
    pesq_score = pesq(sr, audio_ref, audio_deg, "nb")

    rms = np.mean(librosa.feature.rms(y=audio_deg))
    spectral_flatness = np.mean(librosa.feature.spectral_flatness(y=audio_deg))
    zero_crossing_rate = np.mean(librosa.feature.zero_crossing_rate(y=audio_deg))
    score = rms * 0.3 + (1 - spectral_flatness) * 0.4 + (1 - zero_crossing_rate) * 0.3
    min_score = 0
    max_score = 1
    mos = 1 + (score - min_score) / (max_score - min_score) * 4 / 10

    return pesq_score, mos

In [3]:
hparams = create_hparams()
test_dataset = TestTextMelLoader("./data/ESD/test.csv", hparams)
collate_fn = TestTextMelCollate(hparams.n_frames_per_step)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=collate_fn,
    shuffle=False,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
model = load_generator(hparams, device)
model_file = "/home/xzodia/dev/emo-gan/outputs/train_gan/train_6/generator/checkpoint_epoch_100.pth.tar"
model = load_model_checkpoint(model_file, model)

In [5]:
pesq_scores = []
mos_scores = []

model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y, z = parse_test_batch_generator(batch)

        (
            text_padded,
            _,
            emotion_vectors,
            _,
            _,
        ) = x

        real_mel, _, _ = y

        mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model.inference(
            text_padded, emotion_vectors
        )

        ref_audio = recover_wav(real_mel, real_mel)
        gen_audio = recover_wav(real_mel, mel_outputs_postnet)

        pesq_score, mos_score = evaluate_audio(ref_audio, gen_audio)
        pesq_scores.append(pesq_score)
        mos_scores.append(mos_score)

100%|██████████| 160/160 [01:44<00:00,  1.54it/s]


In [6]:
print("PESQ score:", np.mean(pesq_scores))
print("MOS score:", np.mean(mos_scores))

PESQ score: 1.3011498846113683
MOS score: 3.9971423468175984
