In [None]:
import os
from functools import partial

import torch
from torch.nn import MSELoss, HuberLoss
from torch.utils.data import DataLoader

from core.audio_model import AudioModel
from core.sisdr_loss import SISDRLoss
from core.spectro_feature_loss import SpectroFeatureLoss
from utils.device_utils import device_collate_fn, to_device_fn

In [None]:
dataset_test_path = '../_datasets/test_valentini_clean_noisy_dataset.pt'
batch_size = 50
use_mps = False
use_cuda = True

model_dir = '../_models/'
os.makedirs(model_dir, exist_ok=True)
weights_file_name = model_dir + "weights_speech_denoiser_model.pth"

In [None]:
custom_collate_fn = partial(device_collate_fn, use_cuda=use_cuda, use_mps=use_mps)
custom_to_device_fn = partial(to_device_fn, use_cuda=use_cuda, use_mps=use_mps)

In [None]:
test_dataset = torch.load(dataset_test_path)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

test_loader_len = len(test_loader.dataset)

print(f"Finished test data preparation, test loader size: {test_loader_len}")

In [None]:
model = AudioModel()
state_dict = torch.load(weights_file_name)
model.load_state_dict(state_dict)
custom_to_device_fn(model)

print('Model initialized')

In [None]:
mse_criterion = MSELoss()
huber_criterion = HuberLoss()
sisdr_criterion = SISDRLoss()
spectro_feature_criterion = SpectroFeatureLoss()

print('Criteria initialized')

In [None]:
model.eval()

mse_total_loss = 0.0
huber_total_loss = 0.0
sisdr_total_loss = 0.0
spectro_feature_total_loss = 0.0

with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = model(inputs)

        mse_loss = mse_criterion(outputs, targets)
        mse_total_loss += mse_loss.item() * inputs.size(0)

        huber_loss = huber_criterion(outputs, targets)
        huber_total_loss += huber_loss.item() * inputs.size(0)

        sisdr_loss = sisdr_criterion(outputs, targets)
        sisdr_total_loss += sisdr_loss.item() * inputs.size(0)

        spectro_feature_loss = spectro_feature_criterion(outputs, targets)
        spectro_feature_total_loss += spectro_feature_loss.item() * inputs.size(0)

    mse_total_loss = mse_total_loss / len(test_loader.dataset)
    huber_total_loss = huber_total_loss / len(test_loader.dataset)
    sisdr_total_loss = sisdr_total_loss / len(test_loader.dataset)
    spectro_feature_total_loss = spectro_feature_total_loss / len(test_loader.dataset)

    print(f"MSE loss on test data: {mse_total_loss:.3f}")
    print(f"Huber loss on test data: {huber_total_loss:.3f}")
    print(f"SISDR loss on test data: {sisdr_total_loss:.3f}")
    print(f"Spectro feature loss on test data: {spectro_feature_total_loss:.3f}")