# Densely Connected Networks (DenseNet)

Major difference from resnet is that every layer is connected to the layer preceeding it. Outputs are concatentated instead of added.

## Dense Block

In [3]:
import torch
from torch import nn
from d2l import torch as d2l

In [4]:
def conv_block(num_channels):
    return nn.Sequential(
        nn.LazyBatchNorm2d(), nn.ReLU(),

        # Keep spatial dimensions the same
        nn.LazyConv2d(num_channels, kernel_size=3, padding=1))

In [5]:
class DenseBlock(nn.Module):
    def __init__(self, num_convs, num_channels):
        super(DenseBlock, self).__init__()
        layer = []
        for i in range(num_convs):
            layer.append(conv_block(num_channels))
        self.net = nn.Sequential(*layer)

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)

            # Concatenate along the channel dimension
            X = torch.cat((X, Y), dim=1)

        return X

In [6]:
blk = DenseBlock(2, 10)
X = torch.randn(4, 3, 8, 8)
Y = blk(X)
Y.shape



torch.Size([4, 23, 8, 8])

## Transition Layers

Since each dense block increases the number of channels, 1x1 convolutions called transition layers are used to control the complexity of the model. Height and width also reduced by a factor of 2.

In [8]:
def transition_block(num_channels):
    return nn.Sequential(
        nn.LazyBatchNorm2d(), nn.ReLU(),
        nn.LazyConv2d(num_channels, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2)
    )

In [9]:
blk = transition_block(10)
blk(Y).shape

torch.Size([4, 10, 4, 4])

## DenseNet

In [10]:
class DenseNet(d2l.Classifier):
    def b1(self):
        return nn.Sequential(
            nn.LazyConv2d(out_channels=64, kernel_size=7, stide=2, padding=3),
            nn.LazyBatchNorm2d(), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

    def __init__(self, num_channels=64, growth_rate=32, arch=(4, 4, 4, 4), lr=0.1, num_classes=10):
        # growth rate -> number of channels added in each block
        super().__init__()
        self.save_hyperparameters()
        self.net = nn.Sequential(self.b1())

        for i, num_convs in enumerate(arch):
            self.net.add_module(
                f"dense_blk{i + 1}",
                DenseBlock(num_convs,
                           growth_rate)
            )

            # Calculate number of output channels
            num_channels += num_convs * growth_rate

            # Add a transition layer to halve the number of channels
            if i != len(arch) - 1:
                num_channels //= 2
                self.net.add_module(
                    f'tran_blk{i+1}', 
                    transition_block(num_channels)
                )

        self.net.add_module(
            'last',
            nn.Sequential(
                nn.LazyBatchNorm2d(),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten(),
                nn.LazyLinear(num_classes)
            )
        )