<a href="https://colab.research.google.com/github/rng70/ViViT/blob/main/video_vision_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Video Vision Transformer
**Original Author:** Aritra Roy Gosthipaty, Ayush Thakur <br>
**Author:** Al Arafat Tanin


# Introduction

Videos are sequences of images. Let's assume you have an image representation model (CNN, ViT, etc.) and a sequence model (RNN, LSTM, etc.) at hand. We ask you to tweak the model for video classification. The simplest approach would be to apply the image model to individual frames, use the sequence model to learn sequences of image features, then apply a classification head on the learned sequence representation. The Keras example [Video Classification with a CNN-RNN Architecture](https://keras.io/examples/vision/video_classification/) explains this approach in detail. Alernatively, you can also build a hybrid Transformer-based model for video classification as shown in the Keras example [Video Classification with Transformers](https://keras.io/examples/vision/video_transformers/).

In this example, we minimally implement [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Arnab et al., a pure Transformer-based model for video classification. The authors propose a novel embedding scheme and a number of Transformer variants to model video clips. We implement the embedding scheme and one of the variants of the Transformer architecture, for simplicity.

This example requires TensorFlow 2.6 or higher, and the medmnist package, which can be installed by running the code cell below.

In [None]:
!pip install -qq medmnist

# Imports

In [None]:
import os
import io
import imageio
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Setting seed for reproducibility
SEED = 42
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)

# Hyperparameters
The hyperparameters are chosen via hyperparameter search. You can learn more about the process in the "conclusion" section.

In [None]:
# DATA
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11

# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 60

# TUBELET EMBEDDING
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8

# Dataset
For our example we use the [MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification](https://medmnist.com/) dataset. The videos are lightweight and easy to train on.

In [None]:
def downlaod_and_prepare_dateset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """

    data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])

    with np.load(data_path) as data:
      # Get videos
      train_videos = data['train_images']
      valid_videos = data['val_images']
      test_videos = data['test_images']

      # Get labels
      train_labels = data['train_labels'].flatten()
      valid_labels = data['val_labels'].flatten()
      test_labels = data['test_labels'].flatten()

      return (
          (train_videos, train_labels),
          (valid_videos, valid_labels),
          (test_videos, test_labels),
      )

# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]

# Get the dataset
prepared_dataset = downlaod_and_prepare_dateset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

## tf.data pipeline

In [None]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels."""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )
    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = "train",
    batch_size: int = BATCH_SIZE,
):
    """Utility function to prepare the dataloader."""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataloader


trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")

# Tubelet Embedding
In ViTs, an image is divided into patches, which are then spatially flattened, a process known as tokenization. For a video, one can repeat this process for individual frames. Uniform frame sampling as suggested by the authors is a tokenization scheme in which we sample frames from the video clip and perform simple ViT tokenization.
$ \frac{**uniform frame sampling**}{Uniform Frame Sampling [Source]() $

## Tubelet Embedding

In ViTs, an image is divided into patches, which are then spatially
flattened, a process known as tokenization. For a video, one can
repeat this process for individual frames. **Uniform frame sampling**
as suggested by the authors is a tokenization scheme in which we
sample frames from the video clip and perform simple ViT tokenization.

| ![uniform frame sampling](https://github.com/rng70/ViViT/blob/main/images/uniform%20frame%20sampling.png) |
| :--: |
| Uniform Frame Sampling [Source](https://arxiv.org/abs/2103.15691) |

**Tubelet Embedding** is different in terms of capturing temporal
information from the video.
First, we extract volumes from the video -- these volumes contain
patches of the frame and the temporal information as well. The volumes
are then flattened to build video tokens.

| ![tubelet embedding](https://github.com/rng70/ViViT/blob/main/images/tubelet%20embedding.png) |
| :--: |
| Tubelet Embedding [Source](https://arxiv.org/abs/2103.15691) |