# About this notebook  
- Visualize ViT Attention Map
- ViT github is [here](https://github.com/tczhangzhi/VisionTransformer-Pytorch).
(I modified a little for attention map. please see this [issue](https://github.com/tczhangzhi/VisionTransformer-Pytorch/issues/1#issuecomment-739138519).)


I want to show that Attention Map for cassava.
- I just show a few sample in 2019 train dataset.

You can check my pretrained ViT weight in [here](https://www.kaggle.com/piantic/cassava-vit-b-16).

### If this kernel is useful, feel free to upvote:)

# Vision Transformer (ViT) : Attention Map

This is the Attention Map example.
- Reference is [here](https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb).

<img src='https://user-images.githubusercontent.com/6073256/101206904-2a338f00-36b3-11eb-8920-f617abab1604.png'>

Next, we will see the attention map for cassava leaf!

# Import libraries

In [None]:
import sys

package_path = '../input/visiontransformerpytorch121/VisionTransformer-Pytorch'
sys.path.append(package_path)

import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torchvision.models as models

import cv2
from PIL import Image
from torchvision import transforms

from vision_transformer_pytorch import VisionTransformer

In [None]:
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

## Helper function

In [None]:
# ====================================================
# Helper functions
# ====================================================
def load_state(model_path):
    state_dict = torch.load(model_path)['model']
    state_dict = {k[7:] if k.startswith('module.') else k: state_dict[k] for k in state_dict.keys()}
    state_dict = {k[6:] if k.startswith('model.') else k: state_dict[k] for k in state_dict.keys()}

    return state_dict

In [None]:
def get_attention_map(img, get_mask=False):
    x = transform(img)
    x.size()

    logits, att_mat = model(x.unsqueeze(0))

    att_mat = torch.stack(att_mat).squeeze(1)

    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    if get_mask:
        result = cv2.resize(mask / mask.max(), img.size)
    else:        
        mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
        result = (mask * img).astype("uint8")
    
    return result

def plot_attention_map(original_img, att_map):
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map Last Layer')
    _ = ax1.imshow(original_img)
    _ = ax2.imshow(att_map)

# Load ViT Model

In [None]:
model = VisionTransformer.from_name('ViT-B_16', num_classes=5)
state = load_state('../input/cassava-vit-b-16/ViT-B_16_fold0.pth')
model.load_state_dict(state)

# Visualize Attention Map

In [None]:
label_map = pd.read_json('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json', 
                         orient='index')

display(label_map)

## CBB - Class0

In [None]:
img1 = Image.open("../input/cassava-vit-b-16/train-cbb-114.jpg")
img2 = Image.open("../input/cassava-vit-b-16/train-cbb-44.jpg")

result1 = get_attention_map(img1)
result2 = get_attention_map(img2)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

### Check mask for Attention Map

In [None]:
result1 = get_attention_map(img1, True)
result2 = get_attention_map(img2, True)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

## CBSD - Class1

In [None]:
img1 = Image.open("../input/cassava-vit-b-16/train-cbsd-154.jpg")
img2 = Image.open("../input/cassava-vit-b-16/train-cbsd-821.jpg")

In [None]:
result1 = get_attention_map(img1)
result2 = get_attention_map(img2)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

### Check mask for Attention Map

In [None]:
result1 = get_attention_map(img1, True)
result2 = get_attention_map(img2, True)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

## CGM - Class2

In [None]:
img1 = Image.open("../input/cassava-vit-b-16/train-cgm-498.jpg")
img2 = Image.open("../input/cassava-vit-b-16/train-cgm-6.jpg")

result1 = get_attention_map(img1)
result2 = get_attention_map(img2)

In [None]:
plot_attention_map(img1, result1)
plot_attention_map(img2, result2)

### Check mask for Attention Map

In [None]:
result1 = get_attention_map(img1, True)
result2 = get_attention_map(img2, True)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

## CMD - Class3

In [None]:
img1 = Image.open("../input/cassava-vit-b-16/train-cbsd-154.jpg")
img2 = Image.open("../input/cassava-vit-b-16/train-cbsd-821.jpg")

In [None]:
result1 = get_attention_map(img1)
result2 = get_attention_map(img2)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

### Check mask for Attention Map

In [None]:
result1 = get_attention_map(img1, True)
result2 = get_attention_map(img2, True)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

## Healthy - Class4

In [None]:
img1 = Image.open("../input/cassava-vit-b-16/train-healthy-105.jpg")
img2 = Image.open("../input/cassava-vit-b-16/train-healthy-236.jpg")

result1 = get_attention_map(img1)
result2 = get_attention_map(img2)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

### Check mask for Attention Map

In [None]:
result1 = get_attention_map(img1, True)
result2 = get_attention_map(img2, True)

In [None]:
plot_attention_map(img1, result1)

In [None]:
plot_attention_map(img2, result2)

# Visualize Attention Maps

For example, I will use cbsd images.

In [None]:
def get_attention_info(img):
    x = transform(img)
    x.size()

    logits, att_mat = model(x.unsqueeze(0))

    att_mat = torch.stack(att_mat).squeeze(1)

    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    
    return joint_attentions, grid_size

# def plot_attention_map(original_img, att_map):
#     fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
#     ax1.set_title('Original')
#     ax2.set_title('Attention Map Last Layer')
#     _ = ax1.imshow(original_img)
#     _ = ax2.imshow(att_map)

In [None]:
img1 = Image.open("../input/cassava-vit-b-16/train-cbsd-154.jpg")
img2 = Image.open("../input/cassava-vit-b-16/train-cbsd-821.jpg")

In [None]:
joint_att1, grid_size1 = get_attention_info(img1)
joint_att2, grid_size2 = get_attention_info(img2)

In [None]:
for i, v in enumerate(joint_att1):
    v = joint_att1[-1]
    mask = v[0, 1:].reshape(grid_size1, grid_size1).detach().numpy()
    mask = cv2.resize(mask / mask.max(), img1.size)[..., np.newaxis]
    result = (mask * img1).astype("uint8")

    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map_%d Layer' % (i+1))
    _ = ax1.imshow(img1)
    _ = ax2.imshow(result)

In [None]:
for i, v in enumerate(joint_att2):
    v = joint_att2[-1]
    mask = v[0, 1:].reshape(grid_size2, grid_size2).detach().numpy()
    mask = cv2.resize(mask / mask.max(), img1.size)[..., np.newaxis]
    result = (mask * img2).astype("uint8")

    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map_%d Layer' % (i+1))
    _ = ax1.imshow(img2)
    _ = ax2.imshow(result)

## If this kernel is useful, <font color='orange'>please upvote</font>!