In [39]:
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim
from model.lenet5 import Lenet5
from model.resnet import ResNet18
from torchvision.models import resnet18
from torchvision.models import vgg16
from torchvision.models import vgg19
from torchvision.models import inception_v3
from torchvision.models import mobilenet_v3_large

# 数据加载

In [40]:
batchsz = 32

In [41]:
cifar_train = datasets.CIFAR10(
    'data/cifar', 
    train=True, 
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), 
    download=True
)
cifar_train = DataLoader(
    cifar_train, 
    batch_size=batchsz, 
    shuffle=True
)

Files already downloaded and verified


In [42]:
cifar_test = datasets.CIFAR10(
    'data/cifar', 
    train=False, 
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), 
    download=True
)
cifar_test = DataLoader(
    cifar_test, 
    batch_size=batchsz, 
    shuffle=True
)

Files already downloaded and verified


In [43]:
# x, label = iter(cifar_train).next()
x, label = next(iter(cifar_train))
print('x:', x.shape, 'label:', label.shape)
device = torch.device('cuda')

x: torch.Size([32, 3, 32, 32]) label: torch.Size([32])


# 模型定义

## 自定义模型

In [44]:
class MyLeNet5(nn.Module):
    def __init__(self):
        super(MyLeNet5,self).__init__()
        # 卷积层1：输入图像深度=3，输出图像深度=16，卷积核大小=5*5，卷积步长=1;16表示输出维度，也表示卷积核个数
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
        # 池化层1：采用最大池化，区域集大小=2*2.池化步长=2
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        # 卷积层2
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=5,stride=1)
        # 池化层2
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层1：输入大小=32*5*5，输出大小=120
        self.fc1 = nn.Linear(32*5*5,120)
        # 全连接层2
        self.fc2 = nn.Linear(120,84)
        # 全连接层3
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = F.relu(self.conv1(x))  # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)  # output(16, 14, 14)
        x = F.relu(self.conv2(x))  # output(32, 10, 10)
        x = self.pool2(x)  # output(32, 5, 5)
        x = x.view(-1, 32 * 5 * 5)  # output(32*5*5)
        x = F.relu(self.fc1(x))  # output(120)
        x = F.relu(self.fc2(x))  # output(84)
        x = self.fc3(x)  # output(10)
        return x


In [45]:
model = MyLeNet5().to(device)

## 已定义模型

In [46]:
# model = Lenet5().to(device)
# model = ResNet18().to(device)
# model = resnet18(pretrained=False).to(device)
# model = vgg16(pretrained=False).to(device)
# model = vgg19(pretrained=False).to(device)
# model = inception_v3(pretrained=False).to(device)
# model = mobilenet_v3_large(pretrained=False).to(device)

In [47]:
print(model)

MyLeNet5(
  (conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=800, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


# 模型训练

In [48]:
criteron = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [49]:
def train(model, optimizer, criteron, epoch):
    model.train()
    for batchidx, (x, label) in enumerate(cifar_train):
        # x: [b, 3, 32, 32]
        # lable: [b]
        x, label = x.to(device), label.to(device)
        # logits: [b, 10]
        logits = model(x)
        # loss: tensor scalar
        loss = criteron(logits, label)
        # 梯度清零
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 更新网络参数
        optimizer.step()
        if batchidx % 100 == 0:
            print('epoch:', epoch, 'batchidx:', batchidx, 'loss:', loss.item())

# 模型评估

In [50]:
def eval(model, criteron, epoch):
    model.eval()
    with torch.no_grad():
        total_correct = 0
        total_num = 0
        for x, label in cifar_test:
            # x: [b, 3, 32, 32]
            # label: [b]
            x, label = x.to(device), label.to(device)
            # logits: [b, 10]
            logits = model(x)
            # pred: [b]
            pred = logits.argmax(dim=1)
            # [b] vs [b] => scalar tensor
            correct = torch.eq(pred, label).float().sum().item()
            total_correct += correct
            total_num += x.size(0)
            # print(correct)
        acc = total_correct / total_num
        print('epoch:', epoch, 'acc:', acc)

In [51]:
def main():
    # for epoch in range(1000):
    for epoch in range(5):
        train(model, optimizer, criteron, epoch)
        eval(model, criteron, epoch)

In [52]:
if __name__ == '__main__':
    main()

epoch: 0 batchidx: 0 loss: 2.297548294067383
epoch: 0 batchidx: 100 loss: 1.68293297290802
epoch: 0 batchidx: 200 loss: 1.4933511018753052
epoch: 0 batchidx: 300 loss: 1.564423680305481
epoch: 0 batchidx: 400 loss: 1.4148838520050049
epoch: 0 batchidx: 500 loss: 1.8905407190322876
epoch: 0 batchidx: 600 loss: 1.2580304145812988
epoch: 0 batchidx: 700 loss: 1.503340482711792
epoch: 0 batchidx: 800 loss: 1.7086248397827148
epoch: 0 batchidx: 900 loss: 1.5118569135665894
epoch: 0 batchidx: 1000 loss: 1.3862837553024292
epoch: 0 batchidx: 1100 loss: 0.9959393739700317
epoch: 0 batchidx: 1200 loss: 1.449727177619934
epoch: 0 batchidx: 1300 loss: 1.2866764068603516
epoch: 0 batchidx: 1400 loss: 1.1549031734466553
epoch: 0 batchidx: 1500 loss: 1.3130391836166382
epoch: 0 acc: 0.5828
epoch: 1 batchidx: 0 loss: 1.0471978187561035
epoch: 1 batchidx: 100 loss: 1.0342128276824951
epoch: 1 batchidx: 200 loss: 0.9745177030563354
epoch: 1 batchidx: 300 loss: 1.0224632024765015
epoch: 1 batchidx: 400 