In [1]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from eloguessr import EloGuessr
from utils import load_data, plot_losses

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
data_dir = 'USE PATH TO YOUR PROCESSED DATA'
model_dir = 'USE PATH TO YOUR MODELS'

In [3]:
fnames = ['chess_train_both_medium.pt', 'chess_val_both_medium.pt', 'chess_test_both_medium.pt']

BATCH_SIZE = 1024
_, _, test_dloader = load_data(data_dir, fnames, batch_size=BATCH_SIZE)

In [8]:
models = [torch.load(os.path.join(model_dir, 'bestmodel_unfrozen_all_60.pt')),
          torch.load(os.path.join(model_dir, 'best_no_emb_60.pt')),
          torch.load(os.path.join(model_dir, 'bigmodelunfrozen_99epcs.pt'))
          ]

In [9]:
def eval(model, dataloader, device):
    pred_scores = []
    true_scores = []
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)

            scores = model(inputs)
            pred_scores.append(scores.squeeze())
            true_scores.append(targets.squeeze())

    pred_scores = torch.cat(pred_scores, dim=0)
    true_scores = torch.cat(true_scores, dim=0)

    assert pred_scores.shape == true_scores.shape, f"Shape mismatch: {pred_scores.shape} vs {true_scores.shape}"

    hits = [25, 50, 100, 250, 500]
    hit_percentages = {}
    for hit in hits:
        hits_x = torch.abs(pred_scores - true_scores) <= hit
        hit_percentage = hits_x.sum().item() / len(hits_x) * 100
        hit_percentages[f'Accuracy@{hit}'] = hit_percentage

    for key, value in hit_percentages.items():
        print(f'{key}: {value:.2f}%')

    # Ranges eval
    ranges = [(1000, 1499), (1500, 1999), (2000, 2499), (2500, float('inf'))]
    range_accuracies = {}

    for start, end in ranges:
        mask = (true_scores >= start) & (true_scores < end)
        range_pred = pred_scores[mask]
        range_true = true_scores[mask]
        
        if len(range_true) == 0:
            print(f"No samples in range {start}-{end}")
            continue

        range_hit_percentages = {}
        for hit in hits:
            hits_x = torch.abs(range_pred - range_true) <= hit
            hit_percentage = hits_x.sum().item() / len(hits_x) * 100
            range_hit_percentages[f'Accuracy@{hit}'] = hit_percentage

        range_accuracies[f'{start}-{end}'] = range_hit_percentages

        print(f"\nAccuracy for range {start}-{end} (n={len(range_true)}):")
        for key, value in range_hit_percentages.items():
            print(f'{key}: {value:.2f}%')

    return hit_percentages, range_accuracies

In [10]:
for i, model in enumerate(models):
    print(i)
    eval(model, test_dloader, device)

0


Accuracy@25: 15.24%
Accuracy@50: 29.56%
Accuracy@100: 51.89%
Accuracy@250: 82.18%
Accuracy@500: 94.48%

Accuracy for range 1000-1499 (n=17311):
Accuracy@25: 7.54%
Accuracy@50: 15.15%
Accuracy@100: 31.42%
Accuracy@250: 75.15%
Accuracy@500: 95.60%

Accuracy for range 1500-1999 (n=34395):
Accuracy@25: 12.91%
Accuracy@50: 25.43%
Accuracy@100: 47.47%
Accuracy@250: 81.00%
Accuracy@500: 93.37%

Accuracy for range 2000-2499 (n=29193):
Accuracy@25: 26.92%
Accuracy@50: 51.45%
Accuracy@100: 75.99%
Accuracy@250: 87.96%
Accuracy@500: 95.11%

Accuracy for range 2500-inf (n=8237):
Accuracy@25: 0.00%
Accuracy@50: 0.00%
Accuracy@100: 27.90%
Accuracy@250: 81.53%
Accuracy@500: 94.57%
1
Accuracy@25: 14.71%
Accuracy@50: 28.78%
Accuracy@100: 52.21%
Accuracy@250: 82.45%
Accuracy@500: 94.33%

Accuracy for range 1000-1499 (n=17311):
Accuracy@25: 7.79%
Accuracy@50: 15.79%
Accuracy@100: 32.22%
Accuracy@250: 75.63%
Accuracy@500: 95.62%

Accuracy for range 1500-1999 (n=34395):
Accuracy@25: 13.46%
Accuracy@50: 25.8