<a href="https://colab.research.google.com/github/taravatp/roadLane_InstanceSegmentation/blob/main/ENet_SAD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Creating the building blocks

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

In [None]:
class InitialBlock(nn.Module):

  def __init__(self,in_channels,out_channels,bias=False,relu=True):
    super().__init__()

    if relu:
        activation = nn.ReLU
    else:
        activation = nn.PReLU 

    self.main_branch = nn.Conv2d(in_channels,out_channels - 3,kernel_size=3,stride=2,padding=1,bias=bias)
    self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1)
    self.batch_norm = nn.BatchNorm2d(out_channels)
    self.out_activation = activation()

  def forward(self, x):

    main = self.main_branch(x)
    ext = self.ext_branch(x)
    out = torch.cat((main, ext), 1) #dim 1 is the channels
    out = self.batch_norm(out)
    return self.out_activation(out)


In [None]:
class RegularBottleneck(nn.Module): #these are the
    
    def __init__(self,channels,internal_ratio=4,kernel_size=3,padding=0,dilation=1,asymmetric=False,dropout_prob=0,bias=False,relu=True):
        super().__init__()

        # Check in the internal_scale parameter is within the expected range
        # [1, channels]
        if internal_ratio <= 1 or internal_ratio > channels:
            raise RuntimeError("Value out of range. Expected value in the ""interval [1, {0}], got internal_scale={1}.".format(channels, internal_ratio))

        internal_channels = channels // internal_ratio

        if relu:
            activation = nn.ReLU
        else:
            activation = nn.PReLU

        # Main branch - shortcut connection

        # 1x1 projection convolution -  reduce the dimensionality
        self.ext_conv1 = nn.Sequential(
            nn.Conv2d(channels,internal_channels,kernel_size=1,stride=1,bias=bias),
            nn.BatchNorm2d(internal_channels),
            activation())

        # If the convolution is asymmetric we split the main convolution in
        # two. Eg. for a 5x5 asymmetric convolution we have two convolution:
        # the first is 5x1 and the second is 1x5.
        if asymmetric:
            self.ext_conv2 = nn.Sequential(
                nn.Conv2d(internal_channels,internal_channels,kernel_size=(kernel_size, 1),stride=1,padding=(padding, 0),dilation=dilation,bias=bias),
                nn.BatchNorm2d(internal_channels), 
                activation(),
                nn.Conv2d(internal_channels,internal_channels,kernel_size=(1, kernel_size),stride=1,padding=(0, padding),dilation=dilation,bias=bias),
                nn.BatchNorm2d(internal_channels),
                activation())
        else:
        #regular or dilated convolution
            self.ext_conv2 = nn.Sequential(
                nn.Conv2d(internal_channels,internal_channels,kernel_size=kernel_size,stride=1,padding=padding,dilation=dilation,bias=bias), 
                nn.BatchNorm2d(internal_channels),
                activation())

        # 1x1 expansion convolution
        self.ext_conv3 = nn.Sequential(
            nn.Conv2d(internal_channels,channels,kernel_size=1,stride=1,bias=bias),
            nn.BatchNorm2d(channels),
            activation())

        self.ext_regul = nn.Dropout2d(p=dropout_prob)

        # PReLU layer to apply after adding the branches
        self.out_activation = activation()

    def forward(self, x):
        # Main branch shortcut
        main = x

        # Extension branch
        ext = self.ext_conv1(x)
        ext = self.ext_conv2(ext)
        ext = self.ext_conv3(ext)
        ext = self.ext_regul(ext)

        # Add main and extension branches - element-wise addition
        out = main + ext

        return self.out_activation(out)


In [None]:
class DownsamplingBottleneck(nn.Module):


    def __init__(self,in_channels,out_channels,internal_ratio=4,return_indices=False,dropout_prob=0,bias=False,relu=True):
        super().__init__()

        # Store parameters that are needed later
        self.return_indices = return_indices

        # Check in the internal_scale parameter is within the expected range
        # [1, channels]
        if internal_ratio <= 1 or internal_ratio > in_channels:
            raise RuntimeError("Value out of range. Expected value in the ""interval [1, {0}], got internal_scale={1}. ".format(in_channels, internal_ratio))

        internal_channels = in_channels // internal_ratio

        if relu:
            activation = nn.ReLU
        else:
            activation = nn.PReLU

        # Main branch - max pooling followed by feature map (channels) padding
        self.main_max1 = nn.MaxPool2d(2,stride=2,return_indices=return_indices)

        # Extension branch - 2x2 convolution, followed by a regular, dilated or
        # asymmetric convolution, followed by another 1x1 convolution. Number
        # of channels is doubled.

        # 2x2 projection convolution with stride 2
        self.ext_conv1 = nn.Sequential(
            nn.Conv2d(in_channels,internal_channels,kernel_size=2,stride=2,bias=bias), 
            nn.BatchNorm2d(internal_channels), 
            activation())

        # Convolution
        self.ext_conv2 = nn.Sequential(
            nn.Conv2d(internal_channels,internal_channels,kernel_size=3,stride=1,padding=1,bias=bias),
            nn.BatchNorm2d(internal_channels),
            activation())

        # 1x1 expansion convolution
        self.ext_conv3 = nn.Sequential(
            nn.Conv2d(internal_channels,out_channels,kernel_size=1,stride=1,bias=bias), 
            nn.BatchNorm2d(out_channels),
            activation())

        self.ext_regul = nn.Dropout2d(p=dropout_prob)

        # PReLU layer to apply after concatenating the branches
        self.out_activation = activation()

    def forward(self, x):
        # Main branch shortcut
        if self.return_indices:
            main, max_indices = self.main_max1(x)
        else:
            main = self.main_max1(x)

        # Extension branch
        ext = self.ext_conv1(x)
        ext = self.ext_conv2(ext)
        ext = self.ext_conv3(ext)
        ext = self.ext_regul(ext)

        # Main branch channel padding
        n, ch_ext, h, w = ext.size()
        ch_main = main.size()[1]
        padding = torch.zeros(n, ch_ext - ch_main, h, w)

        # Before concatenating, check if main is on the CPU or GPU and
        # convert padding accordingly
        if main.is_cuda:
            padding = padding.cuda()

        # Concatenate
        main = torch.cat((main, padding), 1)

        # Add main and extension branches
        out = main + ext

        return self.out_activation(out), max_indices


In [None]:
class UpsamplingBottleneck(nn.Module):
    def __init__(self,in_channels,out_channels,internal_ratio=4,dropout_prob=0,bias=False,relu=True):
        super().__init__()

        # Check in the internal_scale parameter is within the expected range
        # [1, channels]
        if internal_ratio <= 1 or internal_ratio > in_channels:
            raise RuntimeError("Value out of range. Expected value in the ""interval [1, {0}], got internal_scale={1}. ".format(in_channels, internal_ratio))

        internal_channels = in_channels // internal_ratio

        if relu:
            activation = nn.ReLU
        else:
            activation = nn.PReLU

        # Main branch - max pooling followed by feature map (channels) padding
        self.main_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias),
            nn.BatchNorm2d(out_channels))

        # Remember that the stride is the same as the kernel_size, just like
        # the max pooling layers
        self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2)

        # Extension branch - 1x1 convolution, followed by a regular, dilated or
        # asymmetric convolution, followed by another 1x1 convolution. Number
        # of channels is doubled.

        # 1x1 projection convolution with stride 1
        self.ext_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, internal_channels, kernel_size=1, bias=bias),
            nn.BatchNorm2d(internal_channels),
            activation())

        # Transposed convolution
        self.ext_tconv1 = nn.ConvTranspose2d(internal_channels,internal_channels,kernel_size=2,stride=2,bias=bias)
        self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels)
        self.ext_tconv1_activation = activation()

        # 1x1 expansion convolution
        self.ext_conv2 = nn.Sequential(
            nn.Conv2d(internal_channels, out_channels, kernel_size=1, bias=bias),
            nn.BatchNorm2d(out_channels),
            activation())

        self.ext_regul = nn.Dropout2d(p=dropout_prob)

        # PReLU layer to apply after concatenating the branches
        self.out_activation = activation()

    def forward(self, x, max_indices, output_size):
        # Main branch shortcut
        main = self.main_conv1(x)
        main = self.main_unpool1(main, max_indices, output_size=output_size)

        # Extension branch
        ext = self.ext_conv1(x)
        ext = self.ext_tconv1(ext, output_size=output_size)
        ext = self.ext_tconv1_bnorm(ext)
        ext = self.ext_tconv1_activation(ext)
        ext = self.ext_conv2(ext)
        ext = self.ext_regul(ext)

        # Add main and extension branches
        out = main + ext

        return self.out_activation(out)


In [None]:
class SpatialSoftmax(nn.Module):
    def __init__(self):
        super(SpatialSoftmax, self).__init__()

    def forward(self, feature):

        feature = feature.view(feature.shape[0], -1, feature.shape[1] * feature.shape[2])
        softmax_attention = F.softmax(feature, dim=-1)
        return softmax_attention

# Buliding the model

In [None]:
class Enet_SAD(nn.Module):
  def __init__(self, binary_seg, embedding_dim, encoder_relu=False, decoder_relu=True, sad=True):
    super(Enet_SAD, self).__init__()

    self.sad = sad
    self.initial_block = InitialBlock(3, 16, relu=encoder_relu)

    # Stage 1 share
    self.downsample1_0 = DownsamplingBottleneck(16, 64, return_indices=True, dropout_prob=0.01, relu=encoder_relu)
    self.regular1_1 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)
    self.regular1_2 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)
    self.regular1_3 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)
    self.regular1_4 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu)

    # Stage 2 share
    self.downsample2_0 = DownsamplingBottleneck(64, 128, return_indices=True, dropout_prob=0.1, relu=encoder_relu)
    self.regular2_1 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
    self.dilated2_2 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
    self.asymmetric2_3 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, relu=encoder_relu)
    self.dilated2_4 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
    self.regular2_5 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
    self.dilated2_6 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
    self.asymmetric2_7 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, relu=encoder_relu)
    self.dilated2_8 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)

    # stage 3 binary
    self.regular_binary_3_0 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_binary_3_1 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
    self.asymmetric_binary_3_2 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_binary_3_3 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
    self.regular_binary_3_4 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_binary_3_5 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
    self.asymmetric_binary_3_6 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_binary_3_7 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)

    # stage 3 embedding
    self.regular_embedding_3_0 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_embedding_3_1 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
    self.asymmetric_embedding_3_2 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_embedding_3_3 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
    self.regular_embedding_3_4 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_embedding_3_5 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
    self.asymmetric_bembedding_3_6 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, relu=encoder_relu)
    self.dilated_embedding_3_7 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)

    # binary branch
    self.upsample_binary_4_0 = UpsamplingBottleneck(128, 64, dropout_prob=0.1, relu=decoder_relu)
    self.regular_binary_4_1 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
    self.regular_binary_4_2 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
    self.upsample_binary_5_0 = UpsamplingBottleneck(64, 16, dropout_prob=0.1, relu=decoder_relu)
    self.regular_binary_5_1 = RegularBottleneck(16, padding=1, dropout_prob=0.1, relu=decoder_relu)
    self.binary_transposed_conv = nn.ConvTranspose2d(16, binary_seg, kernel_size=3, stride=2, padding=1, bias=False)

    # embedding branch
    self.upsample_embedding_4_0 = UpsamplingBottleneck(128, 64, dropout_prob=0.1, relu=decoder_relu)
    self.regular_embedding_4_1 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
    self.regular_embedding_4_2 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
    self.upsample_embedding_5_0 = UpsamplingBottleneck(64, 16, dropout_prob=0.1, relu=decoder_relu)
    self.regular_embedding_5_1 = RegularBottleneck(16, padding=1, dropout_prob=0.1, relu=decoder_relu)
    self.embedding_transposed_conv = nn.ConvTranspose2d(16, embedding_dim, kernel_size=3, stride=2, padding=1, bias=False)

  def generate_attention_type1(self,x1,x2):
    #x1: previous encoder feature map
    #x2: current encoder feature map
    spatial_softmax = SpatialSoftmax()

    if x1.size() != x2.size():
      x1 = torch.sum(torch.abs(x1), dim=1)
      attention1 = x1
      x1 = spatial_softmax(x1)
      x2 = torch.sum(torch.abs(x2), dim=1, keepdim=True)
      x2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)(x2)
      attention2 = x2

      x2 = torch.squeeze(x2, dim=1)
      x2 = spatial_softmax(x2)

    else:
      x1 = torch.sum(torch.abs(x1), dim=1)
      attention1 = x1
      x1 = spatial_softmax(x1)
      x2 = torch.sum(torch.abs(x2), dim=1)
      attention2 = x2
      x2 = spatial_softmax(x2)

    loss = nn.MSELoss(reduction='mean')(x1, x2)
    return loss,attention1,attention2

  def generate_attention_type2(self,x1,x2):
    #x1: previous encoder feature map
    #x2: current encoder feature map
    spatial_softmax = SpatialSoftmax()

    if x1.size() != x2.size():
      x1 = torch.sum(x1 * x1, dim=1)
      attention1 = x1
      x1 = spatial_softmax(x1)

      x2 = torch.sum(x2 * x2, dim=1, keepdim=True)
      attention2 = x2
      x2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)(x2)
      x2 = torch.squeeze(x2, dim=1)
      x2 = spatial_softmax(x2)

    else:
      x1 = torch.sum(x1 * x1, dim=1)
      attention1 = x1
      x1 = spatial_softmax(x1)

      x2 = torch.sum(x2 * x2, dim=1)
      attention2 = x2
      x2 = spatial_softmax(x2)

    loss = nn.MSELoss(reduction='mean')(x1, x2)

    return loss,attention1,attention2

  def generate_attention_type3(self,x1,x2):
    #x1: previous encoder feature map
    #x2: current encoder feature map
    spatial_softmax = SpatialSoftmax()

    if x1.size() != x2.size():
      x1, input_indexes = torch.max(x1 * x1, dim=1)
      attention1 = x1
      x1 = spatial_softmax(x1)

      x2, input_indexes = torch.max(x2 * x2, dim=1, keepdim=True)
      x2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)(x2)
      attention2 = x2

      x2 = torch.squeeze(x2, dim=1)
      x2 = spatial_softmax(x2)

    else:
      x1, input_indexes = torch.max(x1 * x1, dim=1)
      attention1 = x1
      x1 = spatial_softmax(x1)

      x2, input_indexes = torch.max(x2 * x2, dim=1)
      attention2 = x2
      x2 = spatial_softmax(x2)
    
    loss = nn.MSELoss(reduction='mean')(x1, x2)
    return loss,attention1,attention2

  def forward(self, x):

    # Initial block
    input_size = x.size()
    x = self.initial_block(x)
    # Stage 1 share
    stage1_input_size = x.size()
    x, max_indices1_0 = self.downsample1_0(x)
    x = self.regular1_1(x)
    x = self.regular1_2(x)
    x = self.regular1_3(x)
    x_1 = self.regular1_4(x)

    # Stage 2 share
    stage2_input_size = x_1.size()
    x, max_indices2_0 = self.downsample2_0(x_1)
    x = self.regular2_1(x)
    x = self.dilated2_2(x)
    x = self.asymmetric2_3(x)
    x = self.dilated2_4(x)
    x = self.regular2_5(x)
    x = self.dilated2_6(x)
    x = self.asymmetric2_7(x)
    x_2 = self.dilated2_8(x)

    if self.sad:
      loss_1, attention_stage1,attention_stage2 = self.generate_attention_type2(x_1, x_2)
      
    # stage 3 binary
    x_binary = self.regular_binary_3_0(x_2)
    x_binary = self.dilated_binary_3_1(x_binary)
    x_binary = self.asymmetric_binary_3_2(x_binary)
    x_binary = self.dilated_binary_3_3(x_binary)
    x_binary = self.regular_binary_3_4(x_binary)
    x_binary = self.dilated_binary_3_5(x_binary)
    x_binary = self.asymmetric_binary_3_6(x_binary)
    print(x_binary.shape)
    x_3 = self.dilated_binary_3_7(x_binary)
    print(x_3.shape)
    if self.sad:
      loss_2,attention_stage2,attention_stage3 = self.generate_attention_type2(x_2, x_3)

    # stage 3 embedding
    x_embedding = self.regular_embedding_3_0(x_2)
    x_embedding = self.dilated_embedding_3_1(x_embedding)
    x_embedding = self.asymmetric_embedding_3_2(x_embedding)
    x_embedding = self.dilated_embedding_3_3(x_embedding)
    x_embedding = self.regular_embedding_3_4(x_embedding)
    x_embedding = self.dilated_embedding_3_5(x_embedding)
    x_embedding = self.asymmetric_bembedding_3_6(x_embedding)
    x_embedding = self.dilated_embedding_3_7(x_embedding)

    # binary branch - deocder
    x_binary = self.upsample_binary_4_0(x_binary, max_indices2_0, output_size=stage2_input_size)
    print(x_binary.shape)
    x_binary = self.regular_binary_4_1(x_binary)
    x_binary = self.regular_binary_4_2(x_binary)
    
    x_binary = self.upsample_binary_5_0(x_binary, max_indices1_0, output_size=stage1_input_size)
    x_binary = self.regular_binary_5_1(x_binary)
    binary_final_logits = self.binary_transposed_conv(x_binary, output_size=input_size)

    # embedding branch - decoder
    x_embedding = self.upsample_embedding_4_0(x_embedding, max_indices2_0, output_size=stage2_input_size)
    x_embedding = self.regular_embedding_4_1(x_embedding)
    x_embedding = self.regular_embedding_4_2(x_embedding)
    x_embedding = self.upsample_embedding_5_0(x_embedding, max_indices1_0, output_size=stage1_input_size)
    x_embedding = self.regular_embedding_5_1(x_embedding)
    
    instance_notfinal_logits = self.embedding_transposed_conv(x_embedding, output_size=input_size)

    distillation_loss = loss_1 + loss_2

    return binary_final_logits, instance_notfinal_logits,distillation_loss , attention_stage1, attention_stage2,attention_stage3

To do List
---


*   bia encoder ro ye doone ziad kon.. ke too shakheye binary ye encoder bishtar dashe bashi
*   attention ro faghat too shakheye mokhtas be segmentation gharar bede
* mitooni attentione mokhtas be discriminative ro ham joda gharar bedi ( test kon bebin performance taghiri mikone ya na)
* kolan bezar stage0 va stage 1 share bashe baghie mokhtas be khodeshoon bashe
* stage 2 ro joda kon



In [None]:
if __name__ == '__main__':
    test_input = torch.ones((8, 3, 256, 512))
    net = Enet_SAD(2, 4)
    binary_final_logits, instance_notfinal_logits,distillation_loss , attention_stage1, attention_stage2, attention_stage3 = net(test_input)
    #x_1 , x_2 =  net(test_input);

torch.Size([8, 128, 32, 64])
torch.Size([8, 128, 32, 64])
torch.Size([8, 64, 64, 128])
