## 引言

视频是由一系列图像组成的。假设你已经拥有一个图像表示模型（如卷积神经网络CNNs、视觉转换器ViTs等）和一个序列模型（如循环神经网络RNNs、长短期记忆网络LSTMs等）。我们请你对这些模型进行调整以实现视频分类任务。直觉上，你会先将图像模型应用于单个帧，然后使用序列模型学习图像表示的顺序。在学习到的序列表示上应用一个分类头部便完成了视频分类模型。[使用CNN-RNN架构进行视频分类](https://keras.io/examples/vision/video_classification/)详细解释了这一方法。更进一步，你还可以构建一种混合式Transformer-based模型来实现视频分类，如[使用Transformers进行视频分类](https://keras.io/examples/vision/video_transformers/)所示。

本示例简要实现了Arnab等人提出的[ViViT：视频视觉转换器](https://arxiv.org/abs/2103.15691)。作者提出了一种基于纯Transformer的模型用于视频分类任务。作者创新性地提出了一个嵌入方案，并为处理视频剪辑设计了许多Transformer变体。在这个示例中，我们将实现该嵌入方案以及其中一个简化版的Transformer架构。

本示例需要TensorFlow 2.6或更高版本，先预装环境。

In [19]:
!pip install ipywidgets

Looking in indexes: https://mirrors.aliyun.com/pypi/simple
Collecting ipywidgets
  Downloading https://mirrors.aliyun.com/pypi/packages/70/1a/7edeedb1c089d63ccd8bd5c0612334774e90cf9337de9fe6c82d90081791/ipywidgets-8.1.2-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.4/139.4 kB[0m [31m406.8 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting widgetsnbextension~=4.0.10 (from ipywidgets)
  Downloading https://mirrors.aliyun.com/pypi/packages/99/bc/82a8c3985209ca7c0a61b383c80e015fd92e74f8ba0ec1af98f9d6ca8dce/widgetsnbextension-4.0.10-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m481.0 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting jupyterlab-widgets~=3.0.10 (from ipywidgets)
  Downloading https://mirrors.aliyun.com/pypi/packages/24/da/db1cb0387a7e4086780aff137987ee924e953d7f91b2a870f994b9b1eeb8/jupyterlab_widgets-3.0.10-py3-none-any.whl (215 kB)
[2K     [

In [1]:
!pip install -qq medmnist

## Imports

In [3]:
import os
import io
import imageio
import medmnist
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# setting seed for reproducibility
SEED = 42
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)

2024-03-04 15:15:31.580097: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-04 15:15:31.608188: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-04 15:15:31.646421: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-04 15:15:31.646439: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-04 15:15:31.646458: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

## 超参数

本例中的超参数是基于特定的超参数搜索选定的。有关这一内容的更多信息可在结论部分找到。

In [4]:
# DATA
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11

# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 60

# TUBELET EMBEDDING
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8

## 数据集

在本示例中，我们使用的数据集是[MedMNIST v2：一个用于二维和三维生物医学图像分类的大规模轻量级基准](https://medmnist.com/)。该数据集中的视频文件体积小且易于训练。

In [9]:
!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/organmnist3d.npz

--2024-03-04 15:31:08--  https://ccclouddisk.oss-cn-hangzhou.aliyuncs.com/organmnist3d.npz
正在解析主机 ccclouddisk.oss-cn-hangzhou.aliyuncs.com (ccclouddisk.oss-cn-hangzhou.aliyuncs.com)... 118.31.219.201
正在连接 ccclouddisk.oss-cn-hangzhou.aliyuncs.com (ccclouddisk.oss-cn-hangzhou.aliyuncs.com)|118.31.219.201|:443... 已连接。
已发出 HTTP 请求，正在等待回应... 200 OK
长度： 32657349 (31M) [application/octet-stream]
正在保存至: ‘organmnist3d.npz’


2024-03-04 15:31:11 (10.8 MB/s) - 已保存 ‘organmnist3d.npz’ [32657349/32657349])



In [10]:
def download_and_prepare_dataset(data_info: dict):
    """
    Utility function to download the dataset and return train/valid/test
    videos and labels.
    Arguments:
        data_info (dict): Dataset metadata
    """
    data_path = "/mnt/workspace/organmnist3d.npz"

    with np.load(data_path) as data:
        # Get videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # Get labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )


# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]

# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

### `tf.data` pipeline

In [11]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels"""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )

    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = "train",
    batch_size: int = BATCH_SIZE,
):
    """Utility function to prepare dataloader"""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )

    return dataloader


trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")

2024-03-04 15:31:42.174862: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:894] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-04 15:31:42.247102: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


## Tubelet Embedding

在ViTs（视觉转换器）中，一幅图像会被划分为多个patch，随后进行空间维度上的展平并投影作为token化方案。对于视频，可以对单个帧重复这一过程。正如作者所建议的，**均匀帧采样**是一种token化方案，即从视频片段中抽样出帧，然后执行简单的ViT token化操作。

![](https://i.imgur.com/aaPyLPX.png)
（图：均匀帧采样 [来源](https://arxiv.org/abs/2103.15691)）

而**Tubelet Embedding**则在捕捉时间信息方面有所不同。从视频中，我们提取出包含多个连续帧的体积。这些体积不仅包含了帧的patch，还包含了时间信息。接着，将这些体积进行展平和投影，以便构建视频tokens。

![](https://i.imgur.com/9G7QTfV.png)
（图：Tubelet Embedding [来源](https://arxiv.org/abs/2103.15691)）

In [12]:
class TubeletEmbedding(layers.Layer):
    def __init__(self, embed_dim, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            padding="VALID",
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

    def call(self, videos):
        projected_patches = self.projection(videos)
        flattened_patches = self.flatten(projected_patches)
        return flattened_patches

## Positional Embedding

这一层会向编码后的视频tokens添加position embedding。

In [13]:
class PositionalEncoder(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim

    def build(self, input_shape):
        _, num_tokens, _ = input_shape
        self.position_embedding = layers.Embedding(
            input_dim=num_tokens, output_dim=self.embed_dim
        )
        self.positions = tf.range(start=0, limit=num_tokens, delta=1)

    def call(self, encoded_tokens):
        # Encode the positions and add it to the encoded tokens
        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens

## Video Vision Transformer


作者提出了四种Video Vision Transformer的方式：

- 空间-时间注意力机制
- 因子分解编码器
- 因子分解自我注意力机制
- 因子分解点积注意力机制

在这个示例中，为了简化起见，我们将实现 **空间-时间注意力机制** 模型。以下代码片段大量借鉴自 [使用Video Vision Transformer进行图像分类](https://keras.io/examples/vision/image_classification_with_vision_transformer/)。

In [14]:
def create_vivit_classifier(
    tubelet_embedder,
    positional_encoder,
    input_shape=INPUT_SHAPE,
    transformer_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    embed_dim=PROJECTION_DIM,
    layer_norm_eps=LAYER_NORM_EPS,
    num_classes=NUM_CLASSES,
):

    # Get the input layer
    inputs = layers.Input(shape=input_shape)
    # Create patches.
    patches = tubelet_embedder(inputs)
    # Encode patches.
    encoded_patches = positional_encoder(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization and MHSA
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
        )(x1, x1)

        # Skip connection
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer Normalization and MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = keras.Sequential(
            [
                layers.Dense(units=embed_dim * 4, activation=tf.nn.gelu),
                layers.Dense(units=embed_dim, activation=tf.nn.gelu),
            ]
        )(x3)

        # Skip connection
        encoded_patches = layers.Add()([x3, x2])

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(units=num_classes, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

## Train

In [15]:
def run_experiment():
    # Initialize model
    model = create_vivit_classifier(
        tubelet_embedder=TubeletEmbedding(
            embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
        ),
        positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
    )

    # Compile the model with the optimizer, loss function
    # and the metrics.
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    # Train the model.
    _ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)

    _, accuracy, top_5_accuracy = model.evaluate(testloader)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return model

In [16]:
model = run_experiment()

Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60
Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60
Epoch 58/60
Epoch 59/60
Epoch 60/60
Test accuracy: 78.36%
Test top 5 accuracy: 97.05%


## Inference

In [20]:
import ipywidgets
NUM_SAMPLES_VIZ = 25
testsamples, labels = next(iter(testloader))
testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]

ground_truths = []
preds = []
videos = []


for i, (testsample, label) in enumerate(zip(testsamples, labels)):
    # Generate gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, (testsample.numpy() * 255).astype("uint8")[..., 0], "GIF", fps=5)
        videos.append(gif.getvalue())

    # Get model prediction
    output = model.predict(tf.expand_dims(testsample, axis=0))[0]
    pred = np.argmax(output, axis=0)

    ground_truths.append(label.numpy().astype("int"))
    preds.append(pred)


def make_box_for_grid(image_widget, fit):
    """
    Make a VBox to hold caption/image for demonstrating
    option_fit values.
    Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
    """
    # Make the caption
    if fit is not None:
        fit_str = "'{}'".format(fit)
    else:
        fit_str = str(fit)

    h = ipywidgets.HTML(value="" + str(fit_str) + "")

    # Make the green box with the image widget inside it
    boxb = ipywidgets.widgets.Box()
    boxb.children = [image_widget]

    # Compose into a vertical box
    vb = ipywidgets.widgets.VBox()
    vb.layout.align_items = "center"
    vb.children = [h, boxb]
    return vb


boxes = []
for i in range(NUM_SAMPLES_VIZ):
    ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
    true_class = info["label"][str(ground_truths[i])]
    pred_class = info["label"][str(preds[i])]
    caption = f"T: {true_class} | P: {pred_class}"

    boxes.append(make_box_for_grid(ib, caption))

ipywidgets.widgets.GridBox(
    boxes, layout=ipywidgets.widgets.Layout(grid_template_columns="repeat(5, 200px)")
)



GridBox(children=(VBox(children=(HTML(value="'T: pancreas | P: pancreas'"), Box(children=(Image(value=b'GIF89a…

## 最后总结

通过基本实现，我们在测试数据集上达到了约79%-80%的Top-1准确率。

有待改进之处：

- 使用数据增强技术。
- 在训练过程中采用更好的正则化方案。
- 应用不同的Transformer模型变体。

本教程中使用的超参数是通过使用[W&B Sweeps](https://docs.wandb.ai/guides/sweeps)进行了超参数搜索后确定的。你可以在这里查看我们的搜索结果[链接](https://wandb.ai/minimal-implementations/vivit/sweeps/66fp0lhz)，并在這裡[链接](https://wandb.ai/minimal-implementations/vivit/reports/Hyperparameter-Tuning-Analysis--VmlldzoxNDEwNzcx)找到我们对结果的快速分析。

我们非常感谢[Weights and Biases](https://wandb.ai/site)项目提供的GPU计算资源支持。