In [3]:
# DenseNet 与 ResNet的主要区别：一个数输出和输入简单的加和，另一个是输出和输出
# 形成了连接的关系——稠密连接
# DenseNet的主要勾践模块： 稠密块（dense block）和过渡层（transition）
# 前者定义了输入和输出是如何连结的，后者则用来控制通道数，使之不会过大

# 稠密块
import time
import torch
from torch import nn,optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def conv_block(in_channels,out_channels):
    blk = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
    )
    return blk

# 稠密块
class DenseBlock(nn.Module):
    def __init__(self,num_convs,in_channels,out_channels):
        super(DenseBlock,self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i*out_channels
            net.append(conv_block(in_c,out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels
        
    def forward(self,x):
        for blk in self.net:
            Y = blk(x)
            x = torch.cat((x,Y),dim=1)# 在通道维上将输入和输出连结
        return x
blk = DenseBlock(2,3,10)
x = torch.rand(4,3,8,8)
y = blk(x)
y.shape

torch.Size([4, 23, 8, 8])

In [6]:
# 由于每个稠密块都会带来通道数的增加，使用过多则会带来过于复杂的模型，
# 过渡层用来控制模型的复杂度，通过1x1的卷积层来减小通道数，并使用步长为2的平均池化层
# 来减半高和宽，从而进一步降低模型复杂度

def transition_block(in_channels,out_channels):
    blk = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels,out_channels,kernel_size=1),
        nn.AvgPool2d(kernel_size=2,stride=2)
    )
    return blk

blk = transition_block(23,10)
blk(y).shape

torch.Size([4, 10, 4, 4])

In [7]:
# DenseNet模型
# DenseNet首先使用同ResNet一样的单卷积层和最大化层
net = nn.Sequential(
    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)

num_channels,growth_rate = 64,32
num_convc_in_dense_blocks = [4,4,4,4]
for i,num_convs in enumerate(num_convc_in_dense_blocks):
    DB = DenseBlock(num_convs,num_channels,growth_rate)
    net.add_module("DenseBlock_%d"%i,DB)
    num_channels = DB.out_channels
    if i!= len(num_convc_in_dense_blocks) - 1:
        net.add_module("transition_block_%d"%i,transition_block(num_channels,num_channels//2))
        net_channels = num_channels//2
        
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

0  output shape:	 torch.Size([1, 64, 48, 48])
1  output shape:	 torch.Size([1, 64, 48, 48])
2  output shape:	 torch.Size([1, 64, 48, 48])
3  output shape:	 torch.Size([1, 64, 24, 24])
DenseBlock_0  output shape:	 torch.Size([1, 192, 24, 24])
transition_block_0  output shape:	 torch.Size([1, 96, 12, 12])


RuntimeError: running_mean should contain 96 elements not 192