In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from types import SimpleNamespace
from collections import OrderedDict
from utils.dataloader import get_dataloaders
from utils.vision_transformer import vit_base, DINOHead
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = SimpleNamespace()
args.dataset = 'c10'
args.model = 'mlp_mixer'
args.batch_size = 32
args.eval_batch_size = 10
args.num_workers = 4
args.seed = 0
args.epochs = 300
args.patch_size = 4
args.autoaugment = False
args.use_cuda = False
args.size = 224
args.split = 'index'

train_dataloader, test_dataloader = get_dataloaders(args)

model = vit_base(num_classes=100)
state_dict = torch.load('../artifacts/cifar100_ViT_B_dino.pth')

stripped_keys = OrderedDict()
for k, v in state_dict.items():
    stripped_keys[k.replace('module.','')] = v

model.load_state_dict(stripped_keys)

Files already downloaded and verified
Files already downloaded and verified


<All keys matched successfully>

In [3]:
atten_dict = {}
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [4]:
for j, (img, label, idx) in enumerate(test_dataloader):
    with torch.no_grad():
        img, label = img.to(device), label.to(device)
        attentions = model.get_last_selfattention(img.clone())
        new_heatmaps = attentions[:,:,0,1:]
        avg_attn_maps = torch.mean(new_heatmaps, dim=1)
        for i in range(len(idx)):
            atten_dict[idx[i].item()] = avg_attn_maps[i,:].detach().cpu().numpy().tolist() # can't have idx as Tensor
    print(f'Processed batch {j}.')

Processed batch 0.
Processed batch 1.
Processed batch 2.
Processed batch 3.
Processed batch 4.
Processed batch 5.
Processed batch 6.
Processed batch 7.
Processed batch 8.
Processed batch 9.
Processed batch 10.
Processed batch 11.
Processed batch 12.
Processed batch 13.
Processed batch 14.
Processed batch 15.
Processed batch 16.
Processed batch 17.
Processed batch 18.
Processed batch 19.
Processed batch 20.
Processed batch 21.
Processed batch 22.
Processed batch 23.
Processed batch 24.
Processed batch 25.
Processed batch 26.
Processed batch 27.
Processed batch 28.
Processed batch 29.
Processed batch 30.
Processed batch 31.
Processed batch 32.
Processed batch 33.
Processed batch 34.
Processed batch 35.
Processed batch 36.
Processed batch 37.
Processed batch 38.
Processed batch 39.
Processed batch 40.
Processed batch 41.
Processed batch 42.
Processed batch 43.
Processed batch 44.
Processed batch 45.
Processed batch 46.
Processed batch 47.
Processed batch 48.
Processed batch 49.
Processed 

In [5]:
with open("avg_attns_testset.json", "w") as outfile:
    json.dump(atten_dict, outfile)