# Import Necessary Libraries

In [1]:
import torch
import torch.nn as nn

<torch._C.Generator at 0x26a69392430>

# Transform Images into Embeddings

<center>
<img src="res/ViT-arch-embed.png" alt="Image description" style="width:60%; height:60%;"/>
</center>


In order to feed input images to a Transformer model, we need to convert the images to a sequence of vectors. This is done by splitting the image into a grid of non-overlapping patches, which are then linearly projected to obtain a fixed-size embedding vector for each patch. We can use PyTorch’s nn.Conv2d layer for this purpose:

In [2]:
class PatchEmbeddings(nn.Module):
    """
    Convert the image into patches and then project them into a vector space.
    """
    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.hidden_size = config["hidden_size"]

        # 计算图像中的patch数量 = (图像大小 // patch大小) ^ 2
        # 例如对于512x512的图像,8x8的patch，该计算结果为64^2=4096
        self.num_patches = (self.image_size // self.patch_size) ** 2  

        # 定义一个卷积层，用于将图像转换为patch, 卷积核大小为patch_size x patch_size, 步长为patch_size
        # 对于shape为(1, 3, 512, 512)的输入，输出为(1, 4096, hidden_size)
        self.projection = nn.Conv2d(self.num_channels, 
                                    self.hidden_size, 
                                    kernel_size=self.patch_size, 
                                    stride=self.patch_size)

    def forward(self, x):
        """
        :param x: Input image of shape (batch_size, num_channels, image_size, image_size)
        :return: Projected patches of shape (batch_size, num_patches, hidden_size)
        """
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x    

kernel_size=self.patch_size and stride=self.patch_size are to make sure the layer’s filter is applied to non-overlapping patches.

After the patches are converted to a sequence of embeddings, the [CLS] token is added to the beginning of the sequence, it will be used later in the classification layer to classify the image. The [CLS] token’s embedding is learned during training.

As patches from different positions may contribute differently to the final predictions, we also need a way to encode patch positions into the sequence. We’re going to use learnable position embeddings to add positional information to the embeddings. This is similar to how position embeddings are used in Transformer models for NLP tasks.

In [4]:
class Embeddings(nn.Module):
    """
    Combine the patch embeddings with the class token and position embeddings.
    """
        
    def __init__(self, config):
        super().__init__()

        # 保存配置
        self.config = config
        
        # patch_embeddings：使用 PatchEmbeddings 类，将图像转换成嵌入向量。
        self.patch_embeddings = PatchEmbeddings(config)

        # cls_token：创建一个可学习的 [CLS] 标记，这个标记类似于 BERT 模型中使用的 [CLS] 标记，
        # 通常被添加到输入序列的开头，用于对整个序列进行分类。
        self.cls_token = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))
        
        # position_embeddings：创建用于 [CLS] 标记和补丁嵌入的位置信息嵌入。序列长度加1是因为包含了 [CLS] 标记。
        self.position_embeddings = nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config["hidden_size"]))
        
        # 应用 dropout 以减少过拟合
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        """
        :param x: Input image of shape (batch_size, num_channels, image_size, image_size)
        :return: Projected patches of shape (batch_size, num_patches + 1, hidden_size)
        """

        # 通过 patch_embeddings 将图像转换为嵌入向量
        x = self.patch_embeddings(x)

        # 获取 batch size
        batch_size, _, _ = x.size()

        # cls_tokens：将 [CLS] 标记扩展到与批量大小一致的维度。
        # (1, 1, hidden_size) -> (batch_size, 1, hidden_size)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        
        # 拼接：将 [CLS] 标记连接到输入序列的开头，结果序列长度变为补丁数加1。
        y = torch.cat((cls_tokens, x), dim=1)

        # 位置嵌入：将位置嵌入加到输入序列上。
        y = y + self.position_embeddings
        
        # dropout：应用 dropout 层。
        y = self.dropout(y)
        
        # 返回结果
        return y
    

    """补充说明：
    
    在这个类中，有一个 nn.Parameter 的函数，它创建了一个可学习的参数，与 torch.tensor 创建的张量不同，它会自动注册为模型的参数，
    并可以参与梯度计算并被优化器更新。在这里，我们使用 nn.Parameter 创建了两个参数，一个是用于 [CLS] 标记的 cls_token,
    以及一个用于位置嵌入的 position_embeddings。

    用途：
        * nn.Parameter: 主要用于定义模型中需要优化的可学习参数，比如神经网络层的权重和偏置。
        * torch.tensor: 主要用于定义普通的张量，通常用于存储数据、模型的临时变量等。
    """

At this step, the input image is converted to a sequence of embeddings with positional information and ready to be fed into the transformer layer.

# Multi-head Attention

<center>
<img src="res/ViT-arch-attention.png" alt="Image description" style="width:50%;"/>
</center>

Before going through the transformer encoder, we first explore the multi-head attention module, which is its core component. The multi-head attention is used to compute the interactions between different patches in the input image. The multi-head attention consists of multiple attention heads, each of which is a single attention layer.

Let’s implement a head of the multi-head attention module. The module takes a sequence of embeddings as input and computes query, key, and value vectors for each embedding. The query and key vectors are then used to compute the attention weights for each token. The attention weights are then used to compute new embeddings using a weighted sum of the value vectors. We can think of this mechanism as a soft version of a database query, where the query vectors find the most relevant key vectors in the database, and the value vectors are retrieved to compute the query output.

In [6]:
class AttentionHead(nn.Module):
    
    """
    这个 AttentionHead 类实现了单个的注意力头，通常用于多头注意力机制中。
    注意力机制是许多现代深度学习模型 (特别是 Transformer) 中的核心部分。以下是这个类的详细解释:

    参数：

    * hidden_size: 输入序列的隐藏层维度。
    * attention_head_size: 每个注意力头的输出维度。
    * dropout: 在注意力得分上的 dropout 概率，用于防止过拟合。
    * bias: 是否在线性层中使用偏置，默认为 True。

    组件：
    * query、key、value: 三个线性层，用于将输入投影到 query、key 和 value 向量空间。
    * dropout: 在计算注意力得分之后应用的 dropout 层。
    """

    def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
        super().__init__()

        # hidden_size: 输入序列的隐藏层维度
        self.hidden_size = hidden_size

        # attention_head_size: 每个注意力头的输出维度
        self.attention_head_size = attention_head_size

        # query: 用于将输入投影到 query 向量空间
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        
        # key: 用于将输入投影到 key 向量空间
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        
        # value: 用于将输入投影到 value 向量空间
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)

        # dropout: 在计算注意力得分之后应用的 dropout 层
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        :param x: 输入序列张量，维度为 (batch_size, sequence_length, hidden_size)
        :return: 注意力输出张量，维度为 (batch_size, sequence_length, attention_head_size)
        """

        # 将输入投影到 query、key 和 value
        # 输入相同用于生成 query、key 和 value，这通常称为自注意力（self-attention）。
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        '''补充说明：
        
        计算注意力得分：
        
        * attention_scores = torch.matmul(query, key.transpose(-1, -2))
        计算 query 和 key 之间的点积, 得到注意力得分。key.transpose(-1, -2) 将 key 的最后两个维度互换，以便进行矩阵乘法。
        
        * attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        将注意力得分除以 attention_head_size 的平方根，进行缩放。
        
        * attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        对注意力得分应用 softmax 函数，得到注意力权重。
        
        * attention_probs = self.dropout(attention_probs)
        对注意力权重应用 dropout。
        '''

        # softmax(Q * K.T / sqrt(head_size)) * V
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        # attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_scores = attention_scores / (self.attention_head_size ** 0.5)
        
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)


        '''补充说明：
        
        计算注意力输出：

        * attention_output = torch.matmul(attention_probs, value)
        将注意力权重与 value 相乘，得到注意力输出。
        '''

        attention_output = torch.matmul(attention_probs, value)

        # 返回结果
        return (attention_output, attention_probs)

The outputs from all attention heads are then concatenated and linearly projected to obtain the final output of the multi-head attention module.

In [7]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention module.
    
    MultiHeadAttention 类实现了多头注意力机制，这个模块通常用于 Transformer 编码器中。
    多头注意力机制通过并行计算多个注意力头，使得模型能够关注输入序列中的不同部分，并捕捉更丰富的上下文信息。
    """

    def __init__(self, config):
        super().__init__()
        
        # 读取配置
        self.hidden_size = config["hidden_size"]  # 输入序列的隐藏层维度。
        self.num_attention_heads = config["num_attention_heads"]  # 注意力头的数量。

        # 计算每个注意力头的尺寸
        self.attention_head_size = self.hidden_size // self.num_attention_heads  # 每个注意力头的输出维度。
        self.all_head_size = self.num_attention_heads * self.attention_head_size  # 所有注意力头的输出维度。
        
        # 是否在查询、键和值投影层中使用偏置
        self.qkv_bias = config["qkv_bias"]  # 默认为 True
        
        # 创建一个包含所有注意力头的列表
        self.heads = nn.ModuleList([])  # 包含所有注意力头的列表，每个注意力头都是一个 AttentionHead 实例。
        for _ in range(self.num_attention_heads):
            head = AttentionHead(
                self.hidden_size,
                self.attention_head_size,
                config["attention_probs_dropout_prob"],
                self.qkv_bias
            )
            self.heads.append(head)
        
        # 创建一个线性层将注意力输出投影回隐藏层大小
        self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)  # 用于将多个注意力头的输出拼接后投影回原始隐藏层大小的线性层。
        self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])  # 应用于投影输出的 dropout 层。

    def forward(self, x, output_attentions=False):

        # 计算每个注意力头的注意力输出
        attention_outputs = [head(x) for head in self.heads]
        
        # 拼接所有注意力头的输出
        # 将所有注意力头的输出拼接在一起，形成一个大的注意力输出。
        attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
        
        # 将拼接后的注意力输出投影回隐藏层大小
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        
        # 返回注意力输出和注意力概率（可选）
        if not output_attentions:
            return (attention_output, None)
        else:
            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
            return (attention_output, attention_probs)
    

    """补充说明：
    
        多头注意力机制：通过并行计算多个注意力头，使得模型能够同时关注输入序列中的不同部分，捕捉更多的上下文信息。
        投影和拼接：将多个注意力头的输出拼接在一起，并通过线性层投影回原始隐藏层大小。
        灵活输出：根据 output_attentions 参数，选择性地返回注意力权重，用于进一步的分析或可视化。
    """

# Transformer Encoder

<center>
<img src="res/ViT-arch-encoder.png" alt="Image description" style="width:50%;"/>
</center>

The transformer encoder is made of a stack of transformer layers. Each transformer layer mainly consists of a multi-head attention module that we just implemented and a feed-forward network. To better scale the model and stabilize training, two Layer normalization layers and skip connections are added to the transformer layer.

Let’s implement a transformer layer (referred to as Block in the code as it’s the building block for the transformer encoder). We’ll begin with the feed-forward network, which is a simple two-layer MLP with GELU activation in between.

In [9]:
class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415

    Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
    """

    def forward(self, input):
        # return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
        return 0.5 * input * (1.0 + torch.tanh(0.7978845608028654 * (input + 0.044715 * torch.pow(input, 3.0))))


class MLP(nn.Module):
    """
    A multi-layer perceptron module.
    """

    def __init__(self, config):
        super().__init__()
        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])
        self.activation = NewGELUActivation()
        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x

We have implemented the multi-head attention and the MLP, we can combine them to create the transformer layer. The skip connections and layer normalization are applied to the input of each layer

In [None]:
class Block(nn.Module):
    """
    A single transformer block.
    """

    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.layernorm_1 = nn.LayerNorm(config["hidden_size"])
        self.mlp = MLP(config)
        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])

    def forward(self, x, output_attentions=False):
        # Self-attention
        attention_output, attention_probs = \
            self.attention(self.layernorm_1(x), output_attentions=output_attentions)
        # Skip connection
        x = x + attention_output
        # Feed-forward network
        mlp_output = self.mlp(self.layernorm_2(x))
        # Skip connection
        x = x + mlp_output
        # Return the transformer block's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, attention_probs)

The transformer encoder stacks multiple transformer layers sequentially:

In [None]:
class Encoder(nn.Module):
    """
    The transformer encoder module.
    """

    def __init__(self, config):
        super().__init__()
        # Create a list of transformer blocks
        self.blocks = nn.ModuleList([])
        for _ in range(config["num_hidden_layers"]):
            block = Block(config)
            self.blocks.append(block)

    def forward(self, x, output_attentions=False):
        # Calculate the transformer block's output for each block
        all_attentions = []
        for block in self.blocks:
            x, attention_probs = block(x, output_attentions=output_attentions)
            if output_attentions:
                all_attentions.append(attention_probs)
        # Return the encoder's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, all_attentions)

# VIT for image classification

After inputting the image to the embedding layer and transformer encoder, we obtain new embeddings for both the image patches and the [CLS] token. At this point, the embeddings should have some useful signals for classification after being processed by the transformer encoder. Similar to BERT, we’ll use only the [CLS] token’s embedding to pass to the classification layer.

The classification layer is a fully connected layer that takes the [CLS] embedding as input and outputs logits for each image. The following code implements the ViT model for image classification:

In [None]:
class ViTForClassfication(nn.Module):
    """
    The ViT model for classification.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.image_size = config["image_size"]
        self.hidden_size = config["hidden_size"]
        self.num_classes = config["num_classes"]
        # Create the embedding module
        self.embedding = Embeddings(config)
        # Create the transformer encoder module
        self.encoder = Encoder(config)
        # Create a linear layer to project the encoder's output to the number of classes
        self.classifier = nn.Linear(self.hidden_size, self.num_classes)
        # Initialize the weights
        self.apply(self._init_weights)

    def forward(self, x, output_attentions=False):
        # Calculate the embedding output
        embedding_output = self.embedding(x)
        # Calculate the encoder's output
        encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)
        # Calculate the logits, take the [CLS] token's output as features for classification
        logits = self.classifier(encoder_output[:, 0])
        # Return the logits and the attention probabilities (optional)
        if not output_attentions:
            return (logits, None)
        else:
            return (logits, all_attentions)

To train the model, you can follow the standard steps for training classification models. You can find the training script here.

# Results

As The goal is not to achieve state-of-the-art performance but to demonstrate how the model works intuitively, the model I trained is much smaller than the original ViT models described in the paper, which have at least 12 layers and a hidden size of 768. The model config I used for the training is:

```python
config = {
    "patch_size": 4,
    "hidden_size": 48,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 4 * 48,
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 10,
    "num_channels": 3,
    "qkv_bias": True
}
```

The model is trained on the CIFAR-10 dataset for 100 epochs, with a batch size of 256. The learning rate was set to 0.01, and no learning rate schedule was used. The model is able to achieve 75.5% accuracy after 100 epochs of training. The following shows the training loss, test loss, and accuracy on the test set during training.


<center>
<img src="res/attention.png" alt="Image description" style="width:50%;"/>
</center>


# Conclusion

In this post, we have learned how the Vision Transformer works, from the embedding layer to the transformer encoder and finally to the classification layer. We have also learned how to implement each component of the model using PyTorch.

Since this implementation is not intended for production use, I recommend using more mature libraries for transformers, such as HuggingFace, if you intend to train full-sized models or train them on large datasets.