This notebook is identical to `dino-attention-maps-video.ipynb` but it generates attention heatmaps using a supervised pre-trained ViT-B16 model.

## Setup

In [None]:
!pip install ml_collections transformers -q

## Imports

In [None]:
import tensorflow as tf
import pandas as pd
import numpy as np

from pprint import pformat

## Select the master dataframe from [AugReg paper](https://arxiv.org/abs/2106.10270)

In [None]:
with tf.io.gfile.GFile("gs://vit_models/augreg/index.csv") as f:
    df = pd.read_csv(f)

df.head()

## Pick a checkpoint

**Criteria**

* B16 architecture
* Resolution 224
* Pacth size 16
* Supervised pre-training on ImageNet-1k
* Best top-1 accuracy on ImageNet-1k

In [None]:
b16s = df.query(
    'ds=="i1k" & adapt_resolution==224 & adapt_ds=="imagenet2012" & name=="B/16"'
).sort_values("adapt_final_test", ascending=False)
b16s.head()

In [None]:
best_b16_i1k_checkpoint = str(b16s.iloc[0]["adapt_filename"])
b16s.iloc[0]["adapt_filename"], b16s.iloc[0]["adapt_final_test"]

In [None]:
filename = best_b16_i1k_checkpoint

path = f"gs://vit_models/augreg/{filename}.npz"

print(f"{tf.io.gfile.stat(path).length / 1024 / 1024:.1f} MiB - {path}")

## Copy over the checkpoint and load it

In [None]:
!gsutil cp {path} .
local_path = path.split("//")[-1].split("/")[-1]
local_path

In [None]:
with open(local_path, "rb") as f:
    params_jax = np.load(f)
    params_jax = dict(zip(params_jax.keys(), params_jax.values()))

## Implement the model in TF

In [None]:
!git clone -q https://github.com/sayakpaul/deit-tf -b new-block

In [None]:
import sys

sys.path.append("deit-tf")

from vit.vit_models import ViTClassifier
from vit.model_configs import base_config
from utils import helpers
from vit.layers import mha

from transformers.tf_utils import shape_list
from tensorflow import keras
import tensorflow as tf

In [None]:
class ViTB16Extended(ViTClassifier):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def interpolate_pos(self, x, N, h, w):
        class_pos_embed = self.positional_embedding[:, 0]
        patch_pos_embed = self.positional_embedding[:, 1:]
        dim = shape_list(x)[-1]

        # Calculate the resolution to which we need to perform interpolation.
        h0 = h // self.config.patch_size
        w0 = w // self.config.patch_size

        # We avoided the following because UpSampling2D won't support float sizes:
        # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L186
        sqrt_N = tf.math.sqrt(tf.cast(N, "float32"))
        sqrt_N_ceil = tf.cast(tf.math.ceil(sqrt_N), "int32")
        patch_pos_embed = tf.reshape(
            patch_pos_embed, (1, sqrt_N_ceil, sqrt_N_ceil, dim)
        )
        patch_pos_embed = tf.image.resize(patch_pos_embed, (h0, w0), method="bicubic")

        tf.debugging.assert_equal(h0, shape_list(patch_pos_embed)[1])
        tf.debugging.assert_equal(w0, shape_list(patch_pos_embed)[2])

        patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
        return tf.concat([class_pos_embed[None, ...], patch_pos_embed], axis=1)

    def interpolate_pos_embedding(self, x, h, w):
        """Resizes the positional embedding in case there is a mismatch in resolution.
        E.g., using 480x480 images instead of 224x224 for a given patch size.

        Reference:
            https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L174
        """
        num_patches = shape_list(x)[1] - 1  # Exlcuding the cls token.
        N = shape_list(self.positional_embedding)[1] - 1

        # Segregate the cls embedding from the position embeddings.
        class_pos_embed = self.positional_embedding[:, 0]
        patch_pos_embed = self.positional_embedding[:, 1:]
        dim = shape_list(x)[-1]

        pos_embed = tf.cond(
            tf.logical_and(tf.equal(num_patches, N), tf.equal(h, w)),
            lambda: self.positional_embedding,
            lambda: self.interpolate_pos(x, N, h, w),
        )
        return pos_embed

    def call(self, inputs, training):
        n, h, w, c = shape_list(inputs)

        # Create patches and project the patches.
        projected_patches = self.projection(inputs)

        # Append class token.
        cls_token = tf.tile(self.cls_token, (n, 1, 1))
        cls_token = tf.cast(cls_token, projected_patches.dtype)
        projected_patches = tf.concat([cls_token, projected_patches], axis=1)

        # Fetch positional embeddings.
        positional_embedding = self.interpolate_pos_embedding(projected_patches, h, w)

        # Add positional embeddings to the projected patches.
        encoded_patches = (
            positional_embedding + projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # Initialize a dictionary to store attention scores from each transformer
        # block.
        attention_scores = dict()

        # Iterate over the number of layers and stack up blocks of
        # Transformer.
        for transformer_module in self.transformer_blocks:
            # Add a Transformer block.
            encoded_patches, attention_score = transformer_module(encoded_patches)
            attention_scores[f"{transformer_module.name}_att"] = attention_score

        # Final layer normalization.
        representation = self.layer_norm(encoded_patches)

        # Pool representation.
        encoded_patches = representation[:, 0]

        # Classification head.
        output = self.head(encoded_patches)

        if training:
            return output
        else:
            return output, attention_scores

## Initialize the TF model

In [None]:
config = base_config.get_config(model_name="vit_base", projection_dim=768, num_heads=12)

vit_b16_in1k = ViTB16Extended(config)

dummy_inputs = tf.random.normal((2, 224, 224, 3))
outputs, attn_scores = vit_b16_in1k(dummy_inputs)

keys = list(attn_scores.keys())
print(attn_scores[keys[-1]].shape)

## Populate the pre-trained params into the TF model

In [None]:
# Projection.

vit_b16_in1k.layers[0].layers[0].kernel.assign(
    tf.Variable(params_jax["embedding/kernel"])
)
vit_b16_in1k.layers[0].layers[0].bias.assign(tf.Variable(params_jax["embedding/bias"]))
print(" ")

In [None]:
# Positional embedding.

vit_b16_in1k.positional_embedding.assign(
    tf.Variable(params_jax["Transformer/posembed_input/pos_embedding"])
)
print(" ")

In [None]:
# Cls token.

vit_b16_in1k.cls_token.assign(tf.Variable(params_jax["cls"]))
print(" ")


In [None]:
vit_b16_in1k.layers[-2].gamma.assign(
    tf.Variable(params_jax["Transformer/encoder_norm/scale"])
)
vit_b16_in1k.layers[-2].beta.assign(
    tf.Variable(params_jax["Transformer/encoder_norm/bias"])
)

print(" ")

In [None]:
vit_b16_in1k.layers[-1].kernel.assign(tf.Variable(params_jax["head/kernel"]))
vit_b16_in1k.layers[-1].bias.assign(tf.Variable(params_jax["head/bias"]))
print(" ")

In [None]:
def modify_attention_block(tf_component, jax_component, params_jax, config):
    tf_component.kernel.assign(
        tf.Variable(
            params_jax[f"{jax_component}/kernel"].reshape(config.projection_dim, -1)
        )
    )
    tf_component.bias.assign(
        tf.Variable(
            params_jax[f"{jax_component}/bias"].reshape(-1)
        )
    )
    return tf_component

In [None]:
idx = 0
for outer_layer in vit_b16_in1k.layers:
    if isinstance(outer_layer, tf.keras.Model) and outer_layer.name != "projection":
        tf_block = vit_b16_in1k.get_layer(outer_layer.name)
        jax_block_name = f"encoderblock_{idx}"

        # LayerNorm layers.
        layer_norm_idx = 0
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                layer_norm_jax_prefix = (
                    f"Transformer/{jax_block_name}/LayerNorm_{layer_norm_idx}"
                )
                layer.gamma.assign(
                    tf.Variable(params_jax[f"{layer_norm_jax_prefix}/scale"])
                )
                layer.beta.assign(
                    tf.Variable(params_jax[f"{layer_norm_jax_prefix}/bias"])
                )
                layer_norm_idx += 2

        # FFN layers.
        ffn_layer_idx = 0
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.Dense):
                dense_layer_jax_prefix = (
                    f"Transformer/{jax_block_name}/MlpBlock_3/Dense_{ffn_layer_idx}"
                )
                layer.kernel.assign(
                    tf.Variable(params_jax[f"{dense_layer_jax_prefix}/kernel"])
                )
                layer.bias.assign(
                    tf.Variable(params_jax[f"{dense_layer_jax_prefix}/bias"])
                )
                ffn_layer_idx += 1

        # Attention layer.
        for layer in tf_block.layers:
            attn_layer_jax_prefix = (
                f"Transformer/{jax_block_name}/MultiHeadDotProductAttention_1"
            )
            if isinstance(layer, mha.TFViTAttention):
                # Key
                layer.self_attention.key = modify_attention_block(
                    layer.self_attention.key,
                    f"{attn_layer_jax_prefix}/key",
                    params_jax,
                    config,
                )
                # Query
                layer.self_attention.query = modify_attention_block(
                    layer.self_attention.query,
                    f"{attn_layer_jax_prefix}/query",
                    params_jax,
                    config,
                )
                # Value
                layer.self_attention.value = modify_attention_block(
                    layer.self_attention.value,
                    f"{attn_layer_jax_prefix}/value",
                    params_jax,
                    config,
                )
                # Final dense projection
                layer.dense_output.dense.kernel.assign(
                    tf.Variable(
                        params_jax[f"{attn_layer_jax_prefix}/out/kernel"].reshape(
                            -1, config.projection_dim
                        )
                    )
                )
                layer.dense_output.dense.bias.assign(
                    tf.Variable(params_jax[f"{attn_layer_jax_prefix}/out/bias"])
                )

        idx += 1

## Test the populated model on a sample output

In [None]:
import requests
from PIL import Image
from io import BytesIO

In [None]:
def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (224, 224))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = (image_resized - 127.5) / 127.5
    return tf.expand_dims(image_resized, 0).numpy()

def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image

!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)

In [None]:
predictions = vit_b16_in1k(image, training=False)
logits = predictions[0].numpy()
predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
expected_label = "Indian_elephant, Elephas_maximus"
assert (
    predicted_label == expected_label
), f"Expected {expected_label} but was {predicted_label}"

## Video generation utilities

Code for the `VideoGeneratorTF` class has been copied and modified from [here](https://github.com/facebookresearch/dino/blob/main/video_generation.py).

In [None]:
import os
import cv2
import glob
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
FOURCC = {
    "mp4": cv2.VideoWriter_fourcc(*"MP4V"),
    "avi": cv2.VideoWriter_fourcc(*"XVID"),
}

In [None]:
class VideoGeneratorTF:
    def __init__(self, args):
        self.args = args

    def run(self):
        if self.args.input_path is None:
            print(f"Provided input path {self.args.input_path} is non valid.")
            sys.exit(1)
        else:
            if self.args.video_only:
                self._generate_video_from_images(
                    self.args.input_path, self.args.output_path
                )
            else:
                # If input path exists
                if os.path.exists(self.args.input_path):
                    # If input is a video file
                    if os.path.isfile(self.args.input_path):
                        frames_folder = os.path.join(self.args.output_path, "frames-tf")
                        attention_folder = os.path.join(
                            self.args.output_path, "attention-tf"
                        )

                        os.makedirs(frames_folder, exist_ok=True)
                        os.makedirs(attention_folder, exist_ok=True)

                        self._extract_frames_from_video(
                            self.args.input_path, frames_folder
                        )

                        self._inference(
                            frames_folder,
                            attention_folder,
                        )

                        self._generate_video_from_images(
                            attention_folder, self.args.output_path
                        )

                    # If input is a folder of already extracted frames
                    if os.path.isdir(self.args.input_path):
                        attention_folder = os.path.join(
                            self.args.output_path, "attention-tf"
                        )

                        os.makedirs(attention_folder, exist_ok=True)

                        self._inference(self.args.input_path, attention_folder)

                        self._generate_video_from_images(
                            attention_folder, self.args.output_path
                        )

                # If input path doesn't exists
                else:
                    print(f"Provided input path {self.args.input_path} doesn't exists.")
                    sys.exit(1)

    def _extract_frames_from_video(self, inp: str, out: str):
        vidcap = cv2.VideoCapture(inp)
        self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)

        print(f"Video: {inp} ({self.args.fps} fps)")
        print(f"Extracting frames to {out}")

        success, image = vidcap.read()
        count = 0
        while success:
            cv2.imwrite(
                os.path.join(out, f"frame-{count:04}.jpg"),
                image,
            )
            success, image = vidcap.read()
            count += 1

    def _generate_video_from_images(self, inp: str, out: str):
        img_array = []
        attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))

        # Get size of the first image
        with open(attention_images_list[0], "rb") as f:
            img = Image.open(f)
            img = img.convert("RGB")
            size = (img.width, img.height)
            img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

        print(f"Generating video {size} to {out}")

        for filename in tqdm(attention_images_list[1:]):
            with open(filename, "rb") as f:
                img = Image.open(f)
                img = img.convert("RGB")
                img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

        out = cv2.VideoWriter(
            os.path.join(out, "video-tf." + self.args.video_format),
            FOURCC[self.args.video_format],
            self.args.fps,
            size,
        )

        for i in range(len(img_array)):
            out.write(img_array[i])
        out.release()
        print("Done")

    def _preprocess_image(self, image: Image, size: int):
        # Reference: https://www.tensorflow.org/lite/examples/style_transfer/overview
        image = np.array(image)
        image_resized = tf.expand_dims(image, 0)
        shape = tf.cast(tf.shape(image_resized)[1:-1], tf.float32)
        short_dim = min(shape)
        scale = size / short_dim
        new_shape = tf.cast(shape * scale, tf.int32)
        image_resized = tf.image.resize(
            image_resized,
            new_shape,
        )
        image_resized = (image_resized - 127.5) / 127.5
        return (image_resized).numpy()

    def _inference(self, inp: str, out: str):
        print(f"Generating attention images to {out}")

        for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
            with open(img_path, "rb") as f:
                img = Image.open(f)
                img = img.convert("RGB")

            preprocessed_image = self._preprocess_image(img, self.args.resize)
            h, w = (
                preprocessed_image.shape[1]
                - preprocessed_image.shape[1] % self.args.patch_size,
                preprocessed_image.shape[2]
                - preprocessed_image.shape[2] % self.args.patch_size,
            )
            preprocessed_image = preprocessed_image[:, :h, :w, :]

            h_featmap = preprocessed_image.shape[1] // self.args.patch_size
            w_featmap = preprocessed_image.shape[2] // self.args.patch_size

            # Grab the attention scores from the final transformer block.
            logits, attention_score_dict = self.args.model(
                preprocessed_image, training=False
            )
            attentions = attention_score_dict["transformer_block_11_att"].numpy()

            nh = attentions.shape[1]  # number of head

            # we keep only the output patch attention
            attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
            attentions = attentions.reshape(nh, h_featmap, w_featmap)
            attentions = attentions.transpose((1, 2, 0))

            # interpolate
            attentions = tf.image.resize(
                attentions,
                size=(
                    h_featmap * self.args.patch_size,
                    w_featmap * self.args.patch_size,
                ),
            )
            # attentions = tf.keras.layers.UpSampling2D(size=self.args.patch_size, data_format="channels_first")(
            #     attentions[None, ...]
            # )[0].numpy()

            # save attentions heatmaps
            fname = os.path.join(out, "attn-" + os.path.basename(img_path))
            plt.imsave(
                fname=fname,
                arr=sum(
                    attentions[..., i] * 1 / attentions.shape[-1]
                    for i in range(attentions.shape[-1])
                ),
                cmap="inferno",
                format="jpg",
            )

## Gather demo videos to run inference on

In [None]:
# Get a demo video.
!gdown --id 12KScLSdZS5gNvLqoZBenbYeTPaVx4wMj
!gdown --id 1dnPP0QvJ2944GaSE47yMgrt3T0yO4R_R

In [None]:
import ml_collections

args = ml_collections.ConfigDict()

args.model = vit_b16_in1k
args.patch_size = 16
args.input_path = "dog.mp4"
args.output_path = "./"
args.resize = 512
args.video_only = False
args.fps = 30.0
args.video_format = "mp4"

## Extract frames, run inference, prepare a video assembling the extracted attention heatmaps

In [None]:
vg = VideoGeneratorTF(args)
vg.run()