## Setup

In [1]:
import sys
if "ViT-pytorch" not in sys.path:
    sys.path.append("ViT-pytorch")

from albumentations.pytorch import ToTensorV2
import albumentations as A

import os
import pickle
import json

import torch
import numpy as np
import matplotlib.pyplot as plt

from urllib.request import urlretrieve
from tqdm.notebook import tqdm

from PIL import Image
from torchvision import transforms

from models.modeling import VisionTransformer, CONFIGS

## Utilities

In [2]:
with open("imagenet_class_index.json", "r") as read_file:
    imagenet_labels = json.load(read_file)
    
MAPPING_DICT = {}
LABEL_NAMES = {}
for label_id in list(imagenet_labels.keys()):
    MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)
    LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]
    
IMAGENET_VAL_PATHS_1k = pickle.loads(open("IMAGENET_VAL_PATHS_1k.pkl", "rb").read())

In [3]:
os.makedirs("attention_data", exist_ok=True)
if not os.path.isfile("attention_data/ViT-L_16-224.npz"):
    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-L_16-224.npz", 
                "attention_data/ViT-L_16-224.npz")

In [4]:
# Prepare Model
config = CONFIGS["ViT-L_16"]
vit_model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224)
vit_model.load_from(np.load("attention_data/ViT-L_16-224.npz"))
vit_model.eval()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_model = vit_model.to(DEVICE)

In [5]:
def get_transforms_cutout(factor=0.05, total_area=224*224):
    area = total_area * factor
    transform = A.Compose([
        A.Resize(224, 224),
        A.Cutout(num_holes=1, max_h_size=int(np.sqrt(area)), 
                      max_w_size=int(np.sqrt(area)), 
                      always_apply=True,
                      p=1.),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])
    
    return transform

def get_transforms():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    
    return transform

In [18]:
def run_prediction(im, factor=None):
    if len(np.array(im).shape) == 2:
        im = np.expand_dims(np.array(im), -1)
        im = np.tile(im, (1, 1, 3))
        im = Image.fromarray(im)
    if factor:
        trans = get_transforms_cutout(factor)
        x = trans(image=np.array(im))
        x = x["image"]
    else:
        trans = get_transforms()
        x = trans(im)
    
    x = x.to(DEVICE)
    
    logits, _ = vit_model(x.unsqueeze(0))

    return logits

## Evaluate without CutOut

In [15]:
def evaluate(factor=None):
    correct_image_paths = []

    for image_path in tqdm(IMAGENET_VAL_PATHS_1k):
        label = image_path.split("/")[1]
        image_idx = image_path.split(".")[0].split("_")[-1]
        label_idx = MAPPING_DICT[label]

        im = Image.open(image_path)
        logits = run_prediction(im, factor)

        probs = torch.nn.Softmax(dim=-1)(logits)
        pred = torch.argmax(probs, dim=-1)

        if pred.cpu().item() == label_idx: 
            correct_image_paths.append(image_path)

    print(f"Total corrects: {len(correct_image_paths)} out of {len(IMAGENET_VAL_PATHS_1k)}")
    return len(correct_image_paths) / len(IMAGENET_VAL_PATHS_1k)

In [16]:
print(evaluate())

  0%|          | 0/1000 [00:00<?, ?it/s]

Total corrects: 830 out of 1000
0.83


## Evaluate with CutOut at varying levels

In [19]:
factors = [0.05, 0.1, 0.2, 0.5]
factors_dict = {}

for factor in factors:
    top_1_acc = evaluate(factor)
    factors_dict[factor] = top_1_acc
    print(f"{factor}: {top_1_acc}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Total corrects: 823 out of 1000
0.05: 0.823


  0%|          | 0/1000 [00:00<?, ?it/s]

Total corrects: 814 out of 1000
0.1: 0.814


  0%|          | 0/1000 [00:00<?, ?it/s]

Total corrects: 779 out of 1000
0.2: 0.779


  0%|          | 0/1000 [00:00<?, ?it/s]

Total corrects: 604 out of 1000
0.5: 0.604
