# Compute Attention map for ViT tokens

## Self-Attention

Attention formulation:

\begin{equation}
    \begin{aligned}
        & Attention(Q,K,V) = softmax \left(\frac{QK^{T}}{\sqrt d_{k}}\right)V
            % \min_{G}\max_{D}\mathbf{E}_{x\sim p_{data}(x)}[log(D(x))] + \mathbf{E}_{z\sim p_{z}(z)}[log(1-D(G(z)))]
            % KL(P \| Q) \coloneqq \int_{-\infty}^{\infty} log \frac{P(dx)}{Q(dx)}P(dx)
    \end{aligned}
\end{equation}

Self-attention takes an input vector, and transform them with $W_{q}$, $W_{k}$ and $W_{v}$:

\begin{equation}
    \begin{aligned}
        & SelfAttention(x) = softmax \left(\frac{xW_{q}W_{k}^{T}x^{T}}{\sqrt d_{k}}\right)xW_{v}
            % \min_{G}\max_{D}\mathbf{E}_{x\sim p_{data}(x)}[log(D(x))] + \mathbf{E}_{z\sim p_{z}(z)}[log(1-D(G(z)))]
            % KL(P \| Q) \coloneqq \int_{-\infty}^{\infty} log \frac{P(dx)}{Q(dx)}P(dx)
    \end{aligned}
\end{equation}

## Understand ViT tokens

![alt text](https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/vit_architecture.jpg "ViT architecture")

As opposed to tokenizing words as in LLMs, Vision transformers tokenize the image in sequential patches but insert a CLS token for learning classification tasks.

## Understand attention score

Attention score is defined per token in the attention formulation, the normalized attention score is simply:

\begin{equation}
    \begin{aligned}
        softmax \left(\frac{QK^{T}}{\sqrt d_{k}}\right)
    \end{aligned}
\end{equation}

### How to visualize attention scores

1. Mean attention distance: distance weighted attention score.
2. Attention Rollout: recursively scaling weighted attention score from output to input.
3. Attention heatmaps: visualize normalized attention score per token.

In this lab, we will visualize the attention map with respect to the CLS token to get a better idea of what the model is attending to when it comes to classification.

For further challenges and understanding, we will examine intra-token attentions based on some metrics.

## Lab: visualize attention heatmap

In [None]:
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

In [None]:
import math
import numpy as np
import sys
import numpy as np
import multiprocessing
import time
import matplotlib.pyplot as plt
import tensorflow.compat.v2 as tf

### Load pretrained models

In [None]:
import zipfile
GITHUB_RELEASE = "https://github.com/sayakpaul/probing-vits/releases/download/v1.0.0/probing_vits.zip"
FNAME = "probing_vits.zip"
MODELS_ZIP = {
    "vit_dino_base16": "Probing_ViTs/vit_dino_base16.zip",
    "vit_b16_patch16_224": "Probing_ViTs/vit_b16_patch16_224.zip",
    "vit_b16_patch16_224-i1k_pretrained": "Probing_ViTs/vit_b16_patch16_224-i1k_pretrained.zip",
}

In [None]:
zip_path = tf.keras.utils.get_file(
    fname=FNAME,
    origin=GITHUB_RELEASE,
)

with zipfile.ZipFile(zip_path, "r") as zip_ref:
    zip_ref.extractall("./")

os.rename("Probing ViTs", "Probing_ViTs")

with zipfile.ZipFile("Probing_ViTs/vit_b16_patch16_224-i1k_pretrained.zip", "r") as zip_ref:
        zip_ref.extractall("Probing_ViTs/")

In [None]:
model = tf.keras.models.load_model('./Probing_ViTs/vit_b16_patch16_224-i1k_pretrained')

### Preprocess the image

In [None]:
from PIL import Image
import requests
import io

In [None]:
input_resolution = 224

norm_layer = tf.keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)


def preprocess_image(image, size=input_resolution):
    image = np.array(image)
    image_resized = tf.expand_dims(image, 0)
    image_resized = tf.image.resize(
        image_resized, (size, size), method="bicubic"
    )
    if image_resized.shape[-1] == 1:
        norm_img = image_resized
    else:
        norm_img = norm_layer(image_resized).numpy()
    return norm_img, image_resized


def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(io.BytesIO(response.content))
    if len(np.array(image).shape) == 2:
        image_np = np.expand_dims(np.array(image), axis=-1)
    else:
        image_np = np.array(image)
    preprocessed_image, resized_image = preprocess_image(image_np)
    return image, preprocessed_image, resized_image

In [None]:
# Load image from public url
image, preprocessed_image, resized_image = load_image_from_url(
    "https://dl.fbaipublicfiles.com/dino/img.png"
)
normalized_resize = (resized_image-np.min(resized_image))/(np.max(resized_image)-np.min(resized_image))
image

In [None]:
### TO DO: upload your own image to a public url and visualize its attention score. ###

### Extract Attention Score

In [None]:
# List of the names of the transformer blocks
blocknames = ['transformer_block_0_att',
                  'transformer_block_1_att',
                  'transformer_block_2_att',
                  'transformer_block_3_att',
                  'transformer_block_4_att',
                  'transformer_block_5_att',
                  'transformer_block_6_att',
                  'transformer_block_7_att',
                  'transformer_block_8_att',
                  'transformer_block_9_att',
                  'transformer_block_10_att',
                  'transformer_block_11_att',]

In [None]:
# Extract and sort attention score by layer
attention_score_dict = model(preprocessed_image, training=False)[1]
attention_score_list=[]
for k in blocknames:
    attention_score_list.append(attention_score_dict[k])

nh = 12
token_size = int(224/16)

### Visualize heatmap for last layer per head

In [None]:
# Pick attention score from specific layers
attention = attention_score_list[-1].numpy()

# Process attention score into heatmap
attention = attention[0, :, 0, 1:]
attention = attention.reshape(nh, token_size, token_size)
attention = attention.transpose((1, 2, 0))
attention = tf.image.resize(
                attention,
                size=(
                    224,
                    224,
                ),
            ).numpy()
attention = attention.transpose((2, 0, 1))

In [None]:
# Plot image and heat map per head
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(3):
    for j in range(4):
        if img_count < len(attention):
            axes[i, j].imshow(normalized_resize[0])
            axes[i, j].imshow(attention[img_count], alpha=0.5)
            axes[i, j].title.set_text(f"Attention head: {img_count}")
            axes[i, j].axis("off")
            img_count += 1

In [None]:
### TO_DO: Visualize attention heatmap for other layers

### Optional: examine attention heatmap for any arbitrary token pair

In [None]:
### TO DO: compare all tokens pair-wise and look for semantically meaningful attentions with ground truth masks.

### Load image and masks

In [None]:
patch_size = 16
num_tokens = int(input_resolution/patch_size)**2+1

In [None]:
image, preprocessed_image, resized_image = load_image_from_url(
    # "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/beer.jpg"
    # "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/chicks.jpg"
    "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/duck.jpg"
)
normalized_resize = (resized_image-np.min(resized_image))/(np.max(resized_image)-np.min(resized_image))
image

In [None]:
image_dist, _, distractor = load_image_from_url(
    # "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/beer_dist.jpg"
    # "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/chicks_dist.jpg"
    "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/duck_dist.jpg"
)
normalized_distractor = (distractor-np.min(distractor))/(np.max(distractor)-np.min(distractor))
image_dist

In [None]:
image_targ, _, target = load_image_from_url(
    # "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/beer_targ.jpg"
    # "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/chicks_targ.jpg"
    "https://raw.githubusercontent.com/schwartz-cnl/Computational-Neuroscience-Class/refs/heads/main/Transformers%20and%20Self-Attention/assets/duck_targ.jpg"
)
normalized_target = (target-np.min(target))/(np.max(target)-np.min(target))
image_targ

In [None]:
# Set color map
from matplotlib.colors import colorConverter
import matplotlib as mpl

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

# Generate the transparent colors
color1 = colorConverter.to_rgba('white',alpha=0.0)
color2 = colorConverter.to_rgba('red',alpha=0.8)
color3 = colorConverter.to_rgba('cyan',alpha=0.8)

# Make the colormaps
cmap1 = mpl.colors.LinearSegmentedColormap.from_list('my_cmap',[color1,color2],256)
cmap2 = mpl.colors.LinearSegmentedColormap.from_list('my_cmap2',[color1,color3],256)

credits:
1. adapted from https://keras.io/examples/vision/probing_vits/#method-iii-attention-heatmaps
2. methods adpated from https://arxiv.org/abs/2405.14880