In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import os
from functools import reduce

import torch
import torch.nn as nn

from mobilenetv2 import MobileNetV2


class BaseBackbone(nn.Module):
    """ Superclass of Replaceable Backbone Model for Semantic Estimation
    """

    def __init__(self, in_channels):
        super(BaseBackbone, self).__init__()
        self.in_channels = in_channels

        self.model = None
        self.enc_channels = []

    def forward(self, x):
        raise NotImplementedError

    def load_pretrained_ckpt(self):
        raise NotImplementedError


class MobileNetV2Backbone(BaseBackbone):
    """ MobileNetV2 Backbone 
    """

    def __init__(self, in_channels):
        super(MobileNetV2Backbone, self).__init__(in_channels)

        self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
        
        self.enc_channels = [16, 24, 32, 96, 1280]       ####################### ENCODER CHANNELS #######################

    def forward(self, x):
        # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
        x = self.model.features[0](x)
        x = self.model.features[1](x)
        enc2x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
        x = self.model.features[2](x)
        x = self.model.features[3](x)
        enc4x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
        x = self.model.features[4](x)
        x = self.model.features[5](x)
        x = self.model.features[6](x)
        enc8x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
        x = self.model.features[7](x)
        x = self.model.features[8](x)
        x = self.model.features[9](x)
        x = self.model.features[10](x)
        x = self.model.features[11](x)
        x = self.model.features[12](x)
        x = self.model.features[13](x)
        enc16x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
        x = self.model.features[14](x)
        x = self.model.features[15](x)
        x = self.model.features[16](x)
        x = self.model.features[17](x)
        x = self.model.features[18](x)
        enc32x = x
        
        return [enc2x, enc4x, enc8x, enc16x, enc32x]

    def load_pretrained_ckpt(self):
        # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch 
        ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
        if not os.path.exists(ckpt_path):
            print('cannot find the pretrained mobilenetv2 backbone')
            exit()
        
        ckpt = torch.load(ckpt_path)
        self.model.load_state_dict(ckpt)


In [3]:
class IBNorm(nn.Module):
    """ Combine Instance Norm and Batch Norm into One Layer
    """

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels
        self.bnorm_channels = int(in_channels / 2)
        self.inorm_channels = in_channels - self.bnorm_channels

        self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
        self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
        
    def forward(self, x):
        bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
        in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())

        return torch.cat((bn_x, in_x), 1)

class Conv2dIBNormRelu(nn.Module):
    """ Convolution + IBNorm + ReLu
    """

    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, padding=0, dilation=1, groups=1, bias=True, 
                 with_ibn=True, with_relu=True):
        super(Conv2dIBNormRelu, self).__init__()

        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size, 
                      stride=stride, padding=padding, dilation=dilation, 
                      groups=groups, bias=bias)
        ]

        if with_ibn:       
            layers.append(IBNorm(out_channels))
        if with_relu:
            layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x) 

class SEBlock(nn.Module):
    """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 
    """

    def __init__(self, in_channels, out_channels, reduction=1):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels // reduction), bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels // reduction), out_channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)

        return x * w.expand_as(x)

In [49]:
SUPPORTED_BACKBONES = {
    'mobilenetv2': MobileNetV2Backbone,
}

class LRBranch(nn.Module):
    """ Low Resolution Branch of MODNet
    """

    def __init__(self, backbone):
        super(LRBranch, self).__init__()

        enc_channels = backbone.enc_channels
        
        print("---------", enc_channels)
        
        self.backbone = backbone
        self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
        self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
        self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
        self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)

    def forward(self, img, inference=False):
        enc_features = self.backbone.forward(img)
        enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]

        enc32x = self.se_block(enc32x)
        lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
        lr16x = self.conv_lr16x(lr16x)
        lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
        lr8x = self.conv_lr8x(lr8x)

        pred_semantic = None
        if not inference:
            lr = self.conv_lr(lr8x)
            pred_semantic = torch.sigmoid(lr)
        
        return pred_semantic, lr8x, [enc2x, enc4x] 
    

class HRBranch(nn.Module):
    """ High Resolution Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(HRBranch, self).__init__()

        self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
        self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)

        self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
        self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)

        self.conv_hr4x = nn.Sequential(
            Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr2x = nn.Sequential(
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, enc2x, enc4x, lr8x, inference):
        img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
        img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)

        enc2x = self.tohr_enc2x(enc2x)
        hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))

        enc4x = self.tohr_enc4x(enc4x)
        hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))

        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))

        hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
        hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))

        pred_detail = None
        if not inference:
            hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
            hr = self.conv_hr(torch.cat((hr, img), dim=1))
            pred_detail = torch.sigmoid(hr)

        return pred_detail, hr2x


class FusionBranch(nn.Module):
    """ Fusion Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(FusionBranch, self).__init__()
        self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
        
        self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
        self.conv_f = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
            Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, lr8x, hr2x):
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        lr4x = self.conv_lr4x(lr4x)
        lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)

        f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
        f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
        f = self.conv_f(torch.cat((f, img), dim=1))
        #pred_matte = torch.sigmoid(f)

        #return pred_matte
        return f

In [50]:
#------------------------------------------------------------------------------
#  MODNet
#------------------------------------------------------------------------------

class MODNet(nn.Module):
    """ Architecture of MODNet
    """

    def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=False):
        super(MODNet, self).__init__()

        self.in_channels = in_channels
        self.hr_channels = hr_channels
        self.backbone_arch = backbone_arch 
        self.backbone_pretrained = backbone_pretrained

        self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)

        self.lr_branch = LRBranch(self.backbone)
        self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
        self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                self._init_conv(m)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
                self._init_norm(m)

        if self.backbone_pretrained:
            self.backbone.load_pretrained_ckpt()                

    def forward(self, img, inference=False):
        pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
        pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
        pred_matte = self.f_branch(img, lr8x, hr2x)

        return pred_semantic, pred_detail, pred_matte
    
    def freeze_norm(self):
        norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
        for m in self.modules():
            for n in norm_types:
                if isinstance(m, n):
                    m.eval()
                    continue

    def _init_conv(self, conv):
        nn.init.kaiming_uniform_(
            conv.weight, a=0, mode='fan_in', nonlinearity='relu')
        if conv.bias is not None:
            nn.init.constant_(conv.bias, 0)

    def _init_norm(self, norm):
        if norm.weight is not None:
            nn.init.constant_(norm.weight, 1)
            nn.init.constant_(norm.bias, 0)

In [51]:
from torchsummary import summary

model = MODNet()
summary(model, (3, 224, 224))

--------- [16, 24, 32, 96, 1280]
Layer (type:depth-idx)                        Output Shape              Param #
├─LRBranch: 1-1                               [-1, 1, 14, 14]           --
|    └─SEBlock: 2-1                           [-1, 1280, 7, 7]          --
|    |    └─AdaptiveAvgPool2d: 3-1            [-1, 1280, 1, 1]          --
|    |    └─Sequential: 3-2                   [-1, 1280]                819,200
|    └─Conv2dIBNormRelu: 2-2                  [-1, 96, 14, 14]          --
|    |    └─Sequential: 3-3                   [-1, 96, 14, 14]          3,072,192
|    └─Conv2dIBNormRelu: 2-3                  [-1, 32, 28, 28]          --
|    |    └─Sequential: 3-4                   [-1, 32, 28, 28]          76,864
|    └─Conv2dIBNormRelu: 2-4                  [-1, 1, 14, 14]           --
|    |    └─Sequential: 3-5                   [-1, 1, 14, 14]           289
├─HRBranch: 1-2                               [-1, 1, 224, 224]         --
|    └─Conv2dIBNormRelu: 2-5                 

Layer (type:depth-idx)                        Output Shape              Param #
├─LRBranch: 1-1                               [-1, 1, 14, 14]           --
|    └─SEBlock: 2-1                           [-1, 1280, 7, 7]          --
|    |    └─AdaptiveAvgPool2d: 3-1            [-1, 1280, 1, 1]          --
|    |    └─Sequential: 3-2                   [-1, 1280]                819,200
|    └─Conv2dIBNormRelu: 2-2                  [-1, 96, 14, 14]          --
|    |    └─Sequential: 3-3                   [-1, 96, 14, 14]          3,072,192
|    └─Conv2dIBNormRelu: 2-3                  [-1, 32, 28, 28]          --
|    |    └─Sequential: 3-4                   [-1, 32, 28, 28]          76,864
|    └─Conv2dIBNormRelu: 2-4                  [-1, 1, 14, 14]           --
|    |    └─Sequential: 3-5                   [-1, 1, 14, 14]           289
├─HRBranch: 1-2                               [-1, 1, 224, 224]         --
|    └─Conv2dIBNormRelu: 2-5                  [-1, 32, 112, 112]        --
|  

## Dissecting each branch and trying a lighter alernative

## LR Branch

In [65]:
SUPPORTED_BACKBONES = {
    'mobilenetv2': MobileNetV2Backbone,
}

class LRBranch(nn.Module):
    """ Low Resolution Branch of MODNet
    """

    def __init__(self, backbone):
        super(LRBranch, self).__init__()

        enc_channels = backbone.enc_channels
        
        print("---------", enc_channels)
        
        self.backbone = backbone
        self.se_block = SEBlock(enc_channels[3], enc_channels[3], reduction=4) # enc_channels[4]
        self.conv_lr16x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) # enc_channels[4],  enc_channels[3]
        self.conv_lr8x = Conv2dIBNormRelu(enc_channels[2], enc_channels[2], 5, stride=1, padding=2) # enc_channels[3]
        self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)

    def forward(self, img, inference=False):
        enc_features = self.backbone.forward(img)
        enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[3] # enc_features[3]

        enc32x = self.se_block(enc32x)
        lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
        lr16x = self.conv_lr16x(lr16x)
        lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
        lr8x = self.conv_lr8x(lr8x)

        pred_semantic = None
        if not inference:
            lr = self.conv_lr(lr8x)
            pred_semantic = torch.sigmoid(lr)
        
        return pred_semantic, lr8x, [enc2x, enc4x] 
    

In [66]:
in_channels=3
hr_channels=32
backbone_arch='mobilenetv2'
backbone_pretrained=False

backbone = SUPPORTED_BACKBONES[backbone_arch](in_channels)

lr_branch = LRBranch(backbone)

--------- [16, 24, 32, 96, 1280]


In [67]:
from torchsummary import summary

summary(lr_branch, (3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
├─SEBlock: 1-1                                [-1, 96, 14, 14]          --
|    └─AdaptiveAvgPool2d: 2-1                 [-1, 96, 1, 1]            --
|    └─Sequential: 2-2                        [-1, 96]                  --
|    |    └─Linear: 3-1                       [-1, 24]                  2,304
|    |    └─ReLU: 3-2                         [-1, 24]                  --
|    |    └─Linear: 3-3                       [-1, 96]                  2,304
|    |    └─Sigmoid: 3-4                      [-1, 96]                  --
├─Conv2dIBNormRelu: 1-2                       [-1, 32, 28, 28]          --
|    └─Sequential: 2-3                        [-1, 32, 28, 28]          --
|    |    └─Conv2d: 3-5                       [-1, 32, 28, 28]          76,832
|    |    └─IBNorm: 3-6                       [-1, 32, 28, 28]          32
|    |    └─ReLU: 3-7                         [-1, 32, 28, 28]          --
├─Conv2dIB

Layer (type:depth-idx)                        Output Shape              Param #
├─SEBlock: 1-1                                [-1, 96, 14, 14]          --
|    └─AdaptiveAvgPool2d: 2-1                 [-1, 96, 1, 1]            --
|    └─Sequential: 2-2                        [-1, 96]                  --
|    |    └─Linear: 3-1                       [-1, 24]                  2,304
|    |    └─ReLU: 3-2                         [-1, 24]                  --
|    |    └─Linear: 3-3                       [-1, 96]                  2,304
|    |    └─Sigmoid: 3-4                      [-1, 96]                  --
├─Conv2dIBNormRelu: 1-2                       [-1, 32, 28, 28]          --
|    └─Sequential: 2-3                        [-1, 32, 28, 28]          --
|    |    └─Conv2d: 3-5                       [-1, 32, 28, 28]          76,832
|    |    └─IBNorm: 3-6                       [-1, 32, 28, 28]          32
|    |    └─ReLU: 3-7                         [-1, 32, 28, 28]          --
├─Conv2dIB

## HR Branch

In [104]:
class HRBranch(nn.Module):
    """ High Resolution Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(HRBranch, self).__init__()

        self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
        self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)

        self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
        self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)

        self.conv_hr4x = nn.Sequential(
            Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr2x = nn.Sequential(
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, enc2x, enc4x, lr8x, inference):
        #img2x = F.interpolate(img, scale_factor=1, mode='bilinear', align_corners=False)
        #img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)

        enc2x = self.tohr_enc2x.cuda()(enc2x)
        img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
        hr4x = self.conv_enc2x.cuda()(torch.cat((img2x, enc2x), dim=1))

        enc4x = self.tohr_enc4x.cuda()(enc4x)
        hr4x = self.conv_enc4x.cuda()(torch.cat((hr4x, enc4x), dim=1))

        lr4x = F.interpolate(lr8x, scale_factor=1, mode='bilinear', align_corners=False)
        img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
        
        print("---", hr4x.shape, lr4x.shape, img4x.shape)
        
        hr4x = self.conv_hr4x.cuda()(torch.cat((hr4x, lr4x, img4x), dim=1))

        hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
        hr2x = self.conv_hr2x.cuda()(torch.cat((hr2x, enc2x), dim=1))

        pred_detail = None
        if not inference:
            hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
            hr = self.conv_hr.cuda()(torch.cat((hr, img), dim=1))
            pred_detail = torch.sigmoid(hr)

        return pred_detail, hr2x

In [105]:
hr_channels=32
backbone_arch='mobilenetv2'
backbone_pretrained=False

backbone = SUPPORTED_BACKBONES[backbone_arch](in_channels)

img = torch.randn(1, 3, 224, 224).cuda()
pred_semantic, lr8x, [enc2x, enc4x] = lr_branch(img, inference=False)


hr_branch = HRBranch(hr_channels, backbone.enc_channels)


#### Note - Below Flop, Param counting implementation has been modified for "HR_Branch" specifically in "get_params" section

In [118]:
"""Implementation of the PyTorch Models Profiler"""
import torch
import torch.nn as nn


class ModelProfiler(nn.Module):
    """ Profile PyTorch models.

    Compute FLOPs (FLoating OPerations) and number of trainable parameters of model.

    Arguments:
        model (nn.Module): model which will be profiled.

    Example:
        model = torchvision.models.resnet50()
        profiler = ModelProfiler(model)
        var = torch.zeros(1, 3, 224, 224)
        profiler(var)
        print("FLOPs: {0:.5}; #Params: {1:.5}".format(profiler.get_flops('G'), profiler.get_params('M')))

    Warning:
        Model profiler doesn't work with models, wrapped by torch.nn.DataParallel.
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.flops = 0
        self.units = {'K': 10.**3, 'M': 10.**6, 'G': 10.**9}
        self.hooks = None
        self._remove_hooks()

    def get_flops(self, units='G'):
        """ Get number of floating operations per inference.

        Arguments:
            units (string): units of the flops value ('K': Kilo (10^3), 'M': Mega (10^6), 'G': Giga (10^9)).

        Returns:
            Floating operations per inference at the choised units.
        """
        assert units in self.units
        return self.flops / self.units[units]

    def get_params(self, units='K'):
        """ Get number of trainable parameters of the model.

        Arguments:
            units (string): units of the flops value ('K': Kilo (10^3), 'M': Mega (10^6), 'G': Giga (10^9)).

        Returns:
            Number of trainable parameters of the model at the choised units.
        """
        assert units in self.units
        model = HRBranch(hr_channels, backbone.enc_channels) # remove this line for general implementation
        params = sum(p.numel() for p in model.parameters() if p.requires_grad) # self.model rather than model
        if units is not None:
            params = params / self.units[units]
        return params

    def forward(self, *args, **kwargs):
        self.flops = 0
        self._init_hooks()
        output = self.model(*args, **kwargs)
        self._remove_hooks()
        return output

    def _remove_hooks(self):
        if self.hooks is not None:
            for hook in self.hooks:
                hook.remove()
        self.hooks = None

    def _init_hooks(self):
        self.hooks = []

        def hook_compute_flop(module, _, output):
            self.flops += module.weight.size()[1:].numel() * output.size()[1:].numel()

        def add_hooks(module):
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                self.hooks.append(module.register_forward_hook(hook_compute_flop))

        self.model.apply(add_hooks)


def profile_model(model, input_size, cuda):
    """ Compute FLOPS and #Params of the CNN.

    Arguments:
        model (nn.Module): model which should be profiled.
        input_size (tuple): size of the input variable.
        cuda (bool): if True then variable will be upload to the GPU.

    Returns:
        dict:
            dict["flops"] (float): number of GFLOPs.
            dict["params"] (int): number of million parameters.
    """
    profiler = ModelProfiler(model)
    var = torch.zeros(input_size)
    if cuda:
        var = var.cuda()
    profiler(var)
    return {"flops": profiler.get_flops('G'), "params": profiler.get_params('M')}


In [120]:
profiler = ModelProfiler(hr_branch(img, enc2x, enc4x, lr8x, False))
{"flops": profiler.get_flops('G'), "params": profiler.get_params('K')} #  # 245,400 params

--- torch.Size([1, 64, 56, 56]) torch.Size([1, 32, 56, 56]) torch.Size([1, 3, 56, 56])


{'flops': 0.0, 'params': 245.409}

## Fusion Branch

In [134]:
class FusionBranch(nn.Module):
    """ Fusion Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(FusionBranch, self).__init__()
        self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
        
        self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
        self.conv_f = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
            Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, lr8x, hr2x):
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        lr4x = self.conv_lr4x.cuda()(lr4x)
        lr2x = F.interpolate(lr4x, scale_factor=1, mode='bilinear', align_corners=False)

        f2x = self.conv_f2x.cuda()(torch.cat((lr2x, hr2x), dim=1))
        f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
        f = self.conv_f.cuda()(torch.cat((f, img), dim=1))
        #pred_matte = torch.sigmoid(f)

        #return pred_matte
        return f

In [135]:
pred_semantic, lr8x, [enc2x, enc4x] = lr_branch(img, inference=False)
pred_detail, hr2x = hr_branch(img, enc2x, enc4x, lr8x, inference=False)

hr_channels=32
backbone_arch='mobilenetv2'
backbone_pretrained=False

backbone = SUPPORTED_BACKBONES[backbone_arch](in_channels)

f_branch = FusionBranch(hr_channels, backbone.enc_channels)

--- torch.Size([1, 64, 56, 56]) torch.Size([1, 32, 56, 56]) torch.Size([1, 3, 56, 56])


In [136]:
profiler = ModelProfiler(f_branch(img, lr8x, hr2x))
{"flops": profiler.get_flops('G'), "params": profiler.get_params('K')} #  # 245,400 params

{'flops': 0.0, 'params': 245.409}

## Combine all the sections (modified) together

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [8]:
import os
from functools import reduce

import torch
import torch.nn as nn

from backbones.mobilenetv2 import MobileNetV2


class BaseBackbone(nn.Module):
    """ Superclass of Replaceable Backbone Model for Semantic Estimation
    """

    def __init__(self, in_channels):
        super(BaseBackbone, self).__init__()
        self.in_channels = in_channels

        self.model = None
        self.enc_channels = []

    def forward(self, x):
        raise NotImplementedError

    def load_pretrained_ckpt(self):
        raise NotImplementedError


class MobileNetV2Backbone(BaseBackbone):
    """ MobileNetV2 Backbone 
    """

    def __init__(self, in_channels):
        super(MobileNetV2Backbone, self).__init__(in_channels)

        self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
        
        self.enc_channels = [16, 24, 32, 96, 1280]       ####################### ENCODER CHANNELS #######################

    def forward(self, x):
        # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
        x = self.model.features[0](x)
        x = self.model.features[1](x)
        enc2x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
        x = self.model.features[2](x)
        x = self.model.features[3](x)
        enc4x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
        x = self.model.features[4](x)
        x = self.model.features[5](x)
        x = self.model.features[6](x)
        enc8x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
        x = self.model.features[7](x)
        x = self.model.features[8](x)
        x = self.model.features[9](x)
        x = self.model.features[10](x)
        x = self.model.features[11](x)
        x = self.model.features[12](x)
        x = self.model.features[13](x)
        enc16x = x

        # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
        x = self.model.features[14](x)
        x = self.model.features[15](x)
        x = self.model.features[16](x)
        x = self.model.features[17](x)
        x = self.model.features[18](x)
        enc32x = x
        
        return [enc2x, enc4x, enc8x, enc16x, enc32x]

    def load_pretrained_ckpt(self):
        # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch 
        ckpt_path = 'mobilenetv2_human_seg.ckpt' #./pretrained/
        if not os.path.exists(ckpt_path):
            print('cannot find the pretrained mobilenetv2 backbone')
            exit()
        
        ckpt = torch.load(ckpt_path)
        self.model.load_state_dict(ckpt)



In [9]:
class IBNorm(nn.Module):
    """ Combine Instance Norm and Batch Norm into One Layer
    """

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels
        self.bnorm_channels = int(in_channels / 2)
        self.inorm_channels = in_channels - self.bnorm_channels

        self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
        self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
        
    def forward(self, x):
        bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
        in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())

        return torch.cat((bn_x, in_x), 1)

class Conv2dIBNormRelu(nn.Module):
    """ Convolution + IBNorm + ReLu
    """

    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, padding=0, dilation=1, groups=1, bias=True, 
                 with_ibn=True, with_relu=True):
        super(Conv2dIBNormRelu, self).__init__()

        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size, 
                      stride=stride, padding=padding, dilation=dilation, 
                      groups=groups, bias=bias)
        ]

        if with_ibn:       
            layers.append(IBNorm(out_channels))
        if with_relu:
            layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x) 

class SEBlock(nn.Module):
    """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 
    """

    def __init__(self, in_channels, out_channels, reduction=1):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels // reduction), bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels // reduction), out_channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)

        return x * w.expand_as(x)

In [10]:
SUPPORTED_BACKBONES = {
    'mobilenetv2': MobileNetV2Backbone,
}

class LRBranch(nn.Module):
    """ Low Resolution Branch of MODNet
    """

    def __init__(self, backbone):
        super(LRBranch, self).__init__()

        enc_channels = backbone.enc_channels
        
        print("---------", enc_channels)
        
        self.backbone = backbone
        self.se_block = SEBlock(enc_channels[3], enc_channels[3], reduction=4) # enc_channels[4]
        self.conv_lr16x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) # enc_channels[4],  enc_channels[3]
        self.conv_lr8x = Conv2dIBNormRelu(enc_channels[2], enc_channels[2], 5, stride=1, padding=2) # enc_channels[3]
        self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)

    def forward(self, img, inference=False):
        enc_features = self.backbone.forward(img)
        enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[3] # enc_features[4]

        enc32x = self.se_block(enc32x)
        lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
        lr16x = self.conv_lr16x(lr16x)
        lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
        lr8x = self.conv_lr8x(lr8x)

        pred_semantic = None
        if not inference:
            lr = self.conv_lr(lr8x)
            pred_semantic = torch.sigmoid(lr)
        
        return pred_semantic, lr8x, [enc2x, enc4x] 
  
class HRBranch(nn.Module):
    """ High Resolution Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(HRBranch, self).__init__()

        self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
        self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)

        self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
        self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)

        self.conv_hr4x = nn.Sequential(
            Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr2x = nn.Sequential(
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, enc2x, enc4x, lr8x, inference):
        #img2x = F.interpolate(img, scale_factor=1, mode='bilinear', align_corners=False)
        #img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)

        enc2x = self.tohr_enc2x.cuda()(enc2x)
        img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
        hr4x = self.conv_enc2x.cuda()(torch.cat((img2x, enc2x), dim=1))

        enc4x = self.tohr_enc4x.cuda()(enc4x)
        hr4x = self.conv_enc4x.cuda()(torch.cat((hr4x, enc4x), dim=1))

        lr4x = F.interpolate(lr8x, scale_factor=1, mode='bilinear', align_corners=False)
        img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
        
        print("---", hr4x.shape, lr4x.shape, img4x.shape)
        
        hr4x = self.conv_hr4x.cuda()(torch.cat((hr4x, lr4x, img4x), dim=1))

        hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
        hr2x = self.conv_hr2x.cuda()(torch.cat((hr2x, enc2x), dim=1))

        pred_detail = None
        if not inference:
            hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
            hr = self.conv_hr.cuda()(torch.cat((hr, img), dim=1))
            pred_detail = torch.sigmoid(hr)

        return pred_detail, hr2x
    
class FusionBranch(nn.Module):
    """ Fusion Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(FusionBranch, self).__init__()
        self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
        
        self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
        self.conv_f = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
            Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, lr8x, hr2x):
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        lr4x = self.conv_lr4x.cuda()(lr4x)
        lr2x = F.interpolate(lr4x, scale_factor=1, mode='bilinear', align_corners=False)

        f2x = self.conv_f2x.cuda()(torch.cat((lr2x, hr2x), dim=1))
        f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
        f = self.conv_f.cuda()(torch.cat((f, img), dim=1))
        #pred_matte = torch.sigmoid(f)

        #return pred_matte
        return f

In [11]:
#------------------------------------------------------------------------------
#  MODNet
#------------------------------------------------------------------------------

class MODNet(nn.Module):
    """ Architecture of MODNet
    """

    def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
        super(MODNet, self).__init__()

        self.in_channels = in_channels
        self.hr_channels = hr_channels
        self.backbone_arch = backbone_arch 
        self.backbone_pretrained = backbone_pretrained

        self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)

        self.lr_branch = LRBranch(self.backbone)
        self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
        self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                self._init_conv(m)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
                self._init_norm(m)

        if self.backbone_pretrained:
            self.backbone.load_pretrained_ckpt()                

    def forward(self, img, inference=False):
        pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
        pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
        pred_matte = self.f_branch(img, lr8x, hr2x)

        return pred_semantic, pred_detail, pred_matte
    
    def freeze_norm(self):
        norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
        for m in self.modules():
            for n in norm_types:
                if isinstance(m, n):
                    m.eval()
                    continue

    def _init_conv(self, conv):
        nn.init.kaiming_uniform_(
            conv.weight, a=0, mode='fan_in', nonlinearity='relu')
        if conv.bias is not None:
            nn.init.constant_(conv.bias, 0)

    def _init_norm(self, norm):
        if norm.weight is not None:
            nn.init.constant_(norm.weight, 1)
            nn.init.constant_(norm.bias, 0)

In [12]:
from torchsummary import summary

model = MODNet()
summary(model, (3, 224, 224))

--------- [16, 24, 32, 96, 1280]
--- torch.Size([2, 64, 56, 56]) torch.Size([2, 32, 56, 56]) torch.Size([2, 3, 56, 56])
Layer (type:depth-idx)                        Output Shape              Param #
├─LRBranch: 1-1                               [-1, 1, 28, 28]           --
|    └─SEBlock: 2-1                           [-1, 96, 14, 14]          --
|    |    └─AdaptiveAvgPool2d: 3-1            [-1, 96, 1, 1]            --
|    |    └─Sequential: 3-2                   [-1, 96]                  4,608
|    └─Conv2dIBNormRelu: 2-2                  [-1, 32, 28, 28]          --
|    |    └─Sequential: 3-3                   [-1, 32, 28, 28]          76,864
|    └─Conv2dIBNormRelu: 2-3                  [-1, 32, 56, 56]          --
|    |    └─Sequential: 3-4                   [-1, 32, 56, 56]          25,664
|    └─Conv2dIBNormRelu: 2-4                  [-1, 1, 28, 28]           --
|    |    └─Sequential: 3-5                   [-1, 1, 28, 28]           289
├─HRBranch: 1-2                       

Layer (type:depth-idx)                        Output Shape              Param #
├─LRBranch: 1-1                               [-1, 1, 28, 28]           --
|    └─SEBlock: 2-1                           [-1, 96, 14, 14]          --
|    |    └─AdaptiveAvgPool2d: 3-1            [-1, 96, 1, 1]            --
|    |    └─Sequential: 3-2                   [-1, 96]                  4,608
|    └─Conv2dIBNormRelu: 2-2                  [-1, 32, 28, 28]          --
|    |    └─Sequential: 3-3                   [-1, 32, 28, 28]          76,864
|    └─Conv2dIBNormRelu: 2-3                  [-1, 32, 56, 56]          --
|    |    └─Sequential: 3-4                   [-1, 32, 56, 56]          25,664
|    └─Conv2dIBNormRelu: 2-4                  [-1, 1, 28, 28]           --
|    |    └─Sequential: 3-5                   [-1, 1, 28, 28]           289
├─HRBranch: 1-2                               [-1, 1, 224, 224]         --
|    └─Conv2dIBNormRelu: 2-5                  [-1, 32, 112, 112]        --
|    |  