Vision Transformer（ViT）是一种基于Transformer的视觉注意力模型，用于图像分类和其他计算机视觉任务。与传统卷积神经网络（CNN）不同，ViT使用自注意力机制来捕捉图像中的关键特征。该模型将输入图像分割成小块，并将每个块作为序列元素输入到Transformer编码器中。这种方法在训练时可以捕捉到图像中的局部和全局特征，从而提高了模型的性能。ViT已被证明在许多视觉任务上具有很高的准确性，并成为计算机视觉领域的热门研究方向之一。

![](https://raw.githubusercontent.com/xuehangcang/DeepLearning/main/static/ViT-Transformer.png)

In [30]:
import torch
from torch import nn
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [31]:
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_data = datasets.CIFAR10(root='data', train=True,download=True, transform=transform)
test_data = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [32]:
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [33]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


## 1.PatchEmbedding 

In [34]:
class PatchEmbedding(nn.Module):
    """将 2 维图像转化 1 维序列的嵌入向量"""

    def __init__(self,
                 in_channels: int = 3,  # 输入图像的颜色通道数，默认值为3
                 patch_size: int = 16,  # 将输入图像转换为卷积核的大小，默认值为16
                 embedding_dim: int = 768):  # 将图像转换嵌入的维度，默认值为768。
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        # 用卷积核将图像转换为嵌入向量
        self.patcher = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embedding_dim,
            kernel_size=patch_size,
            stride=patch_size,
            padding=0
        )
        # 将卷积后的图像转换为 1 维序列
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)

    def forward(self, x):
        image_resolution = x.shape[-1]  # 输入图像的分辨率
        assert image_resolution % self.patch_size == 0, f"图像必须能够被卷积核整除{image_resolution},{self.patch_size}"
        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched)
        return x_flattened.permute(0, 2, 1)  # 将嵌入调整为最终维度，从 [batch_size, P^2•C, N] 转换为 [batch_size, N, P^2•C]。

## 2.MLPBlock

In [35]:
class MLPBlock(nn.Module):
    """包含一个归一化层的多层感知器块，简称 MLP 块"""

    def __init__(self,
                 embedding_dim: int = 768,  # ViT-Base 表中的隐藏大小D
                 mlp_size: int = 3072,  # ViT-Base 表中的 MLP 大小
                 dropout: float = 0.1):  # ViT-Base 表中的 MLP 的丢弃率

        super(MLPBlock, self).__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)  # 层归一化
        self.mlp = nn.Sequential(  # 多层感知器
            nn.Linear(in_features=embedding_dim, out_features=mlp_size),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(p=dropout),  # 丢弃
            nn.Linear(in_features=mlp_size, out_features=embedding_dim),  # 线性变换
            nn.Dropout(p=dropout)  # 丢弃
        )

    def forward(self, x):
        x = self.layer_norm(x)  # 归一化
        x = self.mlp(x)  # 多层感知器
        return x


## 3.MSABlock

In [36]:
class MSABlock(nn.Module):
    """创建一个多头自注意力块，简称 MSA 块 """

    def __init__(self,
                 embedding_dim: int = 768,  # ViT-Base 表中的隐藏大小D
                 num_heads: int = 12,  # ViT-Base 表中的头数
                 attn_dropout: float = 0):  # ViT-Base 表中的注意力层的丢弃率

        super(MSABlock, self).__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)  # 层归一化
        self.multihead_attn = nn.MultiheadAttention(  # 多头注意力
            embed_dim=embedding_dim,
            num_heads=num_heads,
            dropout=attn_dropout,
            batch_first=True
        )

    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn(query=x, key=x, value=x, need_weights=False)
        return attn_output


## 4.Encoder

In [37]:
class Encoder(nn.Module):
    """创建一个Transformer编码器块"""

    def __init__(self,
                 embedding_dim: int = 768,  # ViT-Base 表中的隐藏大小D
                 num_heads: int = 12,  # ViT-Base 表中的头数
                 mlp_size: int = 3072,  # ViT-Base 表中的 MLP 大小
                 mlp_dropout: float = 0.1,  # ViT-Base 表中的 MLP 的丢弃率
                 attn_dropout: float = 0):  # ViT-Base 表中的注意力层的丢弃率
        super(Encoder, self).__init__()
        self.msa_block = MSABlock(  # 多头自注意力块
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            attn_dropout=attn_dropout
        )
        self.mlp_block = MLPBlock(  # 多层感知器块
            embedding_dim=embedding_dim,
            mlp_size=mlp_size,
            dropout=mlp_dropout
        )

    def forward(self, x):
        x = self.msa_block(x) + x  # 多头自注意力块
        x = self.mlp_block(x) + x  # 多层感知器块
        return x

## 5.VisionTransformer

In [38]:
class VisionTransformer(nn.Module):
    """默认情况下使用 ViT-Base 超参数创建一个 Vision Transformer 架构"""

    def __init__(self,
                 img_size: int = 224,  # ViT论文中表3的训练分辨率, 输入图像的大小，默认值为224
                 in_channels: int = 3,  # 输入图像通道数
                 patch_size: int = 16,  # 将输入图像转换为卷积核的大小，默认值为16
                 num_transformer_layers: int = 12,  # ViT-Base 表中的编码器层数
                 embedding_dim: int = 768,  # ViT-Base 表中的隐藏大小D
                 mlp_size: int = 3072,  # ViT-Base 表中的 MLP 大小
                 num_heads: int = 12,  # ViT-Base 表中的头数
                 attn_dropout: float = 0,  # ViT-Base 表中的注意力层的丢弃率
                 mlp_dropout: float = 0.1,  # ViT-Base 表中的 MLP 的丢弃率
                 embedding_dropout: float = 0.1,  # ViT-Base 表中的嵌入层的丢弃率
                 num_classes: int = 1000):  # ViT-Base 表中的分类数

        super(VisionTransformer, self).__init__()
        assert img_size % patch_size == 0, f"图像必须能够被卷积核整除，{img_size},{patch_size}"

        self.num_patches = (img_size * img_size) // patch_size ** 2  # 计算图像中的卷积核数量
        self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),  # 类别嵌入
                                            requires_grad=True)
        self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches + 1, embedding_dim),  # 位置嵌入
                                               requires_grad=True)
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)  # 嵌入层的丢弃率
        self.patch_embedding = PatchEmbedding(in_channels=in_channels, patch_size=patch_size,  # 卷积嵌入
                                              embedding_dim=embedding_dim)
        self.transformer_encoder = nn.Sequential(  # 编码器
            *[Encoder(
                embedding_dim=embedding_dim,
                num_heads=num_heads,
                mlp_size=mlp_size,
                mlp_dropout=mlp_dropout) for _ in
                range(num_transformer_layers)]
        )

        self.classifier = nn.Sequential(  # 分类器
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )

    def forward(self, x):
        batch_size = x.shape[0]  # 获取批次大小
        class_token = self.class_embedding.expand(batch_size, -1, -1)  # 扩展类别嵌入
        x = self.patch_embedding(x)  # 卷积嵌入
        x = torch.cat((class_token, x), dim=1)  # 将类别嵌入和卷积嵌入连接
        x = self.position_embedding + x  # 位置嵌入
        x = self.embedding_dropout(x)  # 嵌入层的丢弃率
        x = self.transformer_encoder(x)  # 编码器
        x = self.classifier(x[:, 0])  # 分类器
        return x
model = VisionTransformer()
model = model.to(device)

In [39]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

In [40]:
def train(dataloader, model, loss_fn, optimizer):
    """训练"""
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)

        # 计算预测和损失
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


In [41]:
def test(dataloader, model, loss_fn):
    """测试"""
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)

            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= size
    correct /= size
    print(f"测试错误率: {(100 * (1 - correct)):.2f}%, 平均损失: {test_loss:>8f}\n")

In [42]:
epochs = 1
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

Epoch 1
-------------------------------
loss: 7.044998 [    0/50000]
loss: 2.129542 [ 6400/50000]
loss: 2.100320 [12800/50000]
loss: 2.206093 [19200/50000]
loss: 2.299979 [25600/50000]
loss: 2.141643 [32000/50000]
loss: 2.110022 [38400/50000]
loss: 2.037074 [44800/50000]
测试错误率: 79.29%, 平均损失: 0.032642



In [43]:
!nvidia-smi

Wed Apr 26 00:25:39 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 531.41                 Driver Version: 531.41       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                      TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090       WDDM | 00000000:01:00.0  On |                  N/A |
| 77%   60C    P2              256W / 350W|  13845MiB / 24576MiB |     75%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    