<a href="https://colab.research.google.com/github/pushkar-khetrapal/Yet-Another-EfficientDet-Pytorch/blob/master/Semantic_Segmention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/mapillary/inplace_abn

Collecting git+https://github.com/mapillary/inplace_abn
  Cloning https://github.com/mapillary/inplace_abn to /tmp/pip-req-build-g7j8h66o
  Running command git clone -q https://github.com/mapillary/inplace_abn /tmp/pip-req-build-g7j8h66o
Building wheels for collected packages: inplace-abn
  Building wheel for inplace-abn (setup.py) ... [?25l[?25hdone
  Created wheel for inplace-abn: filename=inplace_abn-1.0.12-cp36-cp36m-linux_x86_64.whl size=3262983 sha256=dea1e1a7e90e75adf5c3ed51335b5b0bf75c343374f6d94cccea877cabea8b82
  Stored in directory: /tmp/pip-ephem-wheel-cache-794mhfwk/wheels/fe/0b/49/1303ca37166cc1be8784e2367a172133634dcd864a9df0ab56
Successfully built inplace-abn
Installing collected packages: inplace-abn
Successfully installed inplace-abn-1.0.12


In [None]:
!pip install pytorch-model-summary

Collecting pytorch-model-summary
  Downloading https://files.pythonhosted.org/packages/a0/de/f3548f3081045cfc4020fc297cc9db74839a6849da8a41b89c48a3307da7/pytorch_model_summary-0.1.1-py3-none-any.whl
Installing collected packages: pytorch-model-summary
Successfully installed pytorch-model-summary-0.1.1


In [None]:
import torch
from inplace_abn.abn import InPlaceABN, InPlaceABNSync
import torch.distributed as dist

In [5]:
## need to use iABNsync layer with leakyRelu
import torch
import torch.nn.functional as F
import torch.distributed as distributed

import math

from torch import nn
import torch.nn.functional as F


class Conv2dStaticSamePadding(nn.Module):
    """
    created by Zylo117
    The real keras/tensorflow conv2d with same padding
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
                              bias=bias, groups=groups, dilation = dilation)
        self.stride = self.conv.stride
        self.kernel_size = self.conv.kernel_size
        self.dilation = self.conv.dilation

        if isinstance(self.stride, int):
            self.stride = [self.stride] * 2
        elif len(self.stride) == 1:
            self.stride = [self.stride[0]] * 2

        if isinstance(self.kernel_size, int):
            self.kernel_size = [self.kernel_size] * 2
        elif len(self.kernel_size) == 1:
            self.kernel_size = [self.kernel_size[0]] * 2

    def forward(self, x):
        h, w = x.shape[-2:]
        
        extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
        extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
        
        left = extra_h // 2
        right = extra_h - left
        top = extra_v // 2
        bottom = extra_v - top

        x = F.pad(x, [left, right, top, bottom])

        x = self.conv(x)
        return x

class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)




class SeparableConvBlock(nn.Module):
    """
    created by Zylo117
    """

    def __init__(self, in_channels, out_channels=None, norm=True, activation=True, dilation = 1, onnx_export=False):
        super(SeparableConvBlock, self).__init__()
        if out_channels is None:
            out_channels = in_channels

        # Q: whether separate conv
        #  share bias between depthwise_conv and pointwise_conv
        #  or just pointwise_conv apply bias.
        # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias.

        self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels,
                                                      kernel_size=3, stride=1, groups=in_channels, bias=False, dilation = dilation)
        self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)

        self.norm = norm
        if self.norm:
            # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
            self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3)

        self.activation = activation
        if self.activation:
            self.swish = MemoryEfficientSwish() if not onnx_export else Swish()

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)

        if self.norm:
            x = self.bn(x)

        if self.activation:
            x = self.swish(x)

        return x


# LSFE module
class LSFE(nn.Module):
    def __init__(self, ):
        super(LSFE, self).__init__()
        self.conv1 = SeparableConvBlock(64, 256)
        self.conv2 = SeparableConvBlock(256, 256)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)

        
        return x
lsfe = LSFE()
lsfe.forward(torch.randn(1,64,128,256)).shape

torch.Size([1, 256, 128, 256])

In [6]:
# Mismatch Correction Module (MC)
class CorrectionModule(nn.Module):
    def __init__(self):
        super(CorrectionModule, self).__init__()
        self.conv1 = SeparableConvBlock(256, 256)
        self.conv2 = SeparableConvBlock(256, 256)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        ## upsampling 

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.up(x)
        return x

cm = CorrectionModule()
cm.forward(torch.randn(1,256,128,256)).shape

  "See the documentation of nn.Upsample for details.".format(mode))


torch.Size([1, 256, 256, 512])

In [15]:
# Dense Prediction Cells (DPC)
class DPC(nn.Module):
    def __init__(self, height, width, channels = 256):
        super(DPC, self).__init__()

        self.height = height
        self.width = width

        self.conv1 = SeparableConvBlock(256, 256, 3, dilation=(1, 6))
        self.up1 = nn.Upsample((self.height, self.width), mode='bilinear')

        self.conv2 = SeparableConvBlock(256, 256, 3, dilation=(1, 1))
        self.up2 = nn.Upsample((self.height, self.width), mode='bilinear')

        self.conv3 = SeparableConvBlock(256, 256, 3, dilation=(6, 21))
        self.up3 = nn.Upsample((self.height, self.width), mode='bilinear')

        self.up_tocalculate18x3 = nn.Upsample((36, 64), mode='bilinear')
        self.conv4 = SeparableConvBlock(256, 256, 3, dilation=(18, 15))
        self.up4 = nn.Upsample((self.height, self.width), mode='bilinear')

        self.conv5 = SeparableConvBlock(256, 256, 3, dilation=(6,3))
        self.up5 = nn.Upsample((self.height, self.width), mode='bilinear')

        self.lastconv = nn.Conv2d(1280, 128, 1)

    def forward(self, x):

        x = self.conv1(x)
        x1 = self.up1(x)
        
        x2 = self.conv2(x1)
        x2 = self.up2(x2)  

        x3 = self.conv3(x1)
        x3 = self.up3(x3)

        x4 = x1
        if( self.height < 33 ):
          x4 = self.up_tocalculate18x3(x4)
        x4 = self.conv4(x4)
        x4 = self.up4(x4)    

        x5 = self.conv5(x4)
        x5 = self.up5(x5)

        cat = torch.cat(( x1, x2, x3, x4, x5), dim = 1)

        cat = self.lastconv(cat)
        

        return cat

dpc = DPC(128, 256)
dpc.forward(torch.randn(1,256,64,64)).shape

  "See the documentation of nn.Upsample for details.".format(mode))


torch.Size([1, 128, 128, 256])

In [None]:
class SemanticHead(nn.Module):
    def __init__(self):
        super(SemanticHead, self).__init__()
        self.dpcp32 = DPC(32, 64)
        self.dpcp16 = DPC(64, 128)
        self.lsfep8 = LSFE()
        self.lsfep4 = LSFE()

        self.up_p32 = nn.Upsample((64, 128), mode='bilinear')

        self.mc1 = CorrectionModule()
        self.mc2 = CorrectionModule()

        self.up1 = nn.Upsample((256, 512), mode = 'bilinear')
        self.up2 = nn.Upsample((256, 512), mode = 'bilinear')
        self.up3 = nn.Upsample((256, 512), mode = 'bilinear')
        
        self.lastconv = nn.Conv2d(512, 512, 1) ####### NEED TO CHANGE OUTPUT CHANNELS
        self.uplast = nn.Upsample((1024, 2048), mode = 'bilinear')
    
    
    def forward(self, p32, p16, p8, p4):

        d32 = self.dpcp32(p32)
        d16 = self.dpcp16(p16)

        lp8 = self.lsfep8(p8)
        lp4 = self.lsfep4(p4)

        up32 = self.up_p32(d32)
        
        add1 = torch.add(up32, d16)
        
        up16 = self.mc1(add1)
        
        add2 = torch.add(up16, lp8)
        up8 = self.mc2(add2)
        add3 = torch.add(up8, lp4) 
        
        cat1 = self.up1(d32)
        cat2 = self.up2(d16) 
        cat3 = self.up3(add2) 

        cat = torch.cat(( cat1, cat2, cat3, add3), dim = 1)

        cat = self.lastconv(cat)

        cat = self.uplast(cat)
        
        return cat
sh = SemanticHead()
sh.forward(torch.randn(1,256,32,64), torch.randn(1,256,64,128), torch.randn(1,256,128,256), torch.randn(1,256,256, 512)).shape

  "See the documentation of nn.Upsample for details.".format(mode))


torch.Size([1, 512, 1024, 2048])

In [None]:
from pytorch_model_summary import summary
print(summary(sh, torch.randn(1,256,32,64), torch.randn(1,256,64,128), torch.randn(1,256,128,256), torch.randn(1,256,256, 512),show_input=True, show_hierarchical=True))

  "See the documentation of nn.Upsample for details.".format(mode))


-----------------------------------------------------------------------------
         Layer (type)            Input Shape         Param #     Tr. Param #
                DPC-1       [1, 256, 32, 64]         505,728         505,728
                DPC-2      [1, 256, 64, 128]         505,728         505,728
               LSFE-3     [1, 256, 128, 256]          53,120          53,120
               LSFE-4     [1, 256, 256, 512]          53,120          53,120
           Upsample-5       [1, 128, 32, 64]               0               0
   CorrectionModule-6      [1, 128, 64, 128]          35,584          35,584
   CorrectionModule-7     [1, 128, 128, 256]          35,584          35,584
           Upsample-8       [1, 128, 32, 64]               0               0
           Upsample-9      [1, 128, 64, 128]               0               0
          Upsample-10     [1, 128, 128, 256]               0               0
            Conv2d-11     [1, 512, 256, 512]         262,656         262,65