In [None]:
!pip install adversarial-robustness-toolbox
!pip install timm
!pip install vision_transformer_pytorch

In [1]:
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import cv2
from PIL import Image
from torchvision import transforms

In [3]:
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

from art.attacks.evasion import FastGradientMethod, ProjectedGradientDescentPyTorch
from art.estimators.classification import PyTorchClassifier
from art.utils import load_mnist

import urllib

from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

In [6]:
def plot_attention_map(original_img, att_map):
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map Last Layer')
    _ = ax1.imshow(original_img)
    _ = ax2.imshow(att_map)

def get_attention_map(img, get_mask=False):
    x = transform(img)
    x.size()

    logits, att_mat = model(x.unsqueeze(0))

    att_mat = torch.stack(att_mat).squeeze(1)

    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    if get_mask:
        result = cv2.resize(mask / mask.max(), img.size)
    else:        
        mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
        result = (mask * img).astype("uint8")
    
    return result

In [16]:
# model = timm.create_model('vit_large_patch16_224', pretrained=True)

In [22]:
from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16') # or 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16'

# inputs = torch.randn(1, 3, *model.image_size)
# model(inputs)
# model.extract_features(inputs)

Loaded pretrained weights for ViT-B_16


(384, 384)

In [None]:
model

In [19]:
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

In [23]:
# config = resolve_data_config({}, model=model)
# transform = create_transform(**config)

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)

img1 = Image.open(filename).convert('RGB')
tensor = transform(img).unsqueeze(0) # transform and add batch dimension
img1
tensor.size()

torch.Size([1, 3, 384, 384])

In [24]:
x = transform(img)
x.size()

torch.Size([3, 384, 384])

In [25]:
x.unsqueeze(0).shape

torch.Size([1, 3, 384, 384])

In [11]:
res = model(x.unsqueeze(0))



In [12]:
res.shape

torch.Size([1, 1000])

In [8]:
result1 = get_attention_map(img1)

ValueError: ignored