<a href="https://colab.research.google.com/github/sayakpaul/cait-tf/blob/main/notebooks/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Off-the-shelf image classification with CaiT models on TF-Hub

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/sayakpaul/cait-tf/blob/main/notebooks/classification.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/sayakpaul/cait-tf/blob/main/notebooks/classification.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
  <td>
    <a href="https://tfhub.dev/sayakpaul/collections/cait/1"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />Models on Hub</a>
  </td>
</table>

## Setup

In [None]:
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras


from PIL import Image
from io import BytesIO

import matplotlib.pyplot as plt
import numpy as np
import requests
import cv2

## Select a [CaiT](https://arxiv.org/abs/2103.17239) ImageNet-1k model

Find the entire collection [here] (coming soon).

In [None]:
model_name = "cait_xxs24_224"  # @param ['cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'cait_xs24_384', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_m36_384', 'cait_m48_448']
model_handle_map = {
    "cait_xxs24_224": "https://tfhub.dev/sayakpaul/cait_xxs24_224/1",
    "cait_xxs24_384": "https://tfhub.dev/sayakpaul/cait_xxs24_384/1",
    "cait_xxs36_224": "https://tfhub.dev/sayakpaul/cait_xxs36_224/1",
    "cait_xxs36_384": "https://tfhub.dev/sayakpaul/cait_xxs36_384/1",
    "cait_xs24_384": "https://tfhub.dev/sayakpaul/cait_xs24_384/1",
    "cait_s24_224": "https://tfhub.dev/sayakpaul/cait_s24_224/1",
    "cait_s24_384": "https://tfhub.dev/sayakpaul/cait_s24_384/1",
    "cait_s36_384": "https://tfhub.dev/sayakpaul/cait_s36_384/1",
    "cait_m36_384": "https://tfhub.dev/sayakpaul/cait_m36_384/1",
    "cait_m48_448": "https://tfhub.dev/sayakpaul/cait_m48_448/1",
}

input_resolution = int(model_name.split("_")[-1])
model_handle = model_handle_map[model_name]
print(f"Input resolution: {input_resolution} x {input_resolution} x 3.")
print(f"TF-Hub handle: {model_handle}.")

## Image preprocessing utilities 

In [None]:
crop_layer = keras.layers.CenterCrop(input_resolution, input_resolution)
norm_layer = keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)


def preprocess_image(image, size=input_resolution):
    image = np.array(image)
    image_resized = tf.expand_dims(image, 0)
    resize_size = int((256 / 224) * size)
    image_resized = tf.image.resize(
        image_resized, (resize_size, resize_size), method="bicubic"
    )
    image_resized = crop_layer(image_resized)
    return norm_layer(image_resized).numpy()


def load_image_from_url(url):
    # Credit: Willi Gierke
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    preprocessed_image = preprocess_image(image)
    return image, preprocessed_image

## Load ImageNet-1k labels and a demo image

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://i.imgur.com/ErgfLTn.jpg"
image, preprocessed_image = load_image_from_url(img_url)

# https://unsplash.com/photos/Ho93gVTRWW8
plt.imshow(image)
plt.axis("off")
plt.show()

## Run inference

In [None]:
def get_model(model_url: str) -> tf.keras.Model:
    inputs = tf.keras.Input((input_resolution, input_resolution, 3))
    hub_module = hub.KerasLayer(model_url)

    outputs, _, _ = hub_module(
        inputs
    )  # Second and third outputs in the tuple is a dictionary
    # containing attention scores.

    return tf.keras.Model(inputs, outputs)

In [None]:
classification_model = get_model(model_handle)
predictions = classification_model.predict(preprocessed_image)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
print(predicted_label)

## Obtaining attention scores

The models are capable to outputting attention scores (softmax scores) for each of the transformer blocks.

In [None]:
updated_model_handle = f"gs://tfhub-modules/sayakpaul/{model_name}/1/uncompressed"

loaded_model = tf.keras.models.load_model(updated_model_handle)
logits, sa_atn_score_dict, ca_atn_score_dict = loaded_model.predict(preprocessed_image)
ca_atn_score_dict.keys()

In [None]:
# (batch_size, nb_attention_heads, num_cls_token, seq_length)
ca_atn_score_dict["ca_ffn_block_0_att"].shape

## Visualizing attention maps - figures 6 and 7 of the [paper](https://arxiv.org/abs/2103.17239)

### Class attention maps (spatial-class relationship)

In [None]:
# Reference:
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py

patch_size = 16


def get_cls_attention_map(
    attn_score_dict=ca_atn_score_dict,
    block_key="ca_ffn_block_0_att",
    return_saliency=False,
):
    w_featmap = preprocessed_image.shape[2] // patch_size
    h_featmap = preprocessed_image.shape[1] // patch_size

    attention_scores = attn_score_dict[block_key]
    nh = attention_scores.shape[1]  # Number of attention heads.

    # Taking the representations from CLS token.
    attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
    print(attentions.shape)

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(nh, w_featmap, h_featmap)

    if not return_saliency:
        attentions = attentions.transpose((1, 2, 0))
        print(attentions.shape)

    else:
        attentions = np.mean(attentions, axis=0)
        attentions = (attentions - attentions.min()) / (
            attentions.max() - attentions.min()
        )
        attentions = np.expand_dims(attentions, -1)
        print(attentions.shape)

    # Resize the attention patches to 224x224 (224: 14x16)
    attentions = tf.image.resize(
        attentions,
        size=(h_featmap * patch_size, w_featmap * patch_size),
        method="bicubic",
    )
    print(attentions.shape)

    return attentions

In [None]:
attentions_ca_block_0 = get_cls_attention_map()


fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(attentions_ca_block_0.shape[-1]):
    if img_count < attentions_ca_block_0.shape[-1]:
        axes[i].imshow(attentions_ca_block_0[:, :, img_count])
        axes[i].title.set_text(f"Attention head: {img_count}")
        axes[i].axis("off")
        img_count += 1

fig.tight_layout()
fig.savefig("class_attention_heads_0.png", dpi=300, bbox_inches="tight")

In [None]:
attentions_ca_block_1 = get_cls_attention_map(block_key="ca_ffn_block_1_att")


fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(attentions_ca_block_1.shape[-1]):
    if img_count < attentions_ca_block_1.shape[-1]:
        axes[i].imshow(attentions_ca_block_1[:, :, img_count])
        axes[i].title.set_text(f"Attention head: {img_count}")
        axes[i].axis("off")
        img_count += 1

fig.tight_layout()
fig.savefig("class_attention_heads_1.png", dpi=300, bbox_inches="tight")

### Saliency maps

In [None]:
saliency_attention = get_cls_attention_map(return_saliency=True)

In [None]:
image = np.array(image)
image_resized = tf.expand_dims(image, 0)
resize_size = int((256 / 224) * input_resolution)
image_resized = tf.image.resize(
    image_resized, (resize_size, resize_size), method="bicubic"
)
image_resized = crop_layer(image_resized)

plt.imshow(image_resized.numpy().squeeze().astype("int32"))
plt.imshow(saliency_attention.numpy().squeeze(), cmap="cividis", alpha=0.9)
plt.axis("off")
plt.savefig("saliency_attention_map.png", dpi=300, bbox_inches="tight")

plt.show()