# **AlexNet**

<div align=center>
<img width="600" src="../image/5.6_alexnet.png"/>
</div>
<div align=center>AlexNet网络结构</div>

Alex包括八层变换(5层卷积, 3层全连接):
- 第一层卷积: (11, 11)
- 第二层卷积: (5, 5)
- 之后的卷积: (3, 3)


AlexNet将激活函数换成了Relu函数     
AlexNet使用了dropout     
我们实现简化后的AlexNet

In [1]:
import torch
from torch import nn, optim
import torchvision

device = torch.device('cuda')

In [2]:
import sys
sys.path.append(r'C:\D\ProgramFile\jupyter\torch_learn\dive_to_dp\utils') 
import d2lzh as d2l

## **模型定义**

In [3]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=3, stride=2),
                                  # 添加padding为2((5-1)/2)使得输入输出高度一致
                                  nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=3, stride=2),
                                  nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=3, stride=2))
        self.fc = nn.Sequential(nn.Linear(256*5*5, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 10))
    
    def forward(self, x):
        feature = self.conv(x)
        output = self.fc(feature.view(x.shape[0], -1))
        return output

In [4]:
net = AlexNet()
print(net)

AlexNet(
  (conv): Sequential(
    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=6400, out_features=4096, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, o

## **数据读取**

In [5]:
def load_data_fashion_mnist(batch_size, resize=None, root='../datasets/FashionMnisT'):
    """Download the fashion mnist dataset and then load into memory."""
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())
    
    transform = torchvision.transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)

    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_iter, test_iter

In [6]:
batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=(224))

## ** 训练**

In [7]:
lr, nums_epochs = 0.001, 5
optimizer = optim.Adam(params=net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss()

In [8]:
from d2lzh import evaluate_accuracy

In [9]:
net = net.to(device)
for epoch in range(nums_epochs):
    train_l_sum, train_acc_sum, n, batch_count = 0.0, 0.0, 0, 0
    for X, y in train_iter:
        X = X.to(device)
        y = y.to(device)
        y_hat = net(X)
        l = loss(y_hat, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        train_l_sum += l.cpu()
        train_acc_sum += (y_hat.argmax(dim=1) == y).float().cpu().sum()
        n += y.shape[0]
        batch_count += 1
    test_acc = evaluate_accuracy(test_iter, net)
    print(f'epoch{epoch+1}: loss {train_l_sum/batch_count:.4f} train_acc {train_acc_sum / n:.4f} test_acc {test_acc:.4f}')

epoch1: loss 0.6118 train_acc 0.7673 test_acc 0.8590
epoch2: loss 0.3350 train_acc 0.8762 test_acc 0.8858
epoch3: loss 0.2853 train_acc 0.8932 test_acc 0.8959
epoch4: loss 0.2574 train_acc 0.9045 test_acc 0.9008
epoch5: loss 0.2361 train_acc 0.9108 test_acc 0.9087


参数数量显示工具

In [10]:
import torchsummary

In [11]:
torchsummary.summary(net.cuda(), (1, 224, 224)) # 后面的是input_size

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 54, 54]          11,712
              ReLU-2           [-1, 96, 54, 54]               0
         MaxPool2d-3           [-1, 96, 26, 26]               0
            Conv2d-4          [-1, 256, 26, 26]         614,656
              ReLU-5          [-1, 256, 26, 26]               0
         MaxPool2d-6          [-1, 256, 12, 12]               0
            Conv2d-7          [-1, 384, 12, 12]         885,120
              ReLU-8          [-1, 384, 12, 12]               0
            Conv2d-9          [-1, 384, 12, 12]       1,327,488
             ReLU-10          [-1, 384, 12, 12]               0
           Conv2d-11          [-1, 256, 12, 12]         884,992
             ReLU-12          [-1, 256, 12, 12]               0
        MaxPool2d-13            [-1, 256, 5, 5]               0
           Linear-14                 [-