In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from eloguessr import EloGuessr
from utils import load_data, plot_losses

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
data_dir = '/Home/siv33/vbo084/EloGuessr/data/processed'
model_dir = '/Home/siv33/vbo084/EloGuessr/models/final_models'

In [None]:
model = torch.load(os.path.join(model_dir, 'model_both_medium_final_68.pt'))

fnames = ['chess_train_both_medium.pt', 'chess_val_both_medium.pt', 'chess_test_both_medium.pt']

BATCH_SIZE = 1024
_, _, test_dloader = load_data(data_dir, fnames, batch_size=BATCH_SIZE)

In [None]:
def eval(model, dataloader, device):
    pred_scores = []
    true_scores = []
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)

            scores = model(inputs)
            pred_scores.append(scores.squeeze())  # Ensure scores are in the right shape
            true_scores.append(targets.squeeze())  # Ensure targets are in the right shape

    pred_scores = torch.cat(pred_scores, dim=0)
    true_scores = torch.cat(true_scores, dim=0)

    assert pred_scores.shape == true_scores.shape, f"Shape mismatch: {pred_scores.shape} vs {true_scores.shape}"

    hits = [25, 50, 100, 250, 500]
    hit_percentages = {}
    for hit in hits:
        hits_x = torch.abs(pred_scores - true_scores) <= hit
        hit_percentage = hits_x.sum().item() / len(hits_x) * 100
        hit_percentages[f'Accuracy@{hit}'] = hit_percentage

    for key, value in hit_percentages.items():
        print(f'{key}: {value}%')

    return hit_percentages

In [None]:
models = [torch.load(os.path.join(model_dir, 'model_both_medium_final_68.pt')),
          torch.load(os.path.join(model_dir, 'model_both_medium_final_61.pt')),
          torch.load(os.path.join(model_dir, 'eloguessr_earlystop_58epcs.pt'))]

In [6]:
for i, model in enumerate(models):
    print(i)
    eval(model, test_dloader, device)