In [None]:
!pip install vit_pytorch

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar10/cifar-10-python.tar.gz to ./cifar10
Epoch 0. Train Loss: 2.082646, Train Acc: 0.220835, Valid Loss: 1.976132, Valid Acc: 0.264551, Time 00:02:17
Epoch 1. Train Loss: 1.997106, Train Acc: 0.255891, Valid Loss: 1.906866, Valid Acc: 0.307617, Time 00:02:55
Epoch 2. Train Loss: 1.947064, Train Acc: 0.275383, Valid Loss: 1.848430, Valid Acc: 0.317285, Time 00:02:58
Epoch 3. Train Loss: 1.892530, Train Acc: 0.298756, Valid Loss: 1.772291, Valid Acc: 0.352246, Time 00:02:54
Epoch 4. Train Loss: 1.851919, Train Acc: 0.316889, Valid Loss: 1.723479, Valid Acc: 0.368750, Time 00:02:54


In [None]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
from vit_pytorch import ViT, SimpleViT

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def main():
    # 加载和预处理数据集
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比，然后缩放所裁剪得到的图像为制定的大小；
         # （即先随机采集，然后对裁剪得到的图像缩放为同一大小） 默认scale=(0.08, 1.0)
         transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像，默认为0.5；
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])

    trans_valid = transforms.Compose(
        [transforms.Resize(256),  # 是按照比例把图像最小的一个边长放缩到256，另一边按照相同比例放缩。
         transforms.CenterCrop(224),  # 依据给定的size从中心裁剪
         transforms.ToTensor(),  # 将PIL Image或者 ndarray 转换为tensor，并且归一化至[0-1]
         # 归一化至[0-1]是直接除以255，若自己的ndarray数据尺度有变化，则需要自行修改。
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])  # 对数据按通道进行标准化，即先减均值，再除以标准差，注意是 hwc

    trainset = torchvision.datasets.CIFAR10(root="./cifar10", train=True, download=True, transform=trans_train)
    #trainset = torchvision.datasets.LFWPeople(root="./lfw", split='train', download=True, transform=trans_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, download=False, transform=trans_valid)
    #testset = torchvision.datasets.LFWPeople(root='./lfw', split='test', download=True, transform=trans_valid)
    testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                             shuffle=False)

   # classes = ('plane', 'car', 'bird', 'cat',
   #            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # 使用预训练模型
    model = get_vit_model()

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()  # 损失函数
    # 只需要优化最后一层参数
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)  # 优化器

    # train
    train(model, trainloader, testloader, 5, optimizer, criterion)


# 计算准确率
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


# 显示图片
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 定义训练函数
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            im = im.to(device)  # (bs, 3, h, w)
            label = label.to(device)  # (bs, h, w)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                im = im.to(device)  # (bs, 3, h, w)
                label = label.to(device)  # (bs, h, w)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)


def get_vit_model():
    v = SimpleViT(
        image_size=224,
        patch_size=16,
        num_classes=10,
        dim=256,
        depth=2,
        heads=4,
        mlp_dim=128
    )
    return v


if __name__ == '__main__':
    main()

