In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import os
import transformers
from accelerate import Accelerator
from transformers import ViTImageProcessor, ViTForImageClassification, ViTConfig, ViTForMaskedImageModeling
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
import cv2
from datasets import load_dataset,load_metric


os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
accelerator = Accelerator()
device = accelerator.device

  from .autonotebook import tqdm as notebook_tqdm
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "http://farm3.staticflickr.com/2066/1798910782_5536af8767_z.jpg"
# url = "http://farm1.staticflickr.com/184/399924547_98e6cef97a_z.jpg"
# url = "http://farm1.staticflickr.com/128/318959350_1a39aae18c_z.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
def load_data(): 
    dataset = load_dataset("mrm8488/ImageNet1K-val")
    dataset = dataset['train']
    splits = dataset.train_test_split(test_size=0.1, seed=42)
    test_ds = splits['test']
    splits = splits['train'].train_test_split(test_size=0.1, seed=42)
    train_ds = splits['train']
    val_ds = splits['test']
    return train_ds, val_ds, test_ds

train_ds, val_ds, test_ds = load_data()

image = train_ds[3]['image']

# pretrained_name = 'google/vit-base-patch16-224'
pretrained_name = 'vit-base-patch16-224-finetuned-imageneteval/checkpoint-60'
# pretrained_name = 'openai/clip-vit-base-patch32'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
# get mean and std to unnormalize the processed images
mean, std = processor.image_mean, processor.image_std

pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)

inputs = processor(images=image, return_tensors="pt")
inputs.to(device)
outputs = pred_model(**inputs, output_hidden_states=True)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", pred_model.config.id2label[predicted_class_idx])

Repo card metadata block was not found. Setting CardData to empty.


In [None]:
def get_pyx_prime(model, outputs):
    """
    Obtain p(y|x) and p(y|x'), where x' is the input with the ith entry missing.
    Args:
        model: a ViT model
        outputs: the outputs of the ViT model given input x
    Returns:
        pyx: p(y|x)
        pyx_prime: p(y|x')
    """
    image_embeds = outputs.hidden_states[-1] # [N, L+1, d]
    image_embeds = model.vit.layernorm(image_embeds) # [N, L+1, d]
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
    logits = model.classifier(image_embeds) # [N, L+1, 1000]
    logits = torch.softmax(logits, dim=-1)

    pyx = logits[:, 0:1, :] # [N, 1, 1000]
    pyx_prime = logits[:, 1:, :] # [N, L, 1000]
    return pyx, pyx_prime

def get_heatmap(pyx, pyx_prime):
    """
    Given p(y|x) and p(y|x'), where x' is the input with the ith entry missing.
    Args:
        pyx: [N, 1, 1000]
        pyx_prime: [N, L, 1000]

    Returns: 
        heatmap: [N, 14, 14, 1000]
    """
    res = (pyx - pyx_prime) # [N, L, N]
    # res = pyx_prime
    N_v, L, N_t = res.shape
    res = (res>0).float() * res
    
    heatmap = res.reshape(N_v,14,14, N_t).detach().cpu().numpy()
    return heatmap

def unnormalize(img, mean, std):
    mean = np.array(mean).reshape(1,1,3)
    std = np.array(std).reshape(1,1,3)
    return img * std + mean

def convert_to_255_scale(img):
    return (img * 255).astype(np.uint8)

def unnormalize_and_255_scale(img, mean, std):
    return convert_to_255_scale(unnormalize(img,mean,std))

def show_superimposed(img, heatmap):
    cv2_image = cv2.cvtColor(img.transpose(1,2,0), cv2.COLOR_RGB2BGR)
    blur = cv2.GaussianBlur(heatmap,(13,13), 11)

def normalize_and_rescale(heatmap):
    max_value = np.max(heatmap)
    min_value = np.min(heatmap)
    heatmap_ft = (heatmap - min_value) / (max_value - min_value) # float point
    return convert_to_255_scale(heatmap_ft) # int8

def get_overlap(image, heatmap):
    return cv2.addWeighted(heatmap, 0.5, image, 0.5, 0)

def plot_overlap(image, heatmap):
    overlap = get_overlap(image, heatmap)
    plt.imshow(overlap)
    plt.axis('off')
    plt.show()
    return 0

def plot_overlap_np(image, heatmap, img_mean, img_std):
    shape = image.shape[:2]
    heatmap = normalize_and_rescale(heatmap)
    resized_heatmap = cv2.resize(heatmap, shape)
    blur = cv2.blur(resized_heatmap ,(13,13), 11)
    heatmap_img = cv2.applyColorMap(blur, cv2.COLORMAP_JET)
    heatmap_img = cv2.cvtColor(heatmap_img, cv2.COLOR_BGR2RGB)

    image = unnormalize_and_255_scale(image, img_mean, img_std)
    
    plot_overlap(image, heatmap_img)
    return image, heatmap_img

pyx, pyx_prime = get_pyx_prime(pred_model, outputs)
heatmap = get_heatmap(pyx, pyx_prime)[0,:,:,predicted_class_idx]
img = inputs.pixel_values[0].cpu().numpy().transpose(1,2,0)

image, heatmap_img = plot_overlap_np(img, heatmap, mean, std)

In [None]:
config.id2label[predicted_class_idx]

In [None]:
plt.imshow(pyx_prime[0,:,predicted_class_idx].reshape(14,14).detach().cpu().numpy())

In [None]:
plt.imshow(image)

In [148]:
torch.topk(-pyx_prime[0,:,predicted_class_idx], k=10)
# torch.topk(-pyx_prime[0,:,:], k=10)

torch.return_types.topk(
values=tensor([-0.0796, -0.0848, -0.1004, -0.1096, -0.1098, -0.8484, -1.0187, -1.1403,
        -1.3748, -1.3919], device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([41, 27, 55, 28, 14, 24, 25, 38, 53, 52], device='cuda:0'))

In [143]:
pyx_prime.shape

torch.Size([1, 196, 1000])

In [136]:
(pyx - pyx_prime)[:,:,285].reshape(14,14)

tensor([[0.6648, 0.1864, 0.1534, 0.4230, 0.5003, 0.2311, 0.2753, 0.2307, 0.4921,
         0.6167, 0.9358, 0.4667, 0.5765, 0.3406],
        [0.7313, 0.3677, 0.3738, 0.4811, 0.4511, 0.1885, 0.2276, 0.4310, 0.2549,
         0.1728, 0.2518, 0.0900, 0.5523, 0.6250],
        [0.6045, 0.6936, 0.5243, 0.7600, 0.3055, 0.4877, 0.6262, 0.4158, 0.1189,
         0.2034, 0.1487, 0.2163, 0.2734, 0.1782],
        [0.6791, 0.8337, 0.5728, 0.1679, 0.4762, 0.7462, 0.4919, 0.4401, 0.1413,
         0.3464, 0.0770, 0.1961, 0.2623, 0.3292],
        [0.1688, 0.2208, 0.1306, 0.1381, 0.1869, 0.2910, 0.4764, 0.8668, 0.1669,
         0.0633, 0.0936, 0.2386, 0.4852, 0.4924],
        [0.2015, 0.2300, 0.0885, 0.0780, 0.1484, 0.1327, 0.3209, 0.1737, 0.1333,
         0.0803, 0.0756, 0.0994, 0.4112, 0.4168],
        [0.2385, 0.1657, 0.1166, 0.1356, 0.5281, 0.4363, 0.2446, 0.1086, 0.0994,
         0.0507, 0.1346, 0.2167, 0.5004, 0.2677],
        [0.3029, 0.1601, 0.0129, 0.2558, 0.3674, 0.1925, 0.1957, 0.2768, 0.2492,
  

In [125]:
torch.topk(logits, 10)

torch.return_types.topk(
values=tensor([[12.4186,  9.2246,  8.2434,  6.7615,  5.1892,  3.4349,  3.3213,  3.2908,
          3.1910,  2.9049]], device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([[285, 281, 282, 287, 284, 283, 289, 293, 785, 292]], device='cuda:0'))

In [7]:
torch.cuda.device_count()

1