# Top-K Evaluation

In [1]:
import os
import torch

import loader as loader
import utility as utility
import transformers as transformers
import v2_trainer as v2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cuda


In [2]:
loader_params = {
    "batch_size": 1,
    "pad_images": False,
    "percent_mask": 0.0,
    "shuffle": True,
    "evaluate": False,
    "place_central": True
}

dataloader = loader.get_dataloader('dict_traindata.txt', loader_params)

In [3]:
def compute_topk(model_paths, k=5):
    print(f"Embedding cosine-similarity at top-{k} samples:")
    for model_filename in model_paths:
        ViT = v2.VisionTransformer.load_model(f'v2_trainer/trained_models_v2/{model_filename}', print_statements=False, device=device)
        ViT = ViT.to(device)
        
        # This will take some time -- it is generating CLS token embeddings for all images in the dataset
        ids_list = []
        images_tensor = []
        cls_tensor = []
        
        with torch.no_grad():
            ViT.eval()
            for i, (ids, u, _, _, _, _) in enumerate(dataloader):
                B, H, W = u.shape
                u = u.to(device)
        
                cls_logits, _, _ = ViT(u, save_attn=False, temperature=1)        
                ids_list.extend(ids)
                cls_tensor.append(cls_logits[:, 0].cpu())
                
            cls_tensor = torch.cat(cls_tensor, dim=0)
    
        itos = dict([(key, value) for key, value in enumerate(ids_list)])
        stoi = dict([(key, value) for value, key in enumerate(ids_list)])
    
        # Eval pipeline
        attempts, correct, tot_correct, avg_sim, nums_checked = 0, 0, 0, 0, 0
        for target_id in ids_list:
            idx = stoi[target_id]
            closest_embeddings, closest_sims = utility.top_k_cosine_similarity(cls_tensor, idx, k+1, largest=True)
            target_problem = target_id[:target_id.find('-')]
        
            attempts += 1
            correct_for_round = 0
            for i, sim_id_num in enumerate(closest_embeddings):
                sim_id = itos[sim_id_num.item()]
                if sim_id != target_id:
                    avg_sim += closest_sims[i].item()
                    nums_checked += 1
                    if sim_id[:sim_id.find('-')] == target_problem:
                        correct_for_round += 1
            tot_correct += correct_for_round
            correct += min(correct_for_round, 1) 
        
        print(f"Model: {model_filename:>35} || Total Accuracy: {correct/attempts:.3f} | Avg. Sim: {avg_sim/nums_checked:.3f} | Total Correct Accuracy {tot_correct/attempts:.3f}")

# pth_files = [file for file in os.listdir('trained_models/') if file.endswith('.pth')]
pth_files = ['vit_12-3-24_100k_v1.pth', 'vit_12-3-24_100k_v2.pth']
compute_topk(pth_files, k=5)

Embedding cosine-similarity at top-5 samples:


  checkpoint = torch.load(path, map_location=torch.device(device))


Model:             vit_12-3-24_100k_v1.pth || Total Accuracy: 0.338 | Avg. Sim: 0.969 | Total Correct Accuracy 0.512
Model:             vit_12-3-24_100k_v2.pth || Total Accuracy: 0.288 | Avg. Sim: 0.967 | Total Correct Accuracy 0.411


# Evaluate TorchVision Models for Comparison

In [4]:
from torchvision.models import resnet18
from torchvision.models import resnet50
from torchvision.models import mobilenet_v2
from torchvision.models import squeezenet1_0
from torchvision.models import efficientnet_b0
from torchvision.models import shufflenet_v2_x1_0

import torch.nn.functional as F
import matplotlib.pyplot as plt

In [5]:
COLOR_TO_HEX = {
    -1: '#FF6700',  # blaze orange
    0:  '#000000',  # black
    1:  '#1E93FF',  # blue
    2:  '#F93C31',  # orange
    3:  '#4FCC30',  # green
    4:  '#FFDC00',  # yellow
    5:  '#999999',  # grey
    6:  '#E53AA3',  # pink
    7:  '#FF851B',  # light orange
    8:  '#87D8F1',  # cyan
    9:  '#921231',  # red
    10: '#555555',  # border
    11: '#FF6700',  # active grid border
    12: '#D2B48C',  # image padding
}

def hex_to_rgb(hex_color):
    """ Convert a hex color to an RGB tuple with values in the range [0, 1]. """
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) / 255.0 for i in (0, 2, 4))

def get_embedding(images, encoder, display=True):
    assert images.dim() == 4, "Input images must be a 4D tensor with shape (B x N x H x W)."
    device = next(encoder.parameters()).device
    images = images.to(device)  # Move images to the same device as the encoder
    batch_size, _, height, width = images.shape
    mapped_images = torch.zeros((batch_size, 3, height, width), dtype=torch.float32, device=device)  
    for b in range(batch_size):
        single_image = images[b, 0]  # Extract single-channel image
        for y in range(height):
            for x in range(width):
                rgb_color = torch.tensor(hex_to_rgb(COLOR_TO_HEX[int(single_image[y, x])]),
                                         dtype=torch.float32, device=device)
                mapped_images[b, :, y, x] = rgb_color
    resized_images = F.interpolate(mapped_images, size=(224, 224), mode='nearest')
    embeddings = encoder(resized_images)
    if display:
        image_to_display = resized_images[0].permute(1, 2, 0).cpu()  # Move channels to last dimension for display
        plt.imshow(image_to_display)
        plt.axis("off")
        plt.show()
    return embeddings

In [6]:
torch_models = {
    # "ResNet18": resnet18,
    # "ResNet50": resnet50,
    # "MobileNet_v2": mobilenet_v2,
}

for model_name, model_fn in torch_models.items():
    model = model_fn(pretrained=True)
    if "ResNet" in model_name:
        model = torch.nn.Sequential(*list(model.children())[:-1])  # Remove FC layer
    elif "MobileNet" in model_name:
        model = torch.nn.Sequential(model.features, torch.nn.AdaptiveAvgPool2d((1, 1)))  # Use features, add pooling
    elif "EfficientNet" in model_name:
        model = torch.nn.Sequential(model.features, torch.nn.AdaptiveAvgPool2d((1, 1)))  # Use features, add pooling

    model.to(device)

    ids_list = []
    images_tensor = []
    cls_tensor = []
    
    with torch.no_grad():
        model.eval()
        for i, (ids, u, u_masks, v, v_masks, compute_patch) in enumerate(dataloader):
            B, H, W = u.shape
            u = u.to(device).unsqueeze(0)
            embs = get_embedding(u, model, False)
            embs = embs.reshape(1, -1)
            ids_list.extend(ids)
            cls_tensor.append(embs.cpu())
        cls_tensor = torch.cat(cls_tensor, dim=0)
    
    itos = dict([(key, value) for key, value in enumerate(ids_list)])
    stoi = dict([(key, value) for value, key in enumerate(ids_list)])
    k = 5
    
    attempts, correct, avg_sim, nums_checked, tot_correct = 0, 0, 0, 0, 0
    for target_id in ids_list:
        idx = stoi[target_id]
        closest_embeddings, closest_sims = utility.top_k_cosine_similarity(cls_tensor, idx, k+1, largest=True)
        target_problem = target_id[:target_id.find('-')]
    
        attempts += 1
        correct_for_round = 0
        for i, sim_id_num in enumerate(closest_embeddings):
            sim_id = itos[sim_id_num.item()]
            if sim_id != target_id:
                avg_sim += closest_sims[i].item()
                nums_checked += 1
                if sim_id[:sim_id.find('-')] == target_problem:
                    correct_for_round += 1
                    
        tot_correct += correct_for_round
        correct += min(correct_for_round, 1) 
    
    print(f"Model: {model_name:>25} || Total Accuracy: {correct/attempts:.3f} | Avg. Sim: {avg_sim/nums_checked:.3f} | Total Correct Accuracy {tot_correct/attempts:.3f}")