In [None]:
import os
from functools import partial

import torch
from torch.nn import L1Loss
from torch.utils.data import DataLoader

from core.abs_loss import AbsLoss
from core.recurrent_attention_model import RecurrentAttentionModel
from utils.device_utils import device_collate_fn, to_device_fn

In [None]:
dataset_test_path = '../_datasets/test_valentini_speech_syllables_dataset.pt'
batch_size = 50
use_mps = False
use_cuda = True

model_dir = '../_models/'
os.makedirs(model_dir, exist_ok=True)
weights_file_name = model_dir + "weights_syllable_counter_model.pth"

In [None]:
custom_collate_fn = partial(device_collate_fn, use_cuda=use_cuda, use_mps=use_mps)
custom_to_device_fn = partial(to_device_fn, use_cuda=use_cuda, use_mps=use_mps)

In [None]:
test_dataset = torch.load(dataset_test_path)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

test_loader_len = len(test_loader.dataset)

print(f"Finished test data preparation, test loader size: {test_loader_len}")

In [None]:
model = RecurrentAttentionModel(256, 64, 2)
state_dict = torch.load(weights_file_name)
model.load_state_dict(state_dict)
custom_to_device_fn(model)

print('Model initialized')

In [None]:
abs_criterion = AbsLoss()
l1_criterion = L1Loss()

print('Criteria initialized')

In [None]:
model.eval()

abs_total_loss = 0.0
l1_total_loss = 0.0

with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = model(inputs)

        abs_loss = abs_criterion(outputs, targets)
        abs_total_loss += abs_loss.item() * inputs.size(0)

        l1_loss = l1_criterion(outputs, targets)
        l1_total_loss += l1_loss.item() * inputs.size(0)

    abs_total_loss = abs_total_loss / len(test_loader.dataset)
    l1_total_loss = l1_total_loss / len(test_loader.dataset)

    print(f"Abs loss on test data: {abs_total_loss:.3f}")
    print(f"L1 loss on test data: {l1_total_loss:.3f}")