# 9. Densely Connected Networks (DenseNet)

In [2]:
import torch 
import torch.nn as nn

import numpy as np

## From ResNet to DenseNet

Similar to the **Taylor expansion**, we can decompose the desired mapping $f$ in ResNet:

$f(x)=x+g(x)$

into more than two terms. Such design can capture more information and is known as the **DenseNet**:

![](http://d2l.ai/_images/densenet-block.svg)

As shown above, instead of **addition** as in ResNet, the outputs of a Dense Block are **concatenated**. 

As a result, we are actually performing a mapping from $x$ to its values after applying an **increasingly complex sequence** of functions:

$$\mathbf{x} \to \left[
\mathbf{x},
f_1(\mathbf{x}),
f_2\left(\left[\mathbf{x}, f_1\left(\mathbf{x}\right)\right]\right), f_3\left(\left[\mathbf{x}, f_1\left(\mathbf{x}\right), f_2\left(\left[\mathbf{x}, f_1\left(\mathbf{x}\right)\right]\right)\right]\right), \ldots\right]$$

In the end, all these functions are combined using a **MLP layer** to reduce the number of features. The last layer of such a chain is **densely connected** to all the previous layers:

![](http://d2l.ai/_images/densenet.svg)

## Dense Block

We first implement a **convolutional block**:

In [5]:
def conv_block(input_channels, num_channels):
    return nn.Sequential(nn.BatchNorm2d(input_channels), 
                         nn.ReLU(),
                         nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1))

A **dense block** consists of multiple convolutional blocks, each with the same number of **output channels**. 

In forward propagation, we **concatenate** the input and output of each convolution block on the channel dimension:

In [6]:
class DenseBlock(nn.Module):
    
    def __init__(self, num_convs, input_channels, num_channels):
        super(DenseBlock, self).__init__()
        layer = []
        for i in range(num_convs):
            layer.append(conv_block(num_channels*i+input_channels, num_channels))
        self.net = nn.Sequential(*layer)

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X,Y), dim=1)
        return X

In the following example, we define a DenseBlock instance with **2 convolution blocks** of **10 output channels**. 

When using an **input with 3 channels**, we will get an **output with $3+2\times10=23$ channels**. The **number of convolution block channels** controls the growth in the number of output channels relative to the number of input channels. This is also referred to as the **growth rate**.

In [19]:
blk = DenseBlock(num_convs=2, input_channels=3, num_channels=10)
X = torch.randn(4, 3, 8, 8)
Y = blk(X)
Y.shape

torch.Size([4, 23, 8, 8])

## Transition Layer

Each dense block will increase the number of channels, which makes the model **more complex**. 

**Transition layers** are used to **reduce the number of channels** via $1\times1$ convolutional layers and to **half the height and width of the output** via average-pooling layers with stride of 2.

In [20]:
def transition_block(input_channels, num_channels):
    return nn.Sequential(nn.BatchNorm2d(input_channels), 
                         nn.ReLU(),
                         nn.Conv2d(input_channels, num_channels, kernel_size=1),
                         nn.AvgPool2d(kernel_size=2, stride=2))

In [21]:
blk = transition_block(23, 10)
blk(Y).shape

torch.Size([4, 10, 4, 4])

## DenseNet Model

In [24]:
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), 
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

In [25]:
num_channels, growth_rate = 64, 32
num_convs_in_dense_blocks = [4, 4, 4, 4]

blks = []
for i, num_convs in enumerate(num_convs_in_dense_blocks):
    
    blks.append(DenseBlock(num_convs, num_channels, growth_rate))
    num_channels += num_convs * growth_rate

    if i != len(num_convs_in_dense_blocks) - 1:
        blks.append(transition_block(num_channels, num_channels//2))
        num_channels = num_channels//2

In [26]:
net = nn.Sequential(b1, *blks,
                    nn.BatchNorm2d(num_channels), 
                    nn.ReLU(),
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.Linear(num_channels, 10))