In [1]:
import torch
import torch.nn as nn

In [8]:
class Bottleneck(nn.Module):
    def __init__(self, in_channels, k):
        super().__init__()

        self.residual = nn.Sequential(nn.BatchNorm2d(in_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(in_channels, 4*k, kernel_size=1, bias=False),
                                      nn.BatchNorm2d(4*k),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(4*k, k, kernel_size=3, padding=1, bias=False))

    def forward(self, x):
        return torch.cat([x, self.residual(x)], 1) # x가 바로 직전 채널 뿐만 아니라 그 전것도 모두 가지고 있음

class Transition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.transition = nn.Sequential(nn.BatchNorm2d(in_channels),
                                        nn.ReLU(inplace=True),
                                        nn.Conv2d(in_channels, out_channels, 1, bias=False),
                                        nn.AvgPool2d(2))

    def forward(self, x):
        return self.transition(x)

#DesneNet-BC
#B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3))
#C stands for compression factor(0< theta ≤1)
class DenseNet(nn.Module):
    def __init__(self, num_block_list, growth_rate, reduction=0.5, num_class=100):
        super().__init__()
        self.k = growth_rate

        inner_channels = 2 * self.k

        self.conv1 = nn.Sequential(nn.Conv2d(3, inner_channels, kernel_size=7, stride=2, padding=3, bias=False),
                                   nn.BatchNorm2d(inner_channels),
                                   nn.ReLU(inplace=True),
                                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        layers = []
        for num_blocks in num_block_list[:-1]:
            layers += [self._make_dense_layers(inner_channels, num_blocks)]
            inner_channels += self.k * num_blocks

            out_channels = int(reduction * inner_channels)
            layers += [Transition(inner_channels, out_channels)]
            inner_channels = out_channels

        layers += [self._make_dense_layers(inner_channels, num_block_list[-1])]
        inner_channels += self.k * num_block_list[-1]
        layers += [nn.BatchNorm2d(inner_channels)]
        layers += [nn.ReLU(inplace=True)]
        self.features = nn.Sequential(*layers)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(inner_channels, num_class)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        output = self.conv1(x)
        output = self.features(output)
        output = self.avgpool(output)
        output = torch.flatten(output, start_dim=1)
        output = self.linear(output)
        return output

    def _make_dense_layers(self, in_channels, nblocks):
        dense_block = []
        for _ in range(nblocks):
            dense_block += [ Bottleneck(in_channels, self.k) ]
            in_channels += self.k
        return nn.Sequential(*dense_block)

In [9]:
def densenet121(**kwargs):
    return DenseNet([6,12,24,16], growth_rate=32, **kwargs)

def densenet169(**kwargs):
    return DenseNet([6,12,32,32], growth_rate=32, **kwargs)

def densenet201(**kwargs):
    return DenseNet([6,12,48,32], growth_rate=32, **kwargs)

def densenet264(**kwargs):
    return DenseNet([6,12,64,48], growth_rate=32, **kwargs)

In [10]:
model = densenet121()
# print(model)
!pip install torchinfo
from torchinfo import summary
summary(model, input_size=(2,3,224,224), device='cpu')

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


Layer (type:depth-idx)                        Output Shape              Param #
DenseNet                                      [2, 100]                  --
├─Sequential: 1-1                             [2, 64, 56, 56]           --
│    └─Conv2d: 2-1                            [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                       [2, 64, 112, 112]         128
│    └─ReLU: 2-3                              [2, 64, 112, 112]         --
│    └─MaxPool2d: 2-4                         [2, 64, 56, 56]           --
├─Sequential: 1-2                             [2, 1024, 7, 7]           --
│    └─Sequential: 2-5                        [2, 256, 56, 56]          --
│    │    └─Bottleneck: 3-1                   [2, 96, 56, 56]           45,440
│    │    └─Bottleneck: 3-2                   [2, 128, 56, 56]          49,600
│    │    └─Bottleneck: 3-3                   [2, 160, 56, 56]          53,760
│    │    └─Bottleneck: 3-4                   [2, 192, 56, 56]          57,920


In [11]:
x = torch.randn(2,3,224,224)
print(model(x).shape)

torch.Size([2, 100])
