# Top-K Evaluation

In [1]:
import os
import torch

import loader as loader
import utility as utility
import transformers as transformers

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cuda


In [2]:
loader_params = {
    "batch_size": 1,
    "pad_images": False,
    "percent_mask": 0.0,
    "shuffle": True,
    "place_central": True
}

dataloader = loader.get_dataloader('dict_traindata.txt', loader_params)

In [3]:
pth_files = [file for file in os.listdir('trained_models/') if file.endswith('.pth')]
k = 5
print(f"Embedding cosine-similarity at top-{k} samples:")

for model_filename in pth_files:
    ViT = transformers.VisionTransformer.load_model(f'trained_models/{model_filename}', print_statements=False)
    ViT = ViT.to(device)
    
    # This will take some time -- it is generating CLS token embeddings for all images in the dataset
    ids_list = []
    images_tensor = []
    cls_tensor = []
    
    with torch.no_grad():
        ViT.eval()
        for i, (ids, u, u_masks, v, v_masks) in enumerate(dataloader):
            B, H, W = u.shape
            u = u.to(device)
    
            cls_logits, _, _ = ViT(u, save_attn=False, temperature=1)
    
            ids_list.extend(ids)
            cls_tensor.append(cls_logits[:, 0].cpu())
            
        cls_tensor = torch.cat(cls_tensor, dim=0)

    itos = dict([(key, value) for key, value in enumerate(ids_list)])
    stoi = dict([(key, value) for value, key in enumerate(ids_list)])

    # Eval pipeline
    attempts, correct, avg_sim, nums_checked = 0, 0, 0, 0
    for target_id in ids_list:
        idx = stoi[target_id]
        closest_embeddings, closest_sims = utility.top_k_cosine_similarity(cls_tensor, idx, k+1, largest=True)
        target_problem = target_id[:target_id.find('-')]
    
        attempts += 1
        correct_for_round = 0
        for i, sim_id_num in enumerate(closest_embeddings):
            sim_id = itos[sim_id_num.item()]
            if sim_id != target_id:
                avg_sim += closest_sims[i].item()
                nums_checked += 1
                if sim_id[:sim_id.find('-')] == target_problem:
                    correct_for_round += 1
        correct += min(correct_for_round, 1) 
    
    print(f"Model: {model_filename:>35} || Total Accuracy: {correct/attempts:.3f} | Avg. Sim: {avg_sim/nums_checked:.3f}")

Embedding cosine-similarity at top-5 samples:
Model:       vit_20241108_sinusoid100k.pth || Total Accuracy: 0.477 | Avg. Sim: 0.993
Model:        vit_20241110_sinusoid75k.pth || Total Accuracy: 0.495 | Avg. Sim: 0.998
Model:             vit_20241117_231820.pth || Total Accuracy: 0.533 | Avg. Sim: 0.994
Model:            vit_20241117_rope84k.pth || Total Accuracy: 0.458 | Avg. Sim: 1.000
Model:        vit_20241117_sinusoid13k.pth || Total Accuracy: 0.477 | Avg. Sim: 0.995
Model:    vit_20241117_sinusoid500k_v1.pth || Total Accuracy: 0.494 | Avg. Sim: 0.999
Model:    vit_20241117_sinusoid500k_v2.pth || Total Accuracy: 0.527 | Avg. Sim: 0.999
Model:     vit_20241117_sinusoid60k_vF.pth || Total Accuracy: 0.544 | Avg. Sim: 0.997
Model:    vit_20241118_sinusoid125k_v1.pth || Total Accuracy: 0.562 | Avg. Sim: 0.997
Model:    vit_20241118_sinusoid125k_vF.pth || Total Accuracy: 0.566 | Avg. Sim: 0.998
