In [1]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from eloguessr import EloGuessr
from utils import load_data, plot_losses

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
data_dir = '/Home/siv33/vbo084/EloGuessr/data/processed'
model_dir = '/Home/siv33/vbo084/EloGuessr/models/final_models'

In [3]:
model = torch.load(os.path.join(model_dir, 'model_both_medium_29epcs.pt'))

fnames = ['chess_train_both_medium.pt', 'chess_val_both_medium.pt', 'chess_test_both_medium.pt']

BATCH_SIZE = 256
_, _, test_dloader = load_data(data_dir, fnames, batch_size=BATCH_SIZE)

In [4]:
def eval(model, dataloader):
    pred_scores = []
    true_scores = []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)

            scores = model(inputs)
            pred_scores.append(scores.squeeze())  # Ensure scores are in the right shape
            true_scores.append(targets.squeeze())  # Ensure targets are in the right shape

    pred_scores = torch.cat(pred_scores, dim=0)
    true_scores = torch.cat(true_scores, dim=0)
    print(f'Pred scores before denorm: {pred_scores[:5]}')

    assert pred_scores.shape == true_scores.shape, f"Shape mismatch: {pred_scores.shape} vs {true_scores.shape}"

    hits = [25, 50, 100, 250]
    hit_percentages = {}
    for hit in hits:
        hits_x = torch.abs(pred_scores - true_scores) <= hit
        hit_percentage = hits_x.sum().item() / len(hits_x) * 100
        hit_percentages[f'Hits@{hit}'] = hit_percentage

    for key, value in hit_percentages.items():
        print(f'{key}: {value:.2f}%')

    return hit_percentages

In [5]:
evals = eval(model, test_dloader)

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


Pred scores before denorm: tensor([1020.9497, 1020.9500, 1020.9498, 1020.9500, 1020.9500],
       device='cuda:0')
Hits@25: 0.04%
Hits@50: 0.10%
Hits@100: 0.29%
Hits@250: 2.63%
