# Reproduce DL
## Automated Pavement Crack Segmentation

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary


class ResBlock(nn.Module):
    """
    Residual block (Green)
    ---
    A special case of highway network without any gates
    in their skip connection. Thus allowing the flow of
    memory (or info) from initial layers to last layers.
    """

    def __init__(self, in_channels, out_channels, stride):
        super(ResBlock, self).__init__()
        self.stride_one = (stride == 1)

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn_shortcut = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)

        if not self.stride_one:
            x = self.bn_shortcut(self.conv_shortcut(x))
        y += x

        return self.relu(y)


class CSEBlock(nn.Module):
    """
    Spatial Squeeze and Channel Excitation block
    ---
    Channel-wise focus

    Recalibrates the channels by incorporating global
    spatial information. It provides a receptive field
    of whole spatial extent at the fc's.

    Assign each channel (feature) of a convolutional 
    block (feature map) a different weightage 
    (excitation) based on how important each channel 
    is (squeeze) instead of equally weighing each
    feature. This improves channel interdependencies.
    """

    def __init__(self, in_channels, reduction=2):
        super(CSEBlock, self).__init__()
        # Global pooling == AdaptiveAvgPool2d
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_channels // reduction, in_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, num_channels, _, _ = x.size()

        avg_pool_x = self.avg_pool(x).view(batch_size, num_channels)

        y = self.fc1(avg_pool_x)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y)

        return x * y.view(batch_size, num_channels, 1, 1)


class SSEBlock(nn.Module):
    """
    Channel Squeeze and Spatial Excitation block
    ---
    Spatial-wise focus

    It behaves like a spatial attention map indicating
    where the network should focus more to aid the
    segmentation.

    Assign importance to spatial locations sort of 
    telling where features are better to focus
    instead of reweighing which features are more
    important.
    """

    def __init__(self, in_channels):
        super(SSEBlock, self).__init__()
        # Output channel = 1, 1x1 convolution
        self.conv = nn.Conv2d(in_channels, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, num_channels, H, W = x.size()

        y = self.conv(x)
        y = self.sigmoid(y)

        return x * y.view(batch_size, 1, H, W)


class SCSEBlock(nn.Module):
    """
    Spatial and Channel Squeeze and Excitation block
    ---
    Return the block with the most promising values.
    """

    def __init__(self, in_channels, reduction=2):
        super(SCSEBlock, self).__init__()
        self.CSE = CSEBlock(in_channels, reduction)
        self.SSE = SSEBlock(in_channels)

    def forward(self, x):
        return torch.max(self.CSE(x), self.SSE(x))


class UpsampBlock(nn.Module):
    """
    Upsampling block
    ---
    Includes SCSEBlock
    """

    def __init__(self, in_channels, out_channels):
        super(UpsampBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_channels)
        self.scse = SCSEBlock(in_channels)

    def forward(self, x):
        y = self.relu(x)
        y = self.bn(y)

        return self.scse(y)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 3 input image channels, 64 output channels, 7x7 convolution, stride 2
        # -- Blue --
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)

        # -- Green --
        # maxpooling stride default is kernel size, ceil_mode added for padding
        self.maxpool_1_to_rs1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        # Residual block 1
        self.conv_rs1_1 = ResBlock(64, 64, 1)
        self.conv_rs1_2 = ResBlock(64, 64, 1)
        self.conv_rs1_3 = ResBlock(64, 64, 1)

        # Residual block 2
        self.conv_rs2_1 = ResBlock(64, 128, 2)
        self.conv_rs2_2 = ResBlock(128, 128, 1)
        self.conv_rs2_3 = ResBlock(128, 128, 1)
        self.conv_rs2_4 = ResBlock(128, 128, 1)

        # Residual block 3
        self.conv_rs3_1 = ResBlock(128, 256, 2)
        self.conv_rs3_2 = ResBlock(256, 256, 1)
        self.conv_rs3_3 = ResBlock(256, 256, 1)
        self.conv_rs3_4 = ResBlock(256, 256, 1)
        self.conv_rs3_5 = ResBlock(256, 256, 1)
        self.conv_rs3_6 = ResBlock(256, 256, 1)

        # Residual block 4
        self.conv_rs4_1 = ResBlock(256, 512, 2)
        self.conv_rs4_2 = ResBlock(512, 512, 1)
        self.conv_rs4_3 = ResBlock(512, 512, 1)

        # -- Yellow --
        # 64/128/256 input image channels, 128 output channels, 1x1 convolution, stride 1
        # Green (residual) block to Yellow block (Up to down)
        self.conv_gr_to_yel_1 = nn.Conv2d(64, 128, 1, stride=1)
        self.conv_gr_to_yel_2 = nn.Conv2d(64, 128, 1, stride=1)
        self.conv_gr_to_yel_3 = nn.Conv2d(128, 128, 1, stride=1)
        self.conv_gr_to_yel_4 = nn.Conv2d(256, 128, 1, stride=1)

        # -- Purple --
        # 512 input image channels, 512 output channels, 1x1 convolution, stride 1
        # Green (residual) block to Purple block
        self.conv_gr_to_purp = nn.Conv2d(512, 512, 1, stride=1)
        # Magenta block to Purple block
        self.conv_mag_to_purp_1 = UpsampBlock(256, 256)
        self.conv_mag_to_purp_2 = UpsampBlock(256, 256)
        self.conv_mag_to_purp_3 = UpsampBlock(256, 256)
        self.conv_mag_to_purp_4 = UpsampBlock(256, 256)

        # -- Magenta --
        # 512/256 input image channels, 128 output channels, 2x2 convolution, stride 2
        # Purple (residual) block to Magenta block (Down to up)
        self.conv_purp_to_mag_1 = nn.ConvTranspose2d(512, 128, 2, stride=2)
        self.conv_purp_to_mag_2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv_purp_to_mag_3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv_purp_to_mag_4 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv_purp_to_mag_5 = nn.ConvTranspose2d(256, 1, 2, stride=2)

    def forward(self, x):
        # Add "Same"-padding like Keras.
        x = F.pad(x, (2, 3, 2, 3))

        # Traverse down the architecture (ResNet34)
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu1(y)

        y_y1 = self.conv_gr_to_yel_1(y)

        y = self.maxpool_1_to_rs1(y)
        y = self.conv_rs1_1(y)
        y = self.conv_rs1_2(y)
        y = self.conv_rs1_3(y)
        
        y_y2 = self.conv_gr_to_yel_2(y)

        y = self.conv_rs2_1(y)
        y = self.conv_rs2_2(y)
        y = self.conv_rs2_3(y)
        y = self.conv_rs2_4(y)

        y_y3 = self.conv_gr_to_yel_3(y)

        y = self.conv_rs3_1(y)
        y = self.conv_rs3_2(y)
        y = self.conv_rs3_3(y)
        y = self.conv_rs3_4(y)
        y = self.conv_rs3_5(y)
        y = self.conv_rs3_6(y)

        y_y4 = self.conv_gr_to_yel_4(y)

        y = self.conv_rs4_1(y)
        y = self.conv_rs4_2(y)
        y = self.conv_rs4_3(y)

        y_y5 = self.conv_gr_to_purp(y)

        # Traverse up the U-based architecture
        y_y5 = self.conv_purp_to_mag_1(y_y5)

        y_y4 = torch.cat((y_y4, y_y5), dim=1)
        y_y4 = self.conv_mag_to_purp_1(y_y4)
        y_y4 = self.conv_purp_to_mag_2(y_y4)

        y_y3 = torch.cat((y_y3, y_y4), dim=1)
        y_y3 = self.conv_mag_to_purp_2(y_y3)
        y_y3 = self.conv_purp_to_mag_3(y_y3)

        y_y2 = torch.cat((y_y2, y_y3), dim=1)
        y_y2 = self.conv_mag_to_purp_3(y_y2)
        y_y2 = self.conv_purp_to_mag_4(y_y2)

        y_y1 = torch.cat((y_y1, y_y2), dim=1)
        y_y1 = self.conv_mag_to_purp_4(y_y1)
        return self.conv_purp_to_mag_5(y_y1)


net = Net()
# print(net)
summary(net, input_size=(3, 320, 480))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 160, 240]           9,472
       BatchNorm2d-2         [-1, 64, 160, 240]             128
              ReLU-3         [-1, 64, 160, 240]               0
            Conv2d-4        [-1, 128, 160, 240]           8,320
         MaxPool2d-5          [-1, 64, 80, 120]               0
            Conv2d-6          [-1, 64, 80, 120]          36,928
       BatchNorm2d-7          [-1, 64, 80, 120]             128
              ReLU-8          [-1, 64, 80, 120]               0
            Conv2d-9          [-1, 64, 80, 120]          36,928
      BatchNorm2d-10          [-1, 64, 80, 120]             128
             ReLU-11          [-1, 64, 80, 120]               0
         ResBlock-12          [-1, 64, 80, 120]               0
           Conv2d-13          [-1, 64, 80, 120]          36,928
      BatchNorm2d-14          [-1, 64, 

In [44]:
# from torchviz import make_dot

# x = torch.randn(1,3, 320, 480)
# y = net(x)

# make_dot(y).view()

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=411d58e9-cb4b-4924-bef0-2f383eff0187' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>