In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
import pandas as pd
from model import Model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TestSet(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        image_name = row["img_filename"]
        img = Image.open(
            os.path.join(
                "screen_spot_images",
                image_name,
            )
        )
        img = img.convert("RGB")
        img = self.transform(img)
        instruction = row["instruction"]
        label = row["label"]
        return img, instruction, label
    

In [3]:
df = pd.read_csv("seeclick_web_test.csv")
dataset = TestSet(df)
dataloader =  DataLoader(dataset, batch_size=1, shuffle=True)

df.head(10)

Unnamed: 0,img_filename,bbox,instruction,data_type,data_source,normalized_bbox,label
0,web_213f816e-8e80-4d13-970d-1347bbc7a2a8.png,"[2321, 129, 208, 70]",create a new project,text,gitlab,"(0.906640625, 0.08958333333333333, 0.987890625...",1194
1,web_213f816e-8e80-4d13-970d-1347bbc7a2a8.png,"[2401, 14, 111, 68]",view my account,icon,gitlab,"(0.937890625, 0.009722222222222222, 0.98125, 0...",395
2,web_e40f1b3f-0f26-4313-a6a2-d79e1047951b.png,"[194, 15, 645, 66]",search in gitlab,text,gitlab,"(0.07578125, 0.010416666666666666, 0.327734375...",320
3,web_e40f1b3f-0f26-4313-a6a2-d79e1047951b.png,"[1753, 8, 112, 77]",add a new one,icon,gitlab,"(0.684765625, 0.005555555555555556, 0.72851562...",370
4,web_fd8d71f6-4229-4458-a77e-7d8a6347c8e9.png,"[2044, 96, 481, 187]",go to personal homepage,icon,gitlab,"(0.7984375, 0.06666666666666667, 0.986328125, ...",1389
5,web_fd8d71f6-4229-4458-a77e-7d8a6347c8e9.png,"[2043, 492, 483, 89]",sign out,text,gitlab,"(0.798046875, 0.3416666666666667, 0.98671875, ...",3789
6,web_4e1d5837-4731-43f3-8101-52375498c4ad.png,"[427, 234, 150, 96]",switch to explore projects,text,gitlab,"(0.166796875, 0.1625, 0.225390625, 0.229166666...",1919
7,web_4e1d5837-4731-43f3-8101-52375498c4ad.png,"[1601, 350, 116, 66]",star the project with 56 stars,icon,gitlab,"(0.625390625, 0.24305555555555555, 0.670703125...",2664
8,web_4e1d5837-4731-43f3-8101-52375498c4ad.png,"[1704, 678, 83, 59]",fork the a11y project,icon,gitlab,"(0.665625, 0.4708333333333333, 0.698046875, 0....",4968
9,web_bcce7aec-b36a-42c5-8beb-ead23f5ada2c.png,"[197, 232, 1279, 68]",view issues i've created,text,gitlab,"(0.076953125, 0.16111111111111112, 0.5765625, ...",1832


In [10]:
for images, instructions, labels in dataloader:
    print(images)
    print(instructions)
    print(labels)
    breakpoint()

tensor([[[[2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.1633, 2.2489, 2.2489],
          ...,
          [2.2489, 2.2489, 2.2489,  ..., 0.7077, 0.6734, 1.8722],
          [2.0777, 2.2489, 2.2489,  ..., 1.3070, 1.0673, 1.7523],
          [0.8789, 2.2318, 2.2489,  ..., 1.9749, 1.9578, 0.7762]],

         [[2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.3410, 2.4286, 2.4286],
          ...,
          [2.4286, 2.4286, 2.4286,  ..., 1.4132, 1.3957, 2.2010],
          [2.2710, 2.4286, 2.4286,  ..., 1.8158, 1.6408, 2.0609],
          [1.0805, 2.4111, 2.4286,  ..., 2.2710, 2.2535, 1.0805]],

         [[2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2

KeyboardInterrupt: 

In [21]:
def evaluate(model, test_loader, device):
    print(device)
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient computation
        for images, instructions, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            with torch.no_grad():
                outputs = model(images, instructions)
            # print('outputs', outputs)
            # print('outputs.data', outputs.data)
            _, predicted = torch.max(outputs.data, 1)
            print(predicted)
            # print(predicted)
            # print('predicted', predicted)
            # print('labels: ', labels)
            total += labels.size(0)
            # print('labels.size(0)', labels.size(0))
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the test images: {accuracy:.2f}%')
    return accuracy



In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps:0" if torch.backends.mps.is_available() else "cpu")

model = torch.load('model.path')
# model = Model(device=device)
# print(next(model.parameters()).data)
# checkpoint = torch.load("model.ckpt", map_location=device)  # Ensure checkpoint is loaded to the correct device
# model.load_state_dict(checkpoint)
# model.to(device)  # Ensure model parameters and buffers are on the right device.
# print(next(model.parameters()).data)
model.to(device)
model.eval()
print(device)
print(model.device)

mps:0
mps


In [23]:
evaluate(model, dataloader, device)

mps:0
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')
tensor([0], device='mps:0')


KeyboardInterrupt: 