This notebook visualizes the attention scores from individual attention heads (from the final Transformer block) as [DINO does](https://arxiv.org/abs/2104.14294). 

Code used in this notebook has been copied and modified from the [official DINO implementation](https://github.com/facebookresearch/dino/blob/main/video_generation.py). 

## Setup

In [1]:
!pip install -q ml_collections

[?25l[K     |████▏                           | 10 kB 16.8 MB/s eta 0:00:01[K     |████████▍                       | 20 kB 8.9 MB/s eta 0:00:01[K     |████████████▋                   | 30 kB 7.3 MB/s eta 0:00:01[K     |████████████████▉               | 40 kB 6.7 MB/s eta 0:00:01[K     |█████████████████████           | 51 kB 4.2 MB/s eta 0:00:01[K     |█████████████████████████▎      | 61 kB 4.9 MB/s eta 0:00:01[K     |█████████████████████████████▍  | 71 kB 5.4 MB/s eta 0:00:01[K     |████████████████████████████████| 77 kB 2.7 MB/s 
[?25h  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone


In [2]:
!pip install -U -q gdown
!gdown --id 12KScLSdZS5gNvLqoZBenbYeTPaVx4wMj
!gdown --id 16_1oDm0PeCGJ_KGBG5UKVN7TsAtiRNrN
!unzip -q vit_dino_base16.zip

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
  Building wheel for gdown (PEP 517) ... [?25l[?25hdone
Downloading...
From: https://drive.google.com/uc?id=12KScLSdZS5gNvLqoZBenbYeTPaVx4wMj
To: /content/dog.mp4
100% 12.8M/12.8M [00:00<00:00, 38.4MB/s]
Downloading...
From: https://drive.google.com/uc?id=16_1oDm0PeCGJ_KGBG5UKVN7TsAtiRNrN
To: /content/vit_dino_base16.zip
100% 326M/326M [00:02<00:00, 149MB/s]


## Imports

In [4]:
import os
import glob
import cv2

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import tensorflow as tf
from tensorflow import keras

In [5]:
FOURCC = {
    "mp4": cv2.VideoWriter_fourcc(*"MP4V"),
    "avi": cv2.VideoWriter_fourcc(*"XVID"),
}

## Video generator class inspired from DINO

In [24]:
class VideoGeneratorTF:
    def __init__(self, args):
        self.args = args

        if self.args.resize != 224:
            raise ValueError(
                "We currently support resizing to only 224x224 resolution :("
            )

        if not self.args.video_only:
            self.model = self.__load_model()

        # For DeiT, DINO this should be unchanged. For the original ViT-B16 models,
        # input images should be scaled to [-1, 1] range.
        self.norm_layer = keras.layers.Normalization(
            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
            variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
        )

    def run(self):
        if self.args.input_path is None:
            print(f"Provided input path {self.args.input_path} is non valid.")
            sys.exit(1)
        else:
            if self.args.video_only:
                self._generate_video_from_images(
                    self.args.input_path, self.args.output_path
                )
            else:
                # If input path exists
                if os.path.exists(self.args.input_path):
                    # If input is a video file
                    if os.path.isfile(self.args.input_path):
                        frames_folder = os.path.join(self.args.output_path, "frames-tf")
                        attention_folder = os.path.join(
                            self.args.output_path, "attention-tf"
                        )

                        os.makedirs(frames_folder, exist_ok=True)
                        os.makedirs(attention_folder, exist_ok=True)

                        self._extract_frames_from_video(
                            self.args.input_path, frames_folder
                        )

                        self._inference(
                            frames_folder,
                            attention_folder,
                        )

                        self._generate_video_from_images(
                            attention_folder, self.args.output_path
                        )

                    # If input is a folder of already extracted frames
                    if os.path.isdir(self.args.input_path):
                        attention_folder = os.path.join(
                            self.args.output_path, "attention-tf"
                        )

                        os.makedirs(attention_folder, exist_ok=True)

                        self._inference(self.args.input_path, attention_folder)

                        self._generate_video_from_images(
                            attention_folder, self.args.output_path
                        )

                # If input path doesn't exists
                else:
                    print(f"Provided input path {self.args.input_path} doesn't exists.")
                    sys.exit(1)

    def _extract_frames_from_video(self, inp: str, out: str):
        vidcap = cv2.VideoCapture(inp)
        self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)

        print(f"Video: {inp} ({self.args.fps} fps)")
        print(f"Extracting frames to {out}")

        success, image = vidcap.read()
        count = 0
        while success:
            cv2.imwrite(
                os.path.join(out, f"frame-{count:04}.jpg"),
                image,
            )
            success, image = vidcap.read()
            count += 1

    def _generate_video_from_images(self, inp: str, out: str):
        img_array = []
        attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))

        # Get size of the first image
        with open(attention_images_list[0], "rb") as f:
            img = Image.open(f)
            img = img.convert("RGB")
            size = (img.width, img.height)
            img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

        print(f"Generating video {size} to {out}")

        for filename in tqdm(attention_images_list[1:]):
            with open(filename, "rb") as f:
                img = Image.open(f)
                img = img.convert("RGB")
                img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

        out = cv2.VideoWriter(
            os.path.join(out, "video-tf." + self.args.video_format),
            FOURCC[self.args.video_format],
            self.args.fps,
            size,
        )

        for i in range(len(img_array)):
            out.write(img_array[i])
        out.release()
        print("Done")

    def _preprocess_image(self, image: Image, size: int):
        image = np.array(image)
        image_resized = tf.expand_dims(image, 0)
        image_resized = tf.image.resize(image_resized, (size, size))
        image_w_ar = tf.image.resize(
            image, (size // 2, size // 2), preserve_aspect_ratio=True
        )

        return self.norm_layer(image_resized).numpy(), image_w_ar.numpy()

    def _inference(self, inp: str, out: str):
        print(f"Generating attention images to {out}")

        for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
            with open(img_path, "rb") as f:
                img = Image.open(f)
                img = img.convert("RGB")

            preprocessed_image, image_w_ar = self._preprocess_image(
                img, self.args.resize
            )

            w_featmap = preprocessed_image.shape[2] // self.args.patch_size
            h_featmap = preprocessed_image.shape[1] // self.args.patch_size

            # Grab the attention scores from the final transformer block.
            logits, attention_score_dict = self.model.predict(preprocessed_image)
            attentions = attention_score_dict["transformer_block_11_att"]

            nh = attentions.shape[1]  # number of head

            # we keep only the output patch attention
            attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
            attentions = attentions.reshape(nh, w_featmap, h_featmap)
    
            def _interpolate_attention_map(attn_img):
                attn_img = attn_img[..., None]
                h, w = image_w_ar.shape[0:-1]
                return (
                    tf.image.resize(attn_img, (h, w), method="nearest")
                    .numpy()
                    .squeeze()
                )

            # save attentions heatmaps
            fname = os.path.join(out, "attn-" + os.path.basename(img_path))
            plt.imsave(
                fname=fname,
                arr=sum(
                    _interpolate_attention_map(attentions[i] * 1 / attentions.shape[0])
                    for i in range(attentions.shape[0])
                ),
                cmap="inferno",
                format="jpg",
            )

    def __load_model(self):
        model = keras.models.load_model(self.args.model_path)
        print("Model loaded.")
        return model

## Run inference

In [25]:
args = ml_collections.ConfigDict()

args.model_path = "vit_dino_base16"
args.patch_size = 16
args.pretrained_weights = ""
args.input_path = "dog.mp4"
args.output_path = "./"
args.resize = 224
args.video_only = False
args.fps = 30.0
args.video_format = "mp4"

In [26]:
vg = VideoGeneratorTF(args)
vg.run()

Model loaded.
Video: dog.mp4 (29.97002997002997 fps)
Extracting frames to ./frames-tf
Generating attention images to ./attention-tf


100%|██████████| 150/150 [01:01<00:00,  2.44it/s]


Generating video (112, 63) to ./


100%|██████████| 149/149 [00:00<00:00, 2761.34it/s]

Done



