In [None]:
import os
import sys
import cv2
from PIL import Image
sys.path.append('../../../src')
import copy
import matplotlib.pyplot as plt

import tortto as tt
import tortto.nn as nn
import tortto.nn.functional as F

from torchvision.transforms import Compose, Resize, ToTensor, Normalize

In [None]:
img_mean = (.5,.5,.5)
img_sd = (.5,.5,.5)

transform = Compose([
    Resize((224,224)),
    ToTensor(),
    Normalize(img_mean, img_sd)
])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1, activation = F.relu):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, src, src_mask = None, src_key_padding_mask = None):
        x = src
        y, attn_weight = self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
        x = x + y
        x = x + self._ff_block(self.norm2(x))
        return x, attn_weight

    # self-attention block
    def _sa_block(self, x, attn_mask, key_padding_mask):
        x, attn_weight = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask,
                           need_weights=True, average_attn_weights=False)
        return self.dropout1(x), attn_weight

    # feed forward block
    def _ff_block(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm
        self.attn = []

    def forward(self, src, mask = None, src_key_padding_mask = None):
        output = src
        for mod in self.layers:
            output, attn_weight = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
            self.attn.append(attn_weight)
        if self.norm is not None:
            output = self.norm(output)
        return output

class VisionTransformer_B16(nn.Module):
    def __init__(self, image_size=(224,224), patch_size=(16,16), embed_dim=768, hidden_dim=3072, in_channels=3, 
                 num_heads=12, num_layers=12, num_classes=10, dropout=0.):
        H, W = image_size
        ph, pw = patch_size
        num_patches = int(((H + 0.1) // ph) * ((W + 0.1) // pw))
        super(VisionTransformer_B16, self).__init__()
        self.input_layer = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(tt.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(tt.randn(1, num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        encoder_layer = TransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout, activation=F.gelu)
        self.transformer = TransformerEncoder(encoder_layer, num_layers, norm=nn.LayerNorm(embed_dim))
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):  # B: batchsize, C:in_channels, E:embedding dim
        x = self.input_layer(x).flatten(-2, -1).swapaxes(-1, -2)  # B,C,H,W, -> B,E,h,w -> B,E,h*w -> B,h*w,E
        B, N = x.shape[:2]  # batch and number of patches
        x = tt.cat([self.cls_token.repeat(B, 1, 1), x], dim=1)  # add class token
        x = x + self.pos_embedding[:, :(N + 1)]  # add position embedding
        x = self.dropout(x)
        x = self.transformer(x)
        x = F.log_softmax(self.classifier(x[:, 0]), dim=-1)
        return x

In [None]:
# https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
import numpy as np
from torch import eye, tensor
def rollout(attentions, discard_ratio, head_fusion):
    result = eye(attentions[0].size(-1)) #197,197
    for attention in attentions:
        if head_fusion == "mean":
            attention_heads_fused = attention.mean(axis=1)
        elif head_fusion == "max":
            attention_heads_fused = attention.max(axis=1)[0]
        elif head_fusion == "min":
            attention_heads_fused = attention.min(axis=1)[0]
        else:
            raise "Attention head fusion type Not supported"

        # Drop the lowest attentions, but
        # don't drop the class token
        flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) #1,197*197
        _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
        indices = indices[indices != 0]
        flat[0, indices] = 0

        I = eye(attention_heads_fused.size(-1))#197,197
        a = (attention_heads_fused + 1.0*I)/2
        a = a / a.sum(dim=-1) # take mean

        result = a @ result #1,197,197
    
    # Look at the total attention between the class token and the image patches
    mask = result[0, 0 , 1 :] #196
    # In case of 224x224 image, this brings us from 196 to 14
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask
def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

In [None]:
net = VisionTransformer_B16()
net.load_state_dict(tt.load('checkpoint_014.npy')['model'])
net = net.cuda()

# uncomment to plot attention map

In [None]:
# fns=os.listdir('images')
# fig,axs=plt.subplots(len(fns),4,figsize=(12,43))

# axs[0,0].set_title('Original',fontsize=15)
# axs[0,1].set_title('Attention weight',fontsize=15)
# axs[0,2].set_title('Overlap',fontsize=15)
# axs[0,3].set_title('Prediction',fontsize=15)
# i=0
# for fn in fns:
#     fn='images/'+fn
#     original=Image.open(fn).convert('RGB')
#     data=tt.tensor(transform(original)[None].numpy()).cuda()

#     net.eval()
#     with tt.no_grad():
#         net.transformer.attn=[]
#         outputs=net(data)
#         predicted=classes[outputs.argmax(-1).item()]
#     prob=outputs.exp().detach().cpu().numpy()
#     attn_weight=[tensor(t.cpu().numpy()) for t in net.transformer.attn]
#     mask=rollout(attn_weight, discard_ratio=0, head_fusion='mean')
#     mask=cv2.resize(mask, (224, 224))
#     mask/=mask.max()
#     img=show_mask_on_image(255*(data.cpu().numpy()[0].transpose(1,2,0)*0.5+0.5), mask)

#     axs[i,0].imshow(original)
#     axs[i,0].axis('off')
    
#     axs[i,1].imshow(mask, cmap='jet')
#     axs[i,1].axis('off')

#     axs[i,2].imshow(img, cmap='jet')
#     axs[i,2].axis('off')

#     axs[i,3].bar(classes, prob[0])
#     axs[i,3].tick_params('x', labelrotation=45)
#     axs[i,3].set_ylim([0, 1.1])
#     axs[i,3].set_aspect(10)
#     i+=1
# fig.tight_layout()
# plt.show()