In [9]:
import torch
from torch import nn
from torch.utils import data
from torchvision import datasets,transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib_inline
from IPython import display
from torch.functional import F
import time

In [10]:
def conv_block(input_channels,output_channels):
    return nn.Sequential(nn.BatchNorm2d(input_channels),nn.ReLU(),
                         nn.Conv2d(input_channels,output_channels,kernel_size=3,padding=1))
class DenseBlock(nn.Module):
    def __init__(self,num_convs,input_channels,num_channels):
        super().__init__()
        layer=[]
        for i in range(num_convs):
            layer.append(conv_block(num_channels*i+input_channels,num_channels))
        self.net=nn.Sequential(*layer)
    
    def forward(self,X):
        for blk in self.net:
            Y=blk(X)
            X=torch.cat((X,Y),dim=1)
        return X

In [11]:
blk=DenseBlock(2,3,10)
X=torch.randn(4,3,8,8)
Y=blk(X)
Y.shape

torch.Size([4, 23, 8, 8])

In [12]:
def transition_block(input_channels,num_channels):
    return nn.Sequential(nn.BatchNorm2d(input_channels),nn.ReLU(),
                         nn.Conv2d(input_channels,num_channels,kernel_size=1),
                         nn.AvgPool2d(kernel_size=2,stride=2))

In [13]:
blk=transition_block(23,10)
blk(Y).shape

torch.Size([4, 10, 4, 4])

In [14]:
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
                 nn.BatchNorm2d(64),nn.ReLU(),
                 nn.MaxPool2d(kernel_size=3,stride=2,padding=1))

In [15]:
num_channels,growth_rate=64,32
num_convs_in_dense_blocks=[4,4,4,4]
blks=[]
for i,num_convs in enumerate(num_convs_in_dense_blocks):
    blks.append(DenseBlock(num_convs,num_channels,growth_rate))
    num_channels+=num_convs*growth_rate
    if i!=len(num_convs_in_dense_blocks)-1:#稠密块之间插入过渡块
        blks.append(transition_block(num_channels,num_channels//2))
        num_channels=num_channels//2

In [16]:
net=nn.Sequential(b1,*blks,
                  nn.BatchNorm2d(num_channels),nn.ReLU(),
                  nn.AdaptiveMaxPool2d((1,1)),
                  nn.Flatten(),
                  nn.Linear(num_channels,10))

In [17]:
X=torch.rand(size=(1,1,96,96))
for layer in net:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)

Sequential output shape:	 torch.Size([1, 64, 24, 24])
DenseBlock output shape:	 torch.Size([1, 192, 24, 24])
Sequential output shape:	 torch.Size([1, 96, 12, 12])
DenseBlock output shape:	 torch.Size([1, 224, 12, 12])
Sequential output shape:	 torch.Size([1, 112, 6, 6])
DenseBlock output shape:	 torch.Size([1, 240, 6, 6])
Sequential output shape:	 torch.Size([1, 120, 3, 3])
DenseBlock output shape:	 torch.Size([1, 248, 3, 3])
BatchNorm2d output shape:	 torch.Size([1, 248, 3, 3])
ReLU output shape:	 torch.Size([1, 248, 3, 3])
AdaptiveMaxPool2d output shape:	 torch.Size([1, 248, 1, 1])
Flatten output shape:	 torch.Size([1, 248])
Linear output shape:	 torch.Size([1, 10])
