In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
import numpy as np

In [2]:
# 定义图像分块和嵌入模块
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embedding_size=768):
        """
        Args:
            img_size (int): 输入图像的高度和宽度（假设为正方形图像）
            patch_size (int): 每个图像块的高度和宽度
            in_channels (int): 输入图像的通道数
            embedding_size (int): 嵌入维度
        """
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, embedding_size, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_size))
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embedding_size))

    def forward(self, x):
        x = self.projection(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2)  # (B, embed_dim, num_patches)
        x = x.transpose(1, 2)  # (B, num_patches, embed_dim)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, num_patches + 1, embed_dim)
        x = x + self.positional_embedding  # 添加位置嵌入
        return x

In [3]:
# Transformer编码器模块
class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, num_heads, forward_expansion, dropout=0.1):
        """
        Args:
            embedding_size (int): 嵌入维度
            num_heads (int): 多头注意力机制的头数
            forward_expansion (int): 前馈网络扩展因子
            dropout (float): Dropout概率
        """
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embedding_size, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_size, forward_expansion * embedding_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embedding_size, embedding_size),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query):
        """
        Args:
            value, key, query: 输入张量，形状为 (seq_length, batch_size, embedding_size)
        Returns:
            输出张量，形状为 (seq_length, batch_size, embedding_size)
        """
        attn_output, _ = self.attention(query, key, value)
        x = self.norm1(value + self.dropout(attn_output))
        forward_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(forward_output))
        return x

In [4]:
# Vision Transformer模型
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embedding_size=768, num_heads=12,
                 forward_expansion=4, num_layers=12, num_classes=1000, dropout=0.1):
        """
        Args:
            img_size (int): 输入图像的高度和宽度
            patch_size (int): 每个图像块的高度和宽度
            in_channels (int): 输入图像的通道数
            embedding_size (int): 嵌入维度
            num_heads (int): 多头注意力机制的头数
            forward_expansion (int): 前馈网络扩展因子
            num_layers (int): Transformer编码器层数
            num_classes (int): 分类类别数
            dropout (float): Dropout概率
        """
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embedding_size)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(embedding_size, num_heads, forward_expansion, dropout) for _ in range(num_layers)]
        )
        self.to_cls_token = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embedding_size),
            nn.Linear(embedding_size, num_classes)
        )
    
    def forward(self, x):
        """
        Args:
            x: 输入图像张量，形状为 (batch_size, in_channels, img_size, img_size)
        Returns:
            输出张量，形状为 (batch_size, num_classes)
        """
        x = self.patch_embedding(x)  # (B, num_patches + 1, embed_dim)
        for block in self.transformer_blocks:
            x = block(x, x, x)
        x = self.to_cls_token(x[:, 0])  # 取cls token
        x = self.mlp_head(x)
        return x

In [5]:
#模型下载
from modelscope import snapshot_download
from modelscope.utils.constant import REPO_TYPE_DATASET
model_dir = snapshot_download('EFate1006/CIFAR-10', local_dir='./CIFAR-10', repo_type=REPO_TYPE_DATASET)

  from .autonotebook import tqdm as notebook_tqdm
2026-01-27 17:44:27,822 - modelscope - INFO - Not logged-in, you can login for uploadingor accessing controlled entities.
2026-01-27 17:44:28,123 - modelscope - INFO - Fetching dataset repo file list...


Downloading Dataset to directory: /data/lvm_data_48T/zhangningboo/workspace/tsinghua-lm-books/从零构建大模型-算法训练与微调/第四章ViT模型/CIFAR-10


2026-01-27 17:44:28,553 - modelscope - INFO - Got 1 files, start to download ...
Downloading [data/cifar-10-python.tar.gz]: 100%|██████████| 163M/163M [01:18<00:00, 2.16MB/s]
Processing 1 items: 100%|██████████| 1.00/1.00 [01:19<00:00, 79.2s/it]
2026-01-27 17:45:47,737 - modelscope - INFO - Download dataset 'EFate1006/CIFAR-10' successfully.


In [9]:
# 数据加载和预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
dataset = CIFAR10(root='./CIFAR-10', train=True, download=True, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

100%|██████████| 170M/170M [03:21<00:00, 848kB/s]    


In [10]:
# 模型、损失函数和优化器
model = VisionTransformer(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

In [11]:
# 训练函数
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

In [12]:
# 验证函数
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / len(dataloader.dataset)
    print(f'Validation Accuracy: {accuracy:.2f}%')
    return total_loss / len(dataloader)

In [13]:
# 训练和验证
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss = validate(model, val_loader, criterion, device)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

Validation Accuracy: 9.74%
Epoch [1/10], Train Loss: 2.3561, Val Loss: 2.3415


KeyboardInterrupt: 