### 5.8 网络中的网络（NiN）

#### 5.8.1 NiN块

> 1. **卷积层的输入和输出通常是四维数组(样本, 通道, 宽, 高)**
> 2. **全连接层的输入和输出通常是二维数组(样本, 特征)**
> 3. **`1x1`卷积层可看做全连接层,空间维度(宽和高)上的每个元素相当于样本,通道相当于特征,所以`NiN`使用`1x1`来代替全连接层,从而是空间信息传递到后面的层**
> 4. **`AlexNet`和VGG网络结构局部: 卷积层->卷积层->全连接层->全连接层**
> 5. **`NiN`网络结构局部: 卷积层->`1x1`卷积层->卷积层->`1x1`卷积层**

In [1]:
import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch.utils as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# 由一个卷积层+两个1x1卷积层充当全连接层
# 第一个卷积层参数可自行设定,后面两个参数一般固定
def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU()
    )

#### 5.8.2 NiN模型

> 1. **NiN去掉了AlexNet最后三个全连接层,使用输出通道等于标签类别的NiN块**
> 2. **然后使用全局平均池化层对每个通道的所有元素平均并直接用于分类**
> 3. **此平均池化层的窗口大小等于输入空间维形状的的平均池化层,这样可减小参数个数**

In [3]:
class GlobalAvgPool2d(nn.Module):
    # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])


net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, stride=4, padding=0),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(96, 256, kernel_size=5, stride=1, padding=2),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(256, 384, kernel_size=3, stride=1, padding=1),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Dropout(0.5),
    # 标签类别是10
    nin_block(384, 10, kernel_size=3, stride=1, padding=1),
    GlobalAvgPool2d(),
    # 将四维输出转为二维输出, (批量大小, 10)
    d2l.FlattenLayer()
)

In [4]:
X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children(): 
    X = blk(X)
    print(name, 'output shape: ', X.shape)

0 output shape:  torch.Size([1, 96, 54, 54])
1 output shape:  torch.Size([1, 96, 26, 26])
2 output shape:  torch.Size([1, 256, 26, 26])
3 output shape:  torch.Size([1, 256, 12, 12])
4 output shape:  torch.Size([1, 384, 12, 12])
5 output shape:  torch.Size([1, 384, 5, 5])
6 output shape:  torch.Size([1, 384, 5, 5])
7 output shape:  torch.Size([1, 10, 5, 5])
8 output shape:  torch.Size([1, 10, 1, 1])
9 output shape:  torch.Size([1, 10])


In [5]:
batch_size = 32
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)