In [9]:
import os
if 'HOSTNAME' not in os.environ:
    os.environ['HOSTNAME'] = 'ODIN' # or 'ODIN'ArithmeticError

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms

from core.CNN_scorers import TorchScorer

True


In [10]:
tmpsavedir = "" # Temporary save directory

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load model
network = 'vit_b_32'
layer = '.heads.Linearhead'
unit_idx = [373, 0, 0]

model_unit = (network, layer, unit_idx[0], unit_idx[1], unit_idx[2])

model = TorchScorer(network, device=device)
model.select_unit(model_unit)

explabel = f'activation_{network}_{layer}_{unit_idx[0]}-{unit_idx[1]}-{unit_idx[2]}'




In [11]:
batch_size = 256

train_path = '/data/imagenet-2012/imagenet12/images/train' # path on odin
transform = transforms.Compose(
    [transforms.Resize((224, 224)), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)
imagenet_data = torchvision.datasets.ImageFolder(train_path, transform=transform)
data_loader = torch.utils.data.DataLoader(
    imagenet_data,
    batch_size=batch_size,
    shuffle=False,
    num_workers=12
)

In [12]:
scores = np.zeros((len(imagenet_data),2))

idx = 0
for img_batch, i in tqdm(data_loader, total=len(data_loader)): 
    
    with torch.no_grad():
        model.model(img_batch.to(model.device))
    s = model.activation["score"].squeeze().cpu().numpy().squeeze()
    
    scores[idx*batch_size:idx*batch_size+len(i),0] = s
    scores[idx*batch_size:idx*batch_size+len(i),1] = i
    idx += 1

np.savez(f"{explabel}.npz",scores=scores[:,0], labels=scores[:,1])

  1%|▏         | 66/4809 [00:30<29:12,  2.71it/s] 

In [None]:
plt.scatter(scores[:,0], scores[:,1], alpha=0.1)
plt.gca().set_ylabel('ImageNet class')
plt.gca().set_xlabel('activity')

In [None]:
maxmedian = np.zeros((1000,2))

for i in range(1000):
    x = scores[scores[:,1]==i,0]
    maxmedian[i,0] = np.max(x)
    maxmedian[i,1] = np.median(x)
    
plt.plot(maxmedian, label=['max', 'median'])
plt.legend()
plt.gca().set_xlabel('ImageNet class')
plt.gca().set_ylabel('activity')
np.argmax(maxmedian,axis=0)