# Off-the-shelf image classification with DeiT models on TF-Hub

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/sayakpaul/deit-tf/blob/main/notebooks/classification.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/sayakpaul/deit-tf/blob/main/notebooks/classification.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
  <td>
    <a href="https://tfhub.dev/sayakpaul/collections/deit/1"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />See TF Hub models</a>
  </td>
</table>

## Setup

In [None]:
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras


from PIL import Image
from io import BytesIO

import matplotlib.pyplot as plt
import numpy as np
import requests
import cv2

## Select a [DeiT](https://arxiv.org/abs/2012.12877) ImageNet-1k model

Find the entire collection [here](https://tfhub.dev/sayakpaul/collections/deit/1).

In [None]:
model_name = "deit_tiny_patch16_224" #@param ["deit_tiny_patch16_224", "deit_tiny_distilled_patch16_224", "deit_small_patch16_224", "deit_small_distilled_patch16_224", "deit_base_patch16_224", "deit_base_distilled_patch16_224", "deit_base_patch16_384", "deit_base_distilled_patch16_384"]

model_handle_map ={
    "deit_tiny_patch16_224": "https://tfhub.dev/sayakpaul/deit_tiny_patch16_224/1",
    "deit_tiny_distilled_patch16_224": "https://tfhub.dev/sayakpaul/deit_tiny_distilled_patch16_224/1",
    "deit_small_patch16_224": "https://tfhub.dev/sayakpaul/deit_small_patch16_224/1",
    "deit_small_distilled_patch16_224": "https://tfhub.dev/sayakpaul/deit_small_distilled_patch16_224/1",
    "deit_base_patch16_224": "https://tfhub.dev/sayakpaul/deit_base_patch16_224/1",
    "deit_base_distilled_patch16_224": "https://tfhub.dev/sayakpaul/deit_base_distilled_patch16_224/1",
    "deit_base_patch16_384": "https://tfhub.dev/sayakpaul/deit_base_patch16_384/1",
    "deit_base_distilled_patch16_384": "https://tfhub.dev/sayakpaul/deit_base_distilled_patch16_384/1",

}

input_resolution = int(model_name.split("_")[-1])
model_handle = model_handle_map[model_name]
print(f"Input resolution: {input_resolution} x {input_resolution} x 3.")
print(f"TF-Hub handle: {model_handle}.")

## Image preprocessing utilities 

In [None]:
crop_layer = keras.layers.CenterCrop(input_resolution, input_resolution)
norm_layer = keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)


def preprocess_image(image, size=input_resolution):
    image = np.array(image)
    image_resized = tf.expand_dims(image, 0)
    resize_size = int((256 / 224) * size)
    image_resized = tf.image.resize(image_resized, (resize_size, resize_size), method="bicubic")
    image_resized = crop_layer(image_resized)
    return norm_layer(image_resized).numpy()
    

def load_image_from_url(url):
    # Credit: Willi Gierke
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    preprocessed_image = preprocess_image(image)
    return image, preprocessed_image

## Load ImageNet-1k labels and a demo image

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image, preprocessed_image = load_image_from_url(img_url)

plt.imshow(image)
plt.show()

## Run inference

In [None]:
def get_model(model_url: str) -> tf.keras.Model:
    inputs = tf.keras.Input((input_resolution, input_resolution, 3))
    hub_module = hub.KerasLayer(model_url)

    outputs, _ = hub_module(inputs) # Second output in the tuple is a dictionary containing attention scores.

    return tf.keras.Model(inputs, outputs)

In [None]:
classification_model = get_model(model_handle)
predictions = classification_model.predict(preprocessed_image)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
print(predicted_label)

## Obtaining attention scores

The models are capable to outputting attention scores (softmax scores) for each of the transformer blocks.

In [None]:
loaded_model = tf.keras.models.load_model(
    f"gs://tfhub-modules/sayakpaul/{model_name}/1/uncompressed"
)
logits, attention_score_dict = loaded_model(preprocessed_image)
attention_score_dict.keys()

In [None]:
# (batch_size, nb_attention_heads, seq_length, seq_length)
attention_score_dict["transformer_block_5_att"].shape

Shoutout to [Aritra](https://github.com/ariG23498) for working on this integration.

## Visualizing attention map 

Code copied and modified from [here](https://colab.research.google.com/github/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb).

In [None]:
# Stack the individual attention matrices from individual transformer blocks.
attn_mat = tf.stack([attention_score_dict[k] for k in attention_score_dict.keys()])
attn_mat = tf.squeeze(attn_mat, axis=1)
print(attn_mat.shape)

# Average the attention weights across all heads.
attn_mat = tf.reduce_mean(attn_mat, axis=1)
print(attn_mat.shape)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_attn = tf.eye(attn_mat.shape[1])
aug_attn_mat = attn_mat + residual_attn
aug_attn_mat = aug_attn_mat / tf.reduce_sum(aug_attn_mat, axis=-1)[..., None]
aug_attn_mat = aug_attn_mat.numpy()
print(aug_attn_mat.shape)

# Recursively multiply the weight matrices
joint_attentions = np.zeros(aug_attn_mat.shape)
joint_attentions[0] = aug_attn_mat[0]

for n in range(1, aug_attn_mat.shape[0]):
    joint_attentions[n] = np.matmul(aug_attn_mat[n], joint_attentions[n-1])
    
# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_attn_mat.shape[-1]))
mask = v[0, 1:].reshape(grid_size, grid_size)
mask = cv2.resize(mask / mask.max(), image.size)[..., np.newaxis]
result = (mask * image).astype("uint8")
print(result.shape)

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 8))
fig.suptitle(f"Predicted label: {predicted_label}.", fontsize=20)

_ = ax1.imshow(image)
_ = ax2.imshow(result)
ax1.set_title("Input Image", fontsize=16)
ax2.set_title("Attention Map", fontsize=16)
ax1.axis("off")
ax2.axis("off")

fig.tight_layout()
fig.subplots_adjust(top=1.35)
fig.show()

In [None]:
fig.savefig("attention_map.png", dpi=300, bbox_inches="tight")