In [None]:
import os
import re
from functools import partial

import torch
from torch.nn import MSELoss, HuberLoss
from torch.utils.data import DataLoader

from core.audio_model import AudioModel
from core.sisdr_loss import SISDRLoss
from core.spectro_feature_loss import SpectroFeatureLoss
from utils.device_utils import device_collate_fn, to_device_fn

In [None]:
batch_size = 50
use_mps = False
use_cuda = True

dataset_path = '../_datasets/test_valentini_clean_noisy_dataset.pt'
model_dir = "../_models"
model_path = model_dir + "/speech_denoiser/speech_denoiser_model.pth"
checkpoint_dir = model_dir + "/speech_denoiser/checkpoints"
checkpoint_format = "checkpoint_v2_.*.pt"

In [None]:
custom_collate_fn = partial(device_collate_fn, use_cuda=use_cuda, use_mps=use_mps)
custom_to_device_fn = partial(to_device_fn, use_cuda=use_cuda, use_mps=use_mps)

In [None]:
def get_file_paths(base_path, file_format):
    file_paths = []

    for root, dirs, files in os.walk(base_path):
        for file_name in files:
            if re.match(file_format, file_name):
                file_paths.append(os.path.join(root, file_name))

    return file_paths

In [None]:
def load_dataset(dataset_file_path, collate_fn):
    dataset = torch.load(dataset_file_path)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

def load_model_from_checkpoint(checkpoint_file_path):
    checkpoint = torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
    state_dict = checkpoint['model_state_dict']
    loaded_model = AudioModel()
    loaded_model.load_state_dict(state_dict)
    loaded_model = custom_to_device_fn(loaded_model)
    return loaded_model

def load_model(model_file_path):
    loaded_model = torch.load(model_file_path, map_location=torch.device('cpu'))
    loaded_model = custom_to_device_fn(loaded_model)
    return loaded_model

In [None]:
def evaluate(model, loader):
    mse_criterion = MSELoss()
    huber_criterion = HuberLoss()
    sisdr_criterion = SISDRLoss()
    spectro_feature_criterion = SpectroFeatureLoss(transform=custom_to_device_fn)

    model.eval()
    mse_total_loss = 0.0
    huber_total_loss = 0.0
    sisdr_total_loss = 0.0
    spectro_feature_total_loss = 0.0

    with torch.no_grad():
        for inputs, targets in loader:
            outputs = model(inputs)

            mse_loss = mse_criterion(outputs, targets)
            mse_total_loss += mse_loss.item() * inputs.size(0)

            huber_loss = huber_criterion(outputs, targets)
            huber_total_loss += huber_loss.item() * inputs.size(0)

            sisdr_loss = sisdr_criterion(outputs, targets)
            sisdr_total_loss += sisdr_loss.item() * inputs.size(0)

            spectro_feature_loss = spectro_feature_criterion(outputs, targets)
            spectro_feature_total_loss += spectro_feature_loss.item() * inputs.size(0)

        mse_total_loss = mse_total_loss / len(loader.dataset)
        huber_total_loss = huber_total_loss / len(loader.dataset)
        sisdr_total_loss = sisdr_total_loss / len(loader.dataset)
        spectro_feature_total_loss = spectro_feature_total_loss / len(loader.dataset)

        return mse_total_loss, huber_total_loss, sisdr_total_loss, spectro_feature_total_loss

In [None]:
loader = load_dataset(dataset_path, custom_collate_fn)
checkpoint_files = get_file_paths(checkpoint_dir, checkpoint_format)

for idx, checkpoint_file in enumerate(checkpoint_files):
    model = load_model_from_checkpoint(checkpoint_file)
    mse_loss, huber_loss, sisdr_loss, spectro_feature_loss = evaluate(model, loader)
    print(f"Epoch: {idx + 1:2}, MSE: {mse_loss:5.3f}, Huber: {huber_loss:5.3f}, SISDR: {sisdr_loss:6.3f}, Spectro Feature: {spectro_feature_loss:5.3f}")