
论文：[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)

推荐阅读：[Vit系列模型实现](https://github.com/lucidrains/vit-pytorch)

# Vision Transformer，简称ViT
**视觉变换器（Vision Transformer，简称ViT）**是一种新型的深度学习模型，它将传统的自然语言处理（NLP）中的Transformer架构应用于计算机视觉任务。ViT通过将图像分割成一系列小块（patches），将这些小块视为序列化的token输入到Transformer编码器中，从而实现了图像特征的提取和分类。

![alt text](../_img/vit.gif)

如图所示，`VIT`将输入图片平铺成2D的`Patch`序列（16x16），并通过线性投影层将`Patch`转化成固定长度的特征向量序列，对应自然语言处理中的词向量输入。同时，每个`Patch`可以有自己的位置序号，同样通过一个`Embedding`层对应到位置向量。最终`Patch`向量序列和视觉位置向量相加作为`Transfomer Encoder`的模型输入，这点与`BERT`模型类似。

## ViT的核心组件

ViT的核心组件包括**图像块嵌入（Patch Embeddings）**，**位置嵌入（Position Embeddings）**，**分类标记（Classification Token）**，**线性投影的展平图像块（Linear Projection of Flattened Patches）**，以及**Transformer编码器**。这些组件共同工作，使ViT能够有效地处理图像数据，并在多个视觉任务中取得优异的性能。

同样，VIT通过一个可训练的CLS token得到整个图片的表征，并接入全链接层服务于下游的分类任务。当经过大量的数据上预训练，迁移到多个中等或小规模的图像识别基准（ImageNet, CIFAR-100, VTAB 等）时，ViT取得了比CNN系的模型更好的结果，同时在训练时需要的计算资源大大减少。按说，ViT的思路并不复杂，甚至一般人也不难想到，但是为什么真正有效的工作确没有很快出现？不卖关子，VIT成功的秘诀在于大量的数据做预训练，如果没有这个过程，在开源任务上直接训练，VIT网络仍会逊色于具有更强归纳偏置的CNN网络。


### Patch Embeddings
因为Transformer的Embedding层在处理文本时是将每个token编码成一个向量，例如上下文长度为4096，每个token编码成512维的向量，那么Embedding之后就得到了`[4096,512]`维的矩阵。而`VIT`中的`Patch Embedding`思想与此类似，将一张图片分割成`N`个`Patch`，每个`Patch`当作一个`token`，同样Embedding成向量。

对于一张2D图片，图片的维度维$x \in R^{H \times W \times C}$，$C$为RGB通道数，分割成$N$个`Patch`，每个`Patch`的维度为$P \times P \times C$，那么得到的`Patch`序列的维度为$x_p \in R^{N \times (P^2 \cdot C)}$。

接着将`Patch`展平，使用可训练的`Full connected Layer`将维度$P^2 \cdot C$映射到$D$维，这样就得到了`[N,D]`维的矩阵，可以和Transformer的文本Embedding之后的结果一起输入。

下面来看这部分的代码实现：

In [None]:
import logging
import math
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import nn as nn
import torch.nn.functional as F

from .format import Format, nchw_to
from .helpers import to_2tuple
from .trace_utils import _assert

_logger = logging.getLogger(__name__)


class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    output_fmt: Format
    dynamic_img_pad: torch.jit.Final[bool]

    def __init__(
            self,
            img_size: Optional[int] = 224,
            patch_size: int = 16,
            in_chans: int = 3,
            embed_dim: int = 768,
            norm_layer: Optional[Callable] = None,
            flatten: bool = True,
            output_fmt: Optional[str] = None,
            bias: bool = True,
            strict_img_size: bool = True,
            dynamic_img_pad: bool = False,
    ):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)

        if output_fmt is not None:
            self.flatten = False
            self.output_fmt = Format(output_fmt)
        else:
            # flatten spatial dim and transpose to channels last, kept for bwd compat
            self.flatten = flatten
            self.output_fmt = Format.NCHW
        self.strict_img_size = strict_img_size
        self.dynamic_img_pad = dynamic_img_pad

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
        assert self.patch_size
        if img_size is None:
            return None, None, None
        img_size = to_2tuple(img_size)
        grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
        num_patches = grid_size[0] * grid_size[1]
        return img_size, grid_size, num_patches

    def set_input_size(
            self,
            img_size: Optional[Union[int, Tuple[int, int]]] = None,
            patch_size: Optional[Union[int, Tuple[int, int]]] = None,
    ):
        new_patch_size = None
        if patch_size is not None:
            new_patch_size = to_2tuple(patch_size)
        if new_patch_size is not None and new_patch_size != self.patch_size:
            with torch.no_grad():
                new_proj = nn.Conv2d(
                    self.proj.in_channels,
                    self.proj.out_channels,
                    kernel_size=new_patch_size,
                    stride=new_patch_size,
                    bias=self.proj.bias is not None,
                )
                new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
                if self.proj.bias is not None:
                    new_proj.bias.copy_(self.proj.bias)
                self.proj = new_proj
            self.patch_size = new_patch_size
        img_size = img_size or self.img_size
        if img_size != self.img_size or new_patch_size is not None:
            self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)

    def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
        if as_scalar:
            return max(self.patch_size)
        else:
            return self.patch_size

    def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
        """ Get grid (feature) size for given image size taking account of dynamic padding.
        NOTE: must be torchscript compatible so using fixed tuple indexing
        """
        if self.dynamic_img_pad:
            return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
        else:
            return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]

    def forward(self, x):
        B, C, H, W = x.shape
        if self.img_size is not None:
            if self.strict_img_size:
                _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
                _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
            elif not self.dynamic_img_pad:
                _assert(
                    H % self.patch_size[0] == 0,
                    f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
                )
                _assert(
                    W % self.patch_size[1] == 0,
                    f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
                )
        if self.dynamic_img_pad:
            pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
            pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
            x = F.pad(x, (0, pad_w, 0, pad_h))
        x = self.proj(x)
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
        elif self.output_fmt != Format.NCHW:
            x = nchw_to(x, self.output_fmt)
        x = self.norm(x)
        return x

## 可学习的cls token
与BERT模型中的`cls`token类似，VIT中也为图像的`Patch`序列预设了一个可学习的token embedding:$z_0^0=X_{class}$，这一个token在经过`Vision Transformer`编码器之后，输出的特征信息$z_L^0$作为这张图片的图像表示$y$，用来表征图像信息，以及用来分类等。

无论是在预训练还是在微调阶段，都会有一个`classification head`紧接在$z_L^0$后面，用于图像分类。在预训练时，`classification head`为一个单层 `MLP`；在微调时，`classification head`为单个线性层。


## 位置编码 Position Embeddings
 **位置嵌入 $E_{pos} \in \mathbb{R}^{(N+1) \times D}$ 也被加入图像块嵌入，以保留输入图像块之间的空间位置信息**。不同于 CNN，Transformer 需要位置嵌入来编码 patch tokens 的位置信息，这主要是由于 **自注意力** 的 **扰动不变性** **(Permutation-invariant)**，**即打乱 Sequence 中 tokens 的顺序并不会改变结果**。

        相反，若不给模型提供图像块的位置信息，那么模型就需要通过图像块的语义来学习拼图，这就额外增加了学习成本。ViT 论文中对比了几种不同的位置编码方案：

> 1.  **无位置嵌入**
> 2.  **1-D 位置嵌入 (1D-PE)**：考虑把 2-D 图像块视为 1-D 序列
> 3.  **2-D 位置嵌入 (2D-PE)**：考虑图像块的 2-D 位置 (x, y)
> 4.  **相对位置嵌入 (RPE)**：考虑图像块的相对位置

        最后发现如果 **不提供位置编码效果会差**，但其它各种类型的编码效果效果都接近，这主要是因为 ViT 的输入是相对较大的图像块而非像素，所以学习位置信息相对容易很多。

        Transformer 原文中默认采用 **固定位置编码**，ViT 则采用 **标准可学习/训练的 1-D 位置编码嵌入**，因为尚未观察到使用更高级的 2-D-aware 位置嵌入 (附录 D.4) 能够带来显著的性能提升 (当然，后续的很多 ViT 变体也使用了 2-D 位置嵌入)。在输入 Transformer 编码器之前直接


关于`cls token`和`position Embedding`的代码实现在`VisionTransformer`类的`_pos_embed`方法中，如下：

In [None]:

def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
    if self.pos_embed is None:
        return x.view(x.shape[0], -1, x.shape[-1])

    if self.dynamic_img_size:
        B, H, W, C = x.shape
        prev_grid_size = self.patch_embed.grid_size
        pos_embed = resample_abs_pos_embed(
            self.pos_embed,
            new_size=(H, W),
            old_size=prev_grid_size,
            num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
        )
        x = x.view(B, -1, C)
    else:
        pos_embed = self.pos_embed

    to_cat = []
    if self.cls_token is not None:
        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) # 添加cls token
    if self.reg_token is not None:
        to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))

    if self.no_embed_class:
        # deit-3, updated JAX (big vision)
        # position embedding does not overlap with class token, add then concat
        x = x + pos_embed
        if to_cat:
            x = torch.cat(to_cat + [x], dim=1)
    else:
        # original timm, JAX, and deit vit impl
        # pos_embed has entry for class token, concat then add
        if to_cat:
            x = torch.cat(to_cat + [x], dim=1)
        x = x + pos_embed

    return self.pos_drop(x)

## Transformer编码器

ViT的Transformer编码器由多个堆叠的层组成，每层包括多头自注意力机制（MSA）和全连接的前馈神经网络（MLP block）。这些层共同作用，对图像块嵌入进行处理，提取特征，并进行分类。

Attention机制实现代码：

In [None]:
class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


一个Transformer Encoder Block的实现代码如下：

In [None]:
# Transformer Encoder Block
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()

        # 后接于 MHA 的 Layer Norm
        self.norm1 = norm_layer(dim)
        # MHA
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 
                            qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # 后接于 MLP 的 Layer Norm
        self.norm2 = norm_layer(dim)
        # 隐藏层维度
        mlp_hidden_dim = int(dim * mlp_ratio)
        # MLP
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        # MHA + Add & Layer Norm
        x = x + self.drop_path(self.attn(self.norm1(x)))
        # MLP + Add & Layer Norm
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

### ViT 张量维度变化举例** 

<p align="center">
 <img src="../_img/vit_c_exam.png"/>
</p>

> 1.  输入图像 (input images) 的 shape = (b = b, c = 3, h = 256, w = 256)。
> 2.  输入图像 (input images) 被切分 (Split / Divide) 并展平 (Flatten) 为：batch size 仍为 b，通道数 c = **3**、尺寸 P = **32**、个数 N = (256×256) / (32×32) = **64** 的图像块 (Patch)，每个图像块 (Patch) 均有 P²c = 32×32×3 = **3072** 个像素。
> 3.  图像块 (Patch) 馈入线性投影层 (Linear Projection)，得到个数/长度 (length) 为 N = **64**、像素数/大小/维度 (dimension) 为 D = (32×32×1) = **1024** 的图像块嵌入 (Patch Embedding)。
> 4.  每个图像块嵌入 (Patch Embedding) 按元素加 (Element-wise Summary) 入位置向量/嵌入后，尺寸仍为 N×D = **64×1024**。
> 5.  具有位置嵌入的图像块嵌入 (Patch Embedding) 再于长度 (length) 维度 拼接 (Concat) 一个用于预测分类结果的 **1×1024** 可学习嵌入/向量，构成大小为 **65×1024** 完整嵌入 (长度 (length) N+1 = 64+1 = **65**)。
> 6.  完整嵌入输入编码器经过一系列前向处理后，得到尺寸仍为 N×D = **65×1024** 的输出。


# 下面是完整的代码示例
下面主要基于Transformer中的`transformers.models.vit.modeling_vit`代码来测试Vit的不同模块。源码地址：[Source code for transformers.models.vit.modeling_vit](https://huggingface.co/transformers/v4.11.3/_modules/transformers/models/vit/modeling_vit.html)\
Github 地址：[transformers.models.vit](https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py)

建议是每行代码都理解清楚，除了理解代码原理外，很多if的边界条件判断也很重要，写出代码和写出好代码的区别所在。

## 导包 + 基础配置
为了方便直接在这个notebook中运行，下面把需要用到的`utils`和`modeling_outputs`中的函数也都直接给出，做到只需要运行这个notebook就可。

In [None]:
import collections.abc
import math

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

# from ...activations import ACT2FN
# from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
# from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
# from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
# from ...utils import logging
# from .configuration_vit import ViTConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"

VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/vit-base-patch16-224",
    # See all ViT models at https://huggingface.co/models?filter=vit
]


# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
    if isinstance(x, collections.abc.Iterable):
        return x
    return (x, x)

In [None]:
class ViTConfig():
    model_type = "vit"

    def __init__(
        self,
        hidden_size=768, # Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers=12, # Number of hidden layers in the Transformer encoder.
        num_attention_heads=12, # Number of attention heads for each attention layer in the Transformer encoder.
        intermediate_size=3072, # The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        hidden_act="gelu", # The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,             `"relu"`, `"selu"` and `"gelu_new"` are supported.
        hidden_dropout_prob=0.0, # The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
        attention_probs_dropout_prob=0.0, # The dropout ratio for the attention probabilities.
        initializer_range=0.02, # The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps=1e-12, # The epsilon used by the layer normalization layers.
        image_size=224, # The size of each image (resolution).
        patch_size=16, # The size of the patch to be extracted from the image.
        num_channels=3, # The number of channels in the input images.
        qkv_bias=True, # Whether to include bias parameters in the Q, K, V projections.
        encoder_stride=16, # The stride to use when extracting image patches.
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.qkv_bias = qkv_bias
        self.encoder_stride = encoder_stride

vit_config = ViTConfig()

# Vit  Embedding 部分
首先是`Patch Embedding`的代码实现，其次是加上`cls token`和`position embedding`的代码实现。

In [None]:
class ViTPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    input:(batch_size, num_channels, height, width)
    output:(batch_size, seq_length, hidden_size)
    """
    
    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size

        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        # channel -> hidden size
        # 通过卷积层将输入的图片转换为hidden_size
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) 

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        if num_channels != self.num_channels: # 保证输入的图片通道数和配置文件中的通道数一致
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        if not interpolate_pos_encoding: # 如果不插值
            if height != self.image_size[0] or width != self.image_size[1]: # 那么图片的大小必须和配置文件中的大小一致
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
                
        # embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        print(f"pixel_values:{pixel_values.shape}")
        projection = self.projection(pixel_values) # 8, 64, 16, 16
        print(f"projection:{projection.shape}")
        flatten = projection.flatten(2)            # 8, 64, 196
        print(f"flatten:{flatten.shape}")
        embeddings = flatten.transpose(1, 2)       # 8, 196, 64
        print(f"embeddings:{embeddings.shape}")
        return embeddings


简单测试一下,使用`[batch_size=8, num_channels=3, image_size=224, image_size=224]`,注意看`forward`过程中的print的维度变化。

In [None]:
test_input = torch.randn(8, 3, 224, 224) # [batch_size, num_channels, image_size, image_size]
test_input.shape
ViT_Patch_Embedding = ViTPatchEmbeddings(vit_config)
test_output = ViT_Patch_Embedding(test_input)

下面代码是一个完整的Embedding过程，首先经过一次`Patch_embedding`,再加上`cls token`，再加上`position_embedding`，最后`dropout`后输出。

In [None]:
class ViTEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
    """

    def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
        super().__init__()

        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) # shape: [1, 1, hidden_size]
        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
        self.patch_embeddings = ViTPatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.patch_size = config.patch_size
        self.config = config

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        这个模型允许插值预训练的位置编码，以便能够在更高分辨率的图像上使用模型。这个方法也被修改以支持torch.jit跟踪。
        """

        num_patches = embeddings.shape[1] - 1 # Patch数量
        num_positions = self.position_embeddings.shape[1] - 1 # 位置编码数量

        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
        # 如果正在跟踪，则始终进行插值，以确保导出的模型适用于动态输入形状
        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
            return self.position_embeddings

        class_pos_embed = self.position_embeddings[:, :1]
        patch_pos_embed = self.position_embeddings[:, 1:]

        dim = embeddings.shape[-1]

        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions = torch_int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode="bicubic",
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        if bool_masked_pos is not None:
            seq_length = embeddings.shape[1]
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # add the [CLS] token to the embedded patch tokens
        # 添加[CLS] token到嵌入的patch tokens中
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # add positional encoding to each token
        if interpolate_pos_encoding: # 如果插值位置编码
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings) # dropout

        return embeddings

下面测试一下Embedding结果：

In [None]:
test_embed_config = ViTConfig()
ViT_embedding = ViTEmbeddings(test_embed_config, use_mask_token=False)
print(ViT_embedding)

test_input = torch.randn(8, 3, 224, 224)

x_cls_embd = ViT_embedding(test_input)

# Atention 注意力机制
原理不多介绍，看图：
<p align="center">
 <img src="./_img/sdpa.png" width="49%"/>
 <img src="./_img/mha.png" width="49%"/>
</p>
