In [1]:
from mxnet import nd
from mxnet.gluon import nn

In [4]:
def conv_block(channels):
    out = nn.Sequential()
    out.add(
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.Conv2D(channels, kernel_size=3, padding=1)
    )
    return out

class DenseBlock(nn.Block):
    def __init__(self, layers, growth_rate, **kwargs):
        super(DenseBlock, self).__init__(**kwargs)
        self.net = nn.Sequential()
        for i in range(layers):
            self.net.add(conv_block(growth_rate))

    def forward(self, x):
        for layer in self.net:
            out = layer(x)
            x = nd.concat(x, out, dim=1)
        return x

In [5]:
dblk = DenseBlock(2, 10)
dblk.initialize()

X = nd.random.uniform(shape=(4, 3, 8, 8)) # 输出的通道数就是in_channels+growth_rate*layers
dblk(X).shape # 3 + 10 * 2 = 23

(4L, 23L, 8L, 8L)

In [6]:
# 使用拼接的缘故，每经过一次拼接输出通道数可能会激增。
# 为了控制模型复杂度，这里引入一个过渡块，它不仅把输入的长宽减半，同时也使用\(1\times1\)卷积来改变通道数
def transition_block(channels):
    out = nn.Sequential()
    out.add(
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.Conv2D(channels, kernel_size=1),
        nn.AvgPool2D(pool_size=2, strides=2)
    )
    return out

In [8]:
tblk = transition_block(10)
tblk.initialize()

tblk(X).shape

(4L, 10L, 4L, 4L)

In [9]:
# DenseNet 的主体就是交替串联 DenseBlock 和 TransitionBlock
init_channels = 64
growth_rate = 32
block_layers = [6, 12, 24, 16]
num_classes = 10

def dense_net():
    net = nn.Sequential()
    with net.name_scope():
        # First block
        net.add(
            nn.Conv2D(init_channels, kernel_size=7, strides=2, padding=3),
            nn.BatchNorm(),
            nn.Activation('relu'),
            nn.MaxPool2D(pool_size=3, strides=2, padding=1)
        )
        # Dense Block
        channels = init_channels
        for i, layers in enumerate(block_layers):
            net.add(DenseBlock(layers=layers, growth_rate=growth_rate))
            channels += growth_rate*layers
            if i != len(block_layers)-1:
                net.add(transition_block(channels//2))
        # Last Block
        net.add(
            nn.BatchNorm(),
            nn.Activation('relu'),
            nn.AvgPool2D(pool_size=1),
            nn.Flatten(),
            nn.Dense(num_classes)
        )
    return net

In [10]:
import sys
sys.path.append('..')
import utils
from mxnet import gluon
from mxnet import init

In [11]:
train_data, test_data = utils.load_data_fashion_mnist(batch_size=64, resize=32)

ctx = utils.try_gpu()

net = dense_net()
net.initialize(ctx=ctx, init=init.Xavier())

loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
utils.train(train_data, test_data, net, loss, trainer, ctx, num_epochs=1)

('Start training on ', cpu(0))


KeyboardInterrupt: 