<a href="https://colab.research.google.com/github/seungjun-green/PyTorch-Implementation/blob/main/Densenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implement Densenet-BC in PyTorch

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Dense Block

In [2]:
class DenseBlock(nn.Module):
    def __init__(self, num_layers, inital_channels, growth_rate=32):
        super(DenseBlock, self).__init__()
        self.num_layers = num_layers
        self.inital_channels = inital_channels
        self.growth_rate=growth_rate

        self.blocks = nn.ModuleList()
        self.bns = nn.ModuleList()

        for i in range(num_layers):
          input_channel = inital_channels + i*growth_rate
          self.blocks.append(nn.Conv2d(in_channels=input_channel, out_channels=input_channel, kernel_size=1, stride=1, padding=0))
          self.blocks.append(nn.Conv2d(in_channels=input_channel, out_channels=32, kernel_size=3, stride=1, padding=1))

          self.bns.append(nn.BatchNorm2d(num_features=input_channel))
          self.bns.append(nn.BatchNorm2d(num_features=input_channel))

    def forward(self, x):
      inputs = [x]
      for i in range(0, len(self.blocks), 2):
        x = torch.cat(inputs, dim=1) # concate x_0 to X_l-1

        # As this is DenseNet-BC use BC-ReLU-Conv(1X1)-BC-ReLU-Conv(3X3) structure
        x = self.bns[i](x)
        x = F.relu(x)
        x = self.blocks[i](x)
        x = self.bns[i+1](x)
        x = F.relu(x)
        x = self.blocks[i+1](x)
        inputs.append(x)

      return x

Followings are testing whether created dense blokcs are producing the correct output as shown in Table 1 of the paper.

In [18]:
# first dense block
first_dense_block = DenseBlock(6, 64)
input_tensor = torch.randn(1, 64, 56, 56)
output_tensor = first_dense_block(input_tensor)
print(output_tensor.shape)

# second dense block
second_dense_block = DenseBlock(12, 32)
input_tensor = torch.randn(1, 32, 28, 28)
output_tensor = second_dense_block(input_tensor)
print(output_tensor.shape)

# third dense block
third_dense_block = DenseBlock(24, 32)
input_tensor = torch.randn(1, 32, 14, 14)
output_tensor = third_dense_block(input_tensor)
print(output_tensor.shape)

# fourth dense block
fourth_dense_block = DenseBlock(16, 32)
input_tensor = torch.randn(1, 32, 7, 7)
output_tensor = fourth_dense_block(input_tensor)
print(output_tensor.shape)

torch.Size([1, 32, 56, 56])
torch.Size([1, 32, 28, 28])
torch.Size([1, 32, 14, 14])
torch.Size([1, 32, 7, 7])


## Transition Layer

In [7]:
class Transition(nn.Module):
    def __init__(self, input_channel=32):
        super(Transition, self).__init__()
        self.input_channel = input_channel

        self.conv = nn.Conv2d(in_channels=input_channel, out_channels=input_channel, kernel_size=1, stride=1, padding=0)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
      # x: (None, input_channel, H, W)
      x = self.conv(x)
      # x: (None, input_channel, H, W)
      x = self.pool(x)
      # x: (None, input_channel, H/2, W/2)
      return x

In [8]:
model = Transition()
input_tensor = torch.randn(1, 32, 56, 56)
output = model(input_tensor)
print(output.shape)

torch.Size([1, 32, 28, 28])


## Create DensNet

In [9]:
class DenseNet(nn.Module):
  def __init__(self):
    super(DenseNet, self).__init__()
    self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
    self.maxPool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.dense1 = DenseBlock(6, 64)
    self.trans1 = Transition()
    self.dense2 = DenseBlock(12, 32)
    self.trans2 = Transition()
    self.dense3 = DenseBlock(24, 32)
    self.trans3 = Transition()
    self.dense4 = DenseBlock(16, 32)
    self.trans4 = Transition()
    self.globalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(32, 1000)
    self.softmax = nn.Softmax(dim=1)

  def forward(self, x):
    x = self.conv(x)
    x = self.maxPool(x)
    x = self.dense1(x)
    x = self.trans1(x)
    x = self.dense2(x)
    x = self.trans2(x)
    x = self.dense3(x)
    x = self.trans3(x)
    x = self.dense4(x)
    x = self.trans4(x)
    x = self.globalAvgPool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)
    x = self.softmax(x)
    return x

Testing out model inference

In [10]:
denseNet = DenseNet()
input_tensor = torch.randn(32, 3, 224, 224)
output = denseNet(input_tensor)
print(output.shape)

torch.Size([32, 1000])
