In [None]:
!pip install -U gdown -q

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from dataclasses import dataclass
import numpy as np
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import gdown
import zipfile
import cv2
from sklearn.preprocessing import MinMaxScaler

## Setup configuration

In [None]:
RESOLUTION = 224
PATCH_SIZE = 16

## Data augmentation & preprocessing



In [None]:
crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = layers.Rescaling(scale=1.0/127.5, offset=-1)

In [None]:
def preprocess_image(image, model_type, size=RESOLUTION):
  image = np.array(image)
  image = tf.expand_dims(image, 0)

  # If the `model_type` is ViT, rescale the image to [-1, 1]
  if model_type == 'original_vit':
    image = rescale_layer(image)

  resize_size = int((256 / 224) * size)
  image = tf.image.resize(image, (resize_size, resize_size), method='bicubic')

  image = crop_layer(image)

  # If the `model_type` is Deit or DINO, normalize the image
  if model_type != 'original_vit':
    image = norm_layer(image)

  return image.numpy()

## Load a test image and display it

In [None]:
def load_image_from_url(url, model_type):
  response = requests.get(url)
  image = Image.open(BytesIO(response.content))
  preprocessed_image = preprocess_image(image, model_type)
  return image, preprocessed_image

In [None]:
mapping_file = keras.utils.get_file(
  origin="https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
)

In [None]:
with open(mapping_file, 'r') as f:
  lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

In [None]:
img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="original_vit")

plt.imshow(image)
plt.axis("off")
plt.show()

## Load a model

This model was pretrained on the ImageNet-21k dataset and was then fine-tuned on the ImageNet-1k dataset

In [None]:
def get_gdrive_model(model_id):
  model_path = gdown.download(id=model_id, quiet=False)
  with zipfile.ZipFile(model_path, 'r') as zip_ref:
    zip_ref.extractall()
  model_name = model_path.split('.')[0]
  inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
  model = keras.models.load_model(model_name, compile=False)
  outputs, attention_weights = model(inputs)
  return keras.Model(inputs, outputs=[outputs, attention_weights])

In [None]:
vit_base_i21k_patch16_224 = get_gdrive_model("1mbtnliT3jRb3yJUHhbItWw8unfYZw8KJ")
print("Model loaded.")

## Inference

In [None]:
predictions, attention_score_dict = vit_base_i21k_patch16_224.predict(preprocessed_image)

In [None]:
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]

In [None]:
print(predicted_label)

## Attention heatmaps

In [None]:
vit_dino_base16 = get_gdrive_model("16_1oDm0PeCGJ_KGBG5UKVN7TsAtiRNrN")

In [None]:
img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type='dino')

In [None]:
predictions, attention_score_dict = vit_dino_base16.predict(preprocessed_image)

In [None]:
NUM_HEADS = 12

In [None]:
def attention_heatmap(attention_score_dict, image, model_type='dino'):
  num_tokens = 2 if 'distilled' in model_type else 1

  # Sort the Transformer blocks in order of their depth.
  attention_score_list = list(attention_score_dict.keys())
  attention_score_list.sort(key=lambda x: int(x.split('_')[-2]), reverse=True)

  # Process the attention maps for overlay
  w_featmap = image.shape[2] // PATCH_SIZE
  h_featmap = image.shape[1] // PATCH_SIZE
  attention_scores = attention_score_dict[attention_score_list[0]]

  # Taking the representation from [CLS] token
  attentions = attention_scores[0, :, 0, num_tokens:].reshape(NUM_HEADS, -1)

  # Reshape the attentions
  attentions = attentions.reshape(NUM_HEADS, w_featmap, h_featmap)
  attentions = attentions.transpose((1, 2, 0))

  # Resize the attention patches to [224(14x16), 224]
  attentions = tf.image.resize(
      attentions, size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE)
  )

  return attentions

In [None]:
# De-normalize the image for visual clarity.
in1k_mean = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255])
in1k_std = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255])
preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
preprocessed_img_orig = preprocessed_img_orig / 255.0
preprocessed_img_orig = tf.clip_by_value(preprocessed_img_orig, 0.0, 1.0).numpy()

# Generate the attention heatmaps.
attentions = attention_heatmap(attention_score_dict, preprocessed_img_orig)

# Plot the maps.
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(3):
    for j in range(4):
        if img_count < len(attentions):
            axes[i, j].imshow(preprocessed_img_orig[0])
            axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
            axes[i, j].title.set_text(f"Attention head: {img_count}")
            axes[i, j].axis("off")
            img_count += 1

## Visualize the learned filters

In [None]:
projections = (
    vit_base_i21k_patch16_224.layers[1]
    .get_layer('projection')
    .get_layer('conv_projection')
    .kernel.numpy()
)
projection_dim = projections.shape[-1]
patch_h, patch_w, patch_channels = projections.shape[:-1]

scaled_projections = MinMaxScaler().fit_transform(
    projections.reshape(-1, projection_dim)
)

# Reshape the scaled projections so that the leading
# three dimensions resemble an image
scaled_projections = scaled_projections.reshape(patch_h, patch_w, patch_channels, -1)

# Visualize the first 128 filters of the learned projections
fig, axes = plt.subplots(nrows=8, ncols=16, figsize=(13, 8))
img_count = 0
limit = 128

for i in range(8):
  for j in range(16):
    if img_count < limit:
      axes[i, j].imshow(scaled_projections[..., img_count])
      axes[i, j].axis('off')
      img_count += 1

fig.tight_layout()
