## Limitations and disclaimers

This notebook shows how to generate attention maps from a video using [DINO ViT-B16 checkpoints](https://github.com/facebookresearch/dino#pretrained-models). We have used existing code snippets from various sources to prepare this demo. We have tried to ensure providing due credits in the respective sections. 

In order to generate salient attention heatmaps from videos it's important to resize individual frames without losing their aspect ratios. So, it's also important to interpolate the positional embeddings within the ViT model accordingly. Currently, we support this feature through a series of hacks that we aren't very proud of, there's likely a better way to accomplish this in TensorFlow. We're yet to figure that out. 

Some gotchas you should know about these hacks before proceeding:

* With interpolation of positional embeddings, it's currently not possible to save the model. 
* This is why, we first assemble the original DINO checkpoints for ViT-B16. Then we implement the DINO variant in TensorFlow and port the pre-trained parameters into the implementation manually. 
* We then run inference and show how to generate attention heatmaps from a given video. 

## Setup

In [None]:
!pip install -q ml_collections
!pip install -q transformers

In [None]:
# Backbone
!wget https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth -q

# Linear layer
!wget https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth -q

## Assemble DINO weights in PyTorch

In [None]:
import torch

In [None]:
backbone_state_dict = torch.load(
    "dino_vitbase16_pretrain.pth", map_location=torch.device("cpu")
)
linear_layer_state_dict = torch.load(
    "dino_vitbase16_linearweights.pth", map_location=torch.device("cpu")
)["state_dict"]

In [None]:
backbone_state_dict.update(linear_layer_state_dict)
backbone_state_dict["head.weight"] = backbone_state_dict.pop("module.linear.weight")
backbone_state_dict["head.bias"] = backbone_state_dict.pop("module.linear.bias")

## Setup model conversion utilities

In [None]:
!git clone -q https://github.com/sayakpaul/deit-tf -b new-block

In [None]:
import sys

sys.path.append("deit-tf")

from vit.vit_models import ViTClassifier
from vit.model_configs import base_config
from utils import helpers
from vit.layers import mha

from transformers.tf_utils import shape_list
from tensorflow import keras
import tensorflow as tf
import ml_collections

## A custom `ViTDINOBase` class to account for DINO's custom representation pooling

In [None]:
class ViTDINOBase(ViTClassifier):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def interpolate_pos(self, x, N, h, w):
        class_pos_embed = self.positional_embedding[:, 0]
        patch_pos_embed = self.positional_embedding[:, 1:]
        dim = shape_list(x)[-1]

        # Calculate the resolution to which we need to perform interpolation.
        h0 = h // self.config.patch_size
        w0 = w // self.config.patch_size

        # Reference:
        # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L186
        sqrt_N = tf.math.sqrt(tf.cast(N, "float32"))
        sqrt_N_ceil = tf.cast(tf.math.ceil(sqrt_N), "int32")
        patch_pos_embed = tf.reshape(
            patch_pos_embed, (1, sqrt_N_ceil, sqrt_N_ceil, dim)
        )
        patch_pos_embed = tf.image.resize(patch_pos_embed, (h0, w0), method="bicubic")

        tf.debugging.assert_equal(h0, shape_list(patch_pos_embed)[1])
        tf.debugging.assert_equal(w0, shape_list(patch_pos_embed)[2])

        patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
        return tf.concat([class_pos_embed[None, ...], patch_pos_embed], axis=1)

    def interpolate_pos_embedding(self, x, h, w):
        """Resizes the positional embedding in case there is a mismatch in resolution.
        E.g., using 480x480 images instead of 224x224 for a given patch size.

        Reference:
            https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L174
        """
        num_patches = shape_list(x)[1] - 1  # Exlcuding the cls token.
        N = shape_list(self.positional_embedding)[1] - 1

        # Segregate the cls embedding from the position embeddings.
        class_pos_embed = self.positional_embedding[:, 0]
        patch_pos_embed = self.positional_embedding[:, 1:]
        dim = shape_list(x)[-1]

        pos_embed = tf.cond(
            tf.logical_and(tf.equal(num_patches, N), tf.equal(h, w)),
            lambda: self.positional_embedding,
            lambda: self.interpolate_pos(x, N, h, w),
        )
        return pos_embed

    def call(self, inputs, training):
        n, h, w, c = shape_list(inputs)

        # Create patches and project the patches.
        projected_patches = self.projection(inputs)

        # Append class token.
        cls_token = tf.tile(self.cls_token, (n, 1, 1))
        if cls_token.dtype != projected_patches.dtype:
            cls_token = tf.cast(cls_token, projected_patches.dtype)
        projected_patches = tf.concat([cls_token, projected_patches], axis=1)

        # Fetch positional embeddings.
        positional_embedding = self.interpolate_pos_embedding(projected_patches, h, w)

        # Add positional embeddings to the projected patches.
        encoded_patches = (
            positional_embedding + projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # Initialize a dictionary to store attention scores from each transformer
        # block.
        attention_scores = dict()

        # Iterate over the number of layers and stack up blocks of
        # Transformer.
        for transformer_module in self.transformer_blocks:
            # Add a Transformer block.
            encoded_patches, attention_score = transformer_module(encoded_patches)
            attention_scores[f"{transformer_module.name}_att"] = attention_score

        # Final layer normalization.
        representation = self.layer_norm(encoded_patches)

        # Pool representation.
        # Reference: https://github.com/facebookresearch/dino/blob/main/eval_linear.py#L259-#L260
        encoded_patches = representation[:, 0]
        encoded_patches_exp = tf.expand_dims(encoded_patches, -1)
        avg_patch_tokens = tf.reduce_mean(representation[:, 1:], 1)
        avg_patch_tokens = tf.expand_dims(avg_patch_tokens, -1)
        output = tf.concat([encoded_patches_exp, avg_patch_tokens], -1)
        output = tf.reshape(output, (n, -1))

        # Classification head.
        output = self.head(output)

        if training:
            return output
        else:
            return output, attention_scores

## Validating the initial architecture

In [None]:
config = base_config.get_config(model_name="vit_base", projection_dim=768, num_heads=12)

vit_dino_base = ViTDINOBase(config)

dummy_inputs = tf.random.normal((2, 224, 224, 3))
outputs, attn_scores = vit_dino_base(dummy_inputs)

keys = list(attn_scores.keys())
print(attn_scores[keys[-1]].shape)

## Port pre-trained DINO params

Reference: https://github.com/sayakpaul/deit-tf/

In [None]:
pt_model_dict = {k: backbone_state_dict[k].numpy() for k in backbone_state_dict}

In [None]:
vit_dino_base.layers[0].layers[0] = helpers.modify_tf_block(
    vit_dino_base.layers[0].layers[0],
    pt_model_dict["patch_embed.proj.weight"],
    pt_model_dict["patch_embed.proj.bias"],
)

# Positional embedding.
vit_dino_base.positional_embedding.assign(tf.Variable(pt_model_dict["pos_embed"]))

# CLS and (optional) Distillation tokens.
# Distillation token won't be present in the models trained without distillation.
vit_dino_base.cls_token.assign(tf.Variable(pt_model_dict["cls_token"]))

# Layer norm layers.
ln_idx = -2
vit_dino_base.layers[ln_idx] = helpers.modify_tf_block(
    vit_dino_base.layers[ln_idx],
    pt_model_dict["norm.weight"],
    pt_model_dict["norm.bias"],
)

# Head layers.
head_layer = vit_dino_base.get_layer("classification_head")
vit_dino_base.layers[-1] = helpers.modify_tf_block(
    head_layer,
    pt_model_dict["head.weight"],
    pt_model_dict["head.bias"],
)

# Transformer blocks.
idx = 0

for outer_layer in vit_dino_base.layers:
    if isinstance(outer_layer, tf.keras.Model) and outer_layer.name != "projection":
        tf_block = vit_dino_base.get_layer(outer_layer.name)
        pt_block_name = f"blocks.{idx}"

        # LayerNorm layers.
        layer_norm_idx = 1
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                layer_norm_pt_prefix = f"{pt_block_name}.norm{layer_norm_idx}"
                layer.gamma.assign(
                    tf.Variable(pt_model_dict[f"{layer_norm_pt_prefix}.weight"])
                )
                layer.beta.assign(
                    tf.Variable(pt_model_dict[f"{layer_norm_pt_prefix}.bias"])
                )
                layer_norm_idx += 1

        # FFN layers.
        ffn_layer_idx = 1
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.Dense):
                dense_layer_pt_prefix = f"{pt_block_name}.mlp.fc{ffn_layer_idx}"
                layer = helpers.modify_tf_block(
                    layer,
                    pt_model_dict[f"{dense_layer_pt_prefix}.weight"],
                    pt_model_dict[f"{dense_layer_pt_prefix}.bias"],
                )
                ffn_layer_idx += 1

        # Attention layer.
        for layer in tf_block.layers:
            (q_w, k_w, v_w), (q_b, k_b, v_b) = helpers.get_tf_qkv(
                f"{pt_block_name}.attn",
                pt_model_dict,
                config,
            )

            if isinstance(layer, mha.TFViTAttention):
                # Key
                layer.self_attention.key = helpers.modify_tf_block(
                    layer.self_attention.key,
                    k_w,
                    k_b,
                    is_attn=True,
                )
                # Query
                layer.self_attention.query = helpers.modify_tf_block(
                    layer.self_attention.query,
                    q_w,
                    q_b,
                    is_attn=True,
                )
                # Value
                layer.self_attention.value = helpers.modify_tf_block(
                    layer.self_attention.value,
                    v_w,
                    v_b,
                    is_attn=True,
                )
                # Final dense projection
                layer.dense_output.dense = helpers.modify_tf_block(
                    layer.dense_output.dense,
                    pt_model_dict[f"{pt_block_name}.attn.proj.weight"],
                    pt_model_dict[f"{pt_block_name}.attn.proj.bias"],
                )

        idx += 1

## Video generation for attention maps

Code copied and modified from the [official code](https://github.com/facebookresearch/dino/blob/main/video_generation.py). 

In [None]:
import os 
import cv2
import glob
import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm
from PIL import Image

In [None]:
FOURCC = {
    "mp4": cv2.VideoWriter_fourcc(*"MP4V"),
    "avi": cv2.VideoWriter_fourcc(*"XVID"),
}

In [None]:
class VideoGeneratorTF:
    def __init__(self, args):
        self.args = args

        # For DeiT, DINO this should be unchanged. For the original ViT-B16 models,
        # input images should be scaled to [-1, 1] range.
        self.norm_layer = keras.layers.Normalization(
            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
            variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
        )

    def run(self):
        if self.args.input_path is None:
            print(f"Provided input path {self.args.input_path} is non valid.")
            sys.exit(1)
        else:
            if self.args.video_only:
                self._generate_video_from_images(
                    self.args.input_path, self.args.output_path
                )
            else:
                # If input path exists
                if os.path.exists(self.args.input_path):
                    # If input is a video file
                    if os.path.isfile(self.args.input_path):
                        frames_folder = os.path.join(self.args.output_path, "frames-tf")
                        attention_folder = os.path.join(
                            self.args.output_path, "attention-tf"
                        )

                        os.makedirs(frames_folder, exist_ok=True)
                        os.makedirs(attention_folder, exist_ok=True)

                        self._extract_frames_from_video(
                            self.args.input_path, frames_folder
                        )

                        self._inference(
                            frames_folder,
                            attention_folder,
                        )

                        self._generate_video_from_images(
                            attention_folder, self.args.output_path
                        )

                    # If input is a folder of already extracted frames
                    if os.path.isdir(self.args.input_path):
                        attention_folder = os.path.join(
                            self.args.output_path, "attention-tf"
                        )

                        os.makedirs(attention_folder, exist_ok=True)

                        self._inference(self.args.input_path, attention_folder)

                        self._generate_video_from_images(
                            attention_folder, self.args.output_path
                        )

                # If input path doesn't exists
                else:
                    print(f"Provided input path {self.args.input_path} doesn't exists.")
                    sys.exit(1)

    def _extract_frames_from_video(self, inp: str, out: str):
        vidcap = cv2.VideoCapture(inp)
        self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)

        print(f"Video: {inp} ({self.args.fps} fps)")
        print(f"Extracting frames to {out}")

        success, image = vidcap.read()
        count = 0
        while success:
            cv2.imwrite(
                os.path.join(out, f"frame-{count:04}.jpg"),
                image,
            )
            success, image = vidcap.read()
            count += 1

    def _generate_video_from_images(self, inp: str, out: str):
        img_array = []
        attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))

        # Get size of the first image
        with open(attention_images_list[0], "rb") as f:
            img = Image.open(f)
            img = img.convert("RGB")
            size = (img.width, img.height)
            img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

        print(f"Generating video {size} to {out}")

        for filename in tqdm(attention_images_list[1:]):
            with open(filename, "rb") as f:
                img = Image.open(f)
                img = img.convert("RGB")
                img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

        out = cv2.VideoWriter(
            os.path.join(out, "video-tf." + self.args.video_format),
            FOURCC[self.args.video_format],
            self.args.fps,
            size,
        )

        for i in range(len(img_array)):
            out.write(img_array[i])
        out.release()
        print("Done")

    def _preprocess_image(self, image: Image, size: int):
        # Reference: https://www.tensorflow.org/lite/examples/style_transfer/overview
        image = np.array(image)
        image_resized = tf.expand_dims(image, 0)
        shape = tf.cast(tf.shape(image_resized)[1:-1], tf.float32)
        short_dim = min(shape)
        scale = size / short_dim
        new_shape = tf.cast(shape * scale, tf.int32)
        image_resized = tf.image.resize(
            image_resized,
            new_shape,
        )
        return self.norm_layer(image_resized).numpy()

    def _inference(self, inp: str, out: str):
        print(f"Generating attention images to {out}")

        for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
            with open(img_path, "rb") as f:
                img = Image.open(f)
                img = img.convert("RGB")

            preprocessed_image = self._preprocess_image(img, self.args.resize)
            h, w = (
                preprocessed_image.shape[1]
                - preprocessed_image.shape[1] % self.args.patch_size,
                preprocessed_image.shape[2]
                - preprocessed_image.shape[2] % self.args.patch_size,
            )
            preprocessed_image = preprocessed_image[:, :h, :w, :]

            h_featmap = preprocessed_image.shape[1] // self.args.patch_size
            w_featmap = preprocessed_image.shape[2] // self.args.patch_size

            # Grab the attention scores from the final transformer block.
            logits, attention_score_dict = self.args.model(
                preprocessed_image, training=False
            )
            attentions = attention_score_dict["transformer_block_11_att"].numpy()

            nh = attentions.shape[1]  # number of head

            # we keep only the output patch attention
            attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
            attentions = attentions.reshape(nh, h_featmap, w_featmap)
            attentions = attentions.transpose((1, 2, 0))

            # interpolate
            attentions = tf.image.resize(
                attentions,
                size=(
                    h_featmap * self.args.patch_size,
                    w_featmap * self.args.patch_size,
                ),
            )

            # save attentions heatmaps
            fname = os.path.join(out, "attn-" + os.path.basename(img_path))
            plt.imsave(
                fname=fname,
                arr=sum(
                    attentions[..., i] * 1 / attentions.shape[-1]
                    for i in range(attentions.shape[-1])
                ),
                cmap="inferno",
                format="jpg",
            )

In [None]:
# Get demo videos.
!gdown --id 12KScLSdZS5gNvLqoZBenbYeTPaVx4wMj
!gdown --id 1dnPP0QvJ2944GaSE47yMgrt3T0yO4R_R

In [None]:
args = ml_collections.ConfigDict()

args.model = vit_dino_base
args.patch_size = 16
args.input_path = "dino.mp4"
args.output_path = "./"
args.resize = 512
args.video_only = False
args.fps = 30.0
args.video_format = "mp4"

In [None]:
vg = VideoGeneratorTF(args)
vg.run()