<a href="https://colab.research.google.com/github/satyajitghana/ProjektDepth/blob/master/notebooks/06_DepthModel_ModelZoo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DepthModel - Model Zoo

In [0]:
! pip install funcy



In [0]:
import torch
import torch.nn as nn
import funcy
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def summary(model, input_size, batch_size=-1, device="cuda"):

    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
        ):
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    print("----------------------------------------------------------------")
    line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    print(line_new)
    print("================================================================")
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_output = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_output += np.prod(list(funcy.flatten(summary[layer]["output_shape"])))
        print(line_new)

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(sum([np.prod(input_item) for input_item in input_size]) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("================================================================")
    print("Total params: {0:,}".format(total_params))
    print("Trainable params: {0:,}".format(trainable_params))
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    print("----------------------------------------------------------------")
    print("Input size (MB): %0.2f" % total_input_size)
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    print("Params size (MB): %0.2f" % total_params_size)
    print("Estimated Total Size (MB): %0.2f" % total_size)
    print("----------------------------------------------------------------")
    # return summary

In [0]:
! nvidia-smi

Mon May 18 09:19:33 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.82       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P8    27W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

Here we make all the different models we could use with our dataset

In [0]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import gc

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Custom Unet - ResNet Backbone

Note: All the Resnet Blocks use the ResNetV2 Architecure, i.e. preactivated blocks

In [0]:
class ResDoubleConv(nn.Module):
    '''Basic DoubleConv of a ResNetV2'''
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.double_conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        )

    def forward(self, x):
        return self.double_conv(x)

In [0]:
summary(ResDoubleConv(512, 1024).to(device), (512, 12, 12))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1          [-1, 512, 12, 12]           1,024
              ReLU-2          [-1, 512, 12, 12]               0
            Conv2d-3         [-1, 1024, 12, 12]       4,718,592
       BatchNorm2d-4         [-1, 1024, 12, 12]           2,048
              ReLU-5         [-1, 1024, 12, 12]               0
            Conv2d-6         [-1, 1024, 12, 12]       9,437,184
Total params: 14,158,848
Trainable params: 14,158,848
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.28
Forward/backward pass size (MB): 5.62
Params size (MB): 54.01
Estimated Total Size (MB): 59.92
----------------------------------------------------------------


In [0]:
class ResDownBlock(nn.Module):
    '''Basic DownBlock of a ResNetV2'''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.double_conv = ResDoubleConv(in_channels, out_channels)

        self.proj_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.down_sample = nn.MaxPool2d(2)

    def forward(self, input):
        identity = self.proj_layer(input)
        out = self.double_conv(input)
        out = out + identity

        del identity

        return self.down_sample(out), out

In [0]:
summary(ResDownBlock(256, 512).to(device), (256, 24, 24))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 512, 24, 24]         131,072
       BatchNorm2d-2          [-1, 512, 24, 24]           1,024
       BatchNorm2d-3          [-1, 256, 24, 24]             512
              ReLU-4          [-1, 256, 24, 24]               0
            Conv2d-5          [-1, 512, 24, 24]       1,179,648
       BatchNorm2d-6          [-1, 512, 24, 24]           1,024
              ReLU-7          [-1, 512, 24, 24]               0
            Conv2d-8          [-1, 512, 24, 24]       2,359,296
     ResDoubleConv-9          [-1, 512, 24, 24]               0
        MaxPool2d-10          [-1, 512, 12, 12]               0
Total params: 3,672,576
Trainable params: 3,672,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.56
Forward/backward pass size (MB): 18.56
Params size (MB): 14.01
Estim

In [0]:
class ResUpBlock(nn.Module):
    '''Basic UpBlock of a ResNetV2'''
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.upsample_1 = nn.PixelShuffle(2)
        self.upsample_2 = nn.PixelShuffle(2)
        self.upsample_3 = nn.PixelShuffle(2)
        self.upsample_4 = nn.PixelShuffle(2)

        self.upscale = nn.Upsample(scale_factor=2, mode='bilinear')

        self.double_conv = ResDoubleConv(in_channels, out_channels)

        self.proj_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, down_input, skip_input, decoder_input=None):

        upsampled = [self.upsample_1(down_input), self.upsample_2(down_input), self.upsample_3(down_input), self.upsample_4(down_input)]
        x = torch.cat(upsampled, dim=1)
        x = torch.cat([x, skip_input], dim=1)

        if decoder_input is not None:
            x = torch.cat([x, decoder_input], dim=1)

        identity = self.proj_layer(x)

        out = self.double_conv(x) + identity

        del identity, upsampled, x

        return out

In [0]:
summary(ResUpBlock(512 + 256, 256).to(device), [(512, 24, 24), (256, 48, 48)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
      PixelShuffle-1          [-1, 128, 48, 48]               0
      PixelShuffle-2          [-1, 128, 48, 48]               0
      PixelShuffle-3          [-1, 128, 48, 48]               0
      PixelShuffle-4          [-1, 128, 48, 48]               0
            Conv2d-5          [-1, 256, 48, 48]         196,608
       BatchNorm2d-6          [-1, 256, 48, 48]             512
       BatchNorm2d-7          [-1, 768, 48, 48]           1,536
              ReLU-8          [-1, 768, 48, 48]               0
            Conv2d-9          [-1, 256, 48, 48]       1,769,472
      BatchNorm2d-10          [-1, 256, 48, 48]             512
             ReLU-11          [-1, 256, 48, 48]               0
           Conv2d-12          [-1, 256, 48, 48]         589,824
    ResDoubleConv-13          [-1, 256, 48, 48]               0
Total params: 2,558,464
Trainable param

In [0]:
class ResUNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.res_down1 = ResDownBlock(6, 64)     # H / 2   ; input = 192x192x6 ; output = 96x96x64   ; skip1 = 192x192x64
        self.res_down2 = ResDownBlock(64, 128)   # H / 4   ; input = 96x96x64  ; output = 48x48x128  ; skip2 = 96x96x128
        self.res_down3 = ResDownBlock(128, 256)  # H / 8   ; input = 48x48x128 ; output = 24x24x256  ; skip3 = 48x48x256
        self.res_down4 = ResDownBlock(256, 512)  # H / 16  ; input = 24x24x256 ; output = 12x12x512  ; skip4 = 24x24x512 

        # Bridge
        self.bridge =  ResDoubleConv(512, 512)

        # Depth Decoder
        self.d_res_up4 = ResUpBlock(512 + 512, 512)  # H / 8  ; input = 24x24x1024(upscaled)  24x24x512(skip4)  ; output = 24x24x512(dskip4)
        self.d_res_up3 = ResUpBlock(512 + 256, 256)   # H / 4  ; input = 48x48x512(upscaled)   48x48x256(skip3)  ; output = 48x48x256(dskip3)
        self.d_res_up2 = ResUpBlock(256 + 128, 128)   # H / 2  ; input = 96x96x256(upscaled)   96x96x128(skip2)  ; output = 96x96x128(dskip2)
        self.d_res_up1 = ResUpBlock(128 + 64, 64)     # H / 1  ; input = 192x192x128(upscaled) 192x192x64(skip1) ; output = 192x192x64(dskip1)

        # Depth Output
        self.depth_output = nn.Conv2d(64, 1, kernel_size=1, stride=1, bias=False)  # output = 192x192x1

        # Segmentation Decoder
        self.s_res_up4 = ResUpBlock(512 + 512 + 512, 512)  # H / 8  ; input = 24x24x1024(upscaled)  24x24x512(dskip4)  24x24x512(skip4)  ; output = 24x24x512
        self.s_res_up3 = ResUpBlock(512 + 256 + 256, 256)   # H / 4  ; input = 48x48x512(upscaled)   48x48x256(dskip3)  48x48x256(skip3)  ; output = 48x48x256
        self.s_res_up2 = ResUpBlock(256 + 128 + 128, 128)   # H / 2  ; input = 96x96x256(upscaled)   96x96x128(dskip2)  96x96x128(skip2)  ; output = 96x96x128
        self.s_res_up1 = ResUpBlock(128 + 64 + 64, 64)      # H / 1  ; input = 192x192x128(upscaled) 192x192x64(dskip1) 192x192x64(skip1) ; output = 192x192x64


        # Segmentation Output
        self.segment_output = nn.Conv2d(64, 1, kernel_size=1, stride=1, bias=False) # output = 192x192x1

    def forward(self, input):

        # Encoder
        rd1, skip1_out = self.res_down1(input)
        rd2, skip2_out = self.res_down2(rd1)
        rd3, skip3_out = self.res_down3(rd2)
        rd4, skip4_out = self.res_down4(rd3)

        # Bridge
        bridge = self.bridge(rd4)

        # Depth Decoder
        dru4 = self.d_res_up4(bridge, skip4_out)
        dru3 = self.d_res_up3(dru4, skip3_out)
        dru2 = self.d_res_up2(dru3, skip2_out)
        dru1 = self.d_res_up1(dru2, skip1_out)

        d_out = self.depth_output(dru1)

        # Segmentation Decoder
        sru4 = self.s_res_up4(bridge, skip4_out, dru4)
        sru3 = self.s_res_up3(sru4, skip3_out, dru3)
        sru2 = self.s_res_up2(sru3, skip2_out, dru2)
        sru1 = self.s_res_up1(sru2, skip1_out, dru1)


        s_out = self.segment_output(sru1)


        return d_out, s_out

In [0]:
model = ResUNet().to(device)

In [0]:
summary(model, (6, 192, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 192, 192]             384
       BatchNorm2d-2         [-1, 64, 192, 192]             128
       BatchNorm2d-3          [-1, 6, 192, 192]              12
              ReLU-4          [-1, 6, 192, 192]               0
            Conv2d-5         [-1, 64, 192, 192]           3,456
       BatchNorm2d-6         [-1, 64, 192, 192]             128
              ReLU-7         [-1, 64, 192, 192]               0
            Conv2d-8         [-1, 64, 192, 192]          36,864
     ResDoubleConv-9         [-1, 64, 192, 192]               0
        MaxPool2d-10           [-1, 64, 96, 96]               0
     ResDownBlock-11  [[-1, 64, 96, 96], [-1, 64, 192, 192]]               0
           Conv2d-12          [-1, 128, 96, 96]           8,192
      BatchNorm2d-13          [-1, 128, 96, 96]             256
      BatchNorm2d-14      

In [0]:
summary(model, (6, 192, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 192, 192]             384
       BatchNorm2d-2         [-1, 64, 192, 192]             128
       BatchNorm2d-3          [-1, 6, 192, 192]              12
              ReLU-4          [-1, 6, 192, 192]               0
            Conv2d-5         [-1, 64, 192, 192]           3,456
       BatchNorm2d-6         [-1, 64, 192, 192]             128
              ReLU-7         [-1, 64, 192, 192]               0
            Conv2d-8         [-1, 64, 192, 192]          36,864
     ResDoubleConv-9         [-1, 64, 192, 192]               0
        MaxPool2d-10           [-1, 64, 96, 96]               0
     ResDownBlock-11  [[-1, 64, 96, 96], [-1, 64, 192, 192]]               0
           Conv2d-12          [-1, 128, 96, 96]           8,192
      BatchNorm2d-13          [-1, 128, 96, 96]             256
      BatchNorm2d-14      

# Custom UNet - ResNeXt Backbone

The only change is the DoubleConv Backbone, everything else remains the same

In [0]:
class ResDoubleConv(nn.Module):
    '''Basic DoubleConv of a ResNeXt'''
    def __init__(self, in_channels, out_channels):
        super(ResDoubleConv, self).__init__()

        cardinality = 32

        widen_factor = 6
        base_width = 64

        width_ratio = out_channels / (widen_factor * 64.)
        D = cardinality * int(base_width * width_ratio)

        self.double_conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, D, kernel_size=1, bias=False),
            nn.BatchNorm2d(D),
            nn.ReLU(inplace=True),
            nn.Conv2d(D, D, kernel_size=3, padding=1, groups=cardinality, bias=False),
            nn.BatchNorm2d(D),
            nn.ReLU(inplace=True),
            nn.Conv2d(D, out_channels, kernel_size=1, bias=False)
        )

    def forward(self, x):
        out = self.double_conv(x)

        return out

In [0]:
summary(ResDoubleConv(512, 1024).to(device), (512, 12, 12))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1          [-1, 512, 12, 12]           1,024
              ReLU-2          [-1, 512, 12, 12]               0
            Conv2d-3         [-1, 5440, 12, 12]       2,785,280
       BatchNorm2d-4         [-1, 5440, 12, 12]          10,880
              ReLU-5         [-1, 5440, 12, 12]               0
            Conv2d-6         [-1, 5440, 12, 12]       8,323,200
       BatchNorm2d-7         [-1, 5440, 12, 12]          10,880
              ReLU-8         [-1, 5440, 12, 12]               0
            Conv2d-9         [-1, 1024, 12, 12]       5,570,560
Total params: 16,701,824
Trainable params: 16,701,824
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.28
Forward/backward pass size (MB): 38.11
Params size (MB): 63.71
Estimated Total Size (MB): 102.10
---------------------------------

In [0]:
model = ResUNet().to(device)

In [0]:
summary(model, (6, 192, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 192, 192]             384
       BatchNorm2d-2         [-1, 64, 192, 192]             128
       BatchNorm2d-3          [-1, 6, 192, 192]              12
              ReLU-4          [-1, 6, 192, 192]               0
            Conv2d-5        [-1, 320, 192, 192]           1,920
       BatchNorm2d-6        [-1, 320, 192, 192]             640
              ReLU-7        [-1, 320, 192, 192]               0
            Conv2d-8        [-1, 320, 192, 192]          28,800
       BatchNorm2d-9        [-1, 320, 192, 192]             640
             ReLU-10        [-1, 320, 192, 192]               0
           Conv2d-11         [-1, 64, 192, 192]          20,480
    ResDoubleConv-12         [-1, 64, 192, 192]               0
        MaxPool2d-13           [-1, 64, 96, 96]               0
     ResDownBlock-14  [[-1, 64, 96, 96]

In [0]:
model = ResUNet()
model.to(device)
for i in tqdm(range(6250)):
    input = torch.randn(64, 6, 96, 96)
    input = input.to(device)
    model(input)

    del input
    torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=0.0, max=6250.0), HTML(value='')))

KeyboardInterrupt: ignored

In [0]:
gc.collect()
torch.cuda.empty_cache()

In [0]:
torch.cuda.memory_allocated()

0

In [0]:
import torch
import torch.nn as nn
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

In [0]:
summary(ResNetUNet(10).to(device), (3, 192, 192))

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 192, 192]           1,792
              ReLU-2         [-1, 64, 192, 192]               0
            Conv2d-3         [-1, 64, 192, 192]          36,928
              ReLU-4         [-1, 64, 192, 192]               0
            Conv2d-5           [-1, 64, 96, 96]           9,408
            Conv2d-6           [-1, 64, 96, 96]           9,408
       BatchNorm2d-7           [-1, 64, 96, 96]             128
       BatchNorm2d-8           [-1, 64, 96, 96]             128
              ReLU-9           [-1, 64, 96, 96]               0
             ReLU-10           [-1, 64, 96, 96]               0
        MaxPool2d-11           [-1, 64, 48, 48]               0
        MaxPool2d-12           [-1, 64, 48, 48]               0
           Conv2d-13           [-1, 64, 48, 48]          36,864
           Conv2d-14           [-1, 64

In [0]:
from __future__ import absolute_import, division, print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib


class conv(nn.Module):
    def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
        super(conv, self).__init__()
        self.kernel_size = kernel_size
        self.conv_base = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=kernel_size, stride=stride)
        self.normalize = nn.BatchNorm2d(num_out_layers)

    def forward(self, x):
        p = int(np.floor((self.kernel_size-1)/2))
        p2d = (p, p, p, p)
        x = self.conv_base(F.pad(x, p2d))
        x = self.normalize(x)
        return F.elu(x, inplace=True)


class convblock(nn.Module):
    def __init__(self, num_in_layers, num_out_layers, kernel_size):
        super(convblock, self).__init__()
        self.conv1 = conv(num_in_layers, num_out_layers, kernel_size, 1)
        self.conv2 = conv(num_out_layers, num_out_layers, kernel_size, 2)

    def forward(self, x):
        x = self.conv1(x)
        return self.conv2(x)


class maxpool(nn.Module):
    def __init__(self, kernel_size):
        super(maxpool, self).__init__()
        self.kernel_size = kernel_size

    def forward(self, x):
        p = int(np.floor((self.kernel_size-1) / 2))
        p2d = (p, p, p, p)
        return F.max_pool2d(F.pad(x, p2d), self.kernel_size, stride=2)


class resconv(nn.Module):
    def __init__(self, num_in_layers, num_out_layers, stride):
        super(resconv, self).__init__()
        self.num_out_layers = num_out_layers
        self.stride = stride
        self.conv1 = conv(num_in_layers, num_out_layers, 1, 1)
        self.conv2 = conv(num_out_layers, num_out_layers, 3, stride)
        self.conv3 = nn.Conv2d(num_out_layers, 4*num_out_layers, kernel_size=1, stride=1)
        self.conv4 = nn.Conv2d(num_in_layers, 4*num_out_layers, kernel_size=1, stride=stride)
        self.normalize = nn.BatchNorm2d(4*num_out_layers)

    def forward(self, x):
        # do_proj = x.size()[1] != self.num_out_layers or self.stride == 2
        do_proj = True
        shortcut = []
        x_out = self.conv1(x)
        x_out = self.conv2(x_out)
        x_out = self.conv3(x_out)
        if do_proj:
            shortcut = self.conv4(x)
        else:
            shortcut = x
        return F.elu(self.normalize(x_out + shortcut), inplace=True)


class resconv_basic(nn.Module):
    # for resnet18
    def __init__(self, num_in_layers, num_out_layers, stride):
        super(resconv_basic, self).__init__()
        self.num_out_layers = num_out_layers
        self.stride = stride
        self.conv1 = conv(num_in_layers, num_out_layers, 3, stride)
        self.conv2 = conv(num_out_layers, num_out_layers, 3, 1)
        self.conv3 = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=1, stride=stride)
        self.normalize = nn.BatchNorm2d(num_out_layers)

    def forward(self, x):
        #         do_proj = x.size()[1] != self.num_out_layers or self.stride == 2
        do_proj = True
        shortcut = []
        x_out = self.conv1(x)
        x_out = self.conv2(x_out)
        if do_proj:
            shortcut = self.conv3(x)
        else:
            shortcut = x
        return F.elu(self.normalize(x_out + shortcut), inplace=True)


def resblock(num_in_layers, num_out_layers, num_blocks, stride):
    layers = []
    layers.append(resconv(num_in_layers, num_out_layers, stride))
    for i in range(1, num_blocks - 1):
        layers.append(resconv(4 * num_out_layers, num_out_layers, 1))
    layers.append(resconv(4 * num_out_layers, num_out_layers, 1))
    return nn.Sequential(*layers)


def resblock_basic(num_in_layers, num_out_layers, num_blocks, stride):
    layers = []
    layers.append(resconv_basic(num_in_layers, num_out_layers, stride))
    for i in range(1, num_blocks):
        layers.append(resconv_basic(num_out_layers, num_out_layers, 1))
    return nn.Sequential(*layers)


class upconv(nn.Module):
    def __init__(self, num_in_layers, num_out_layers, kernel_size, scale):
        super(upconv, self).__init__()
        self.scale = scale
        self.conv1 = conv(num_in_layers, num_out_layers, kernel_size, 1)

    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=True)
        return self.conv1(x)


class get_disp(nn.Module):
    def __init__(self, num_in_layers):
        super(get_disp, self).__init__()
        self.conv1 = nn.Conv2d(num_in_layers, 2, kernel_size=3, stride=1)
        self.normalize = nn.BatchNorm2d(2)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        p = 1
        p2d = (p, p, p, p)
        x = self.conv1(F.pad(x, p2d))
        x = self.normalize(x)
        return 0.3 * self.sigmoid(x)


class Resnet50_md(nn.Module):
    def __init__(self, num_in_layers):
        super(Resnet50_md, self).__init__()
        # encoder
        self.conv1 = conv(num_in_layers, 64, 7, 2)  # H/2  -   64D
        self.pool1 = maxpool(3)  # H/4  -   64D
        self.conv2 = resblock(64, 64, 3, 2)  # H/8  -  256D
        self.conv3 = resblock(256, 128, 4, 2)  # H/16 -  512D
        self.conv4 = resblock(512, 256, 6, 2)  # H/32 - 1024D
        self.conv5 = resblock(1024, 512, 3, 2)  # H/64 - 2048D

        # decoder
        self.upconv6 = upconv(2048, 512, 3, 2)
        self.iconv6 = conv(1024 + 512, 512, 3, 1)

        self.upconv5 = upconv(512, 256, 3, 2)
        self.iconv5 = conv(512+256, 256, 3, 1)

        self.upconv4 = upconv(256, 128, 3, 2)
        self.iconv4 = conv(256+128, 128, 3, 1)
        self.disp4_layer = get_disp(128)

        self.upconv3 = upconv(128, 64, 3, 2)
        self.iconv3 = conv(64+64+2, 64, 3, 1)
        self.disp3_layer = get_disp(64)

        self.upconv2 = upconv(64, 32, 3, 2)
        self.iconv2 = conv(32+64+2, 32, 3, 1)
        self.disp2_layer = get_disp(32)

        self.upconv1 = upconv(32, 16, 3, 2)
        self.iconv1 = conv(16+2, 16, 3, 1)
        self.disp1_layer = get_disp(16)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        # encoder
        x1 = self.conv1(x)
        x_pool1 = self.pool1(x1)
        x2 = self.conv2(x_pool1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)

        # skips
        skip1 = x1
        skip2 = x_pool1
        skip3 = x2
        skip4 = x3
        skip5 = x4

        # decoder
        upconv6 = self.upconv6(x5)
        concat6 = torch.cat((upconv6, skip5), 1)
        iconv6 = self.iconv6(concat6)

        upconv5 = self.upconv5(iconv6)
        concat5 = torch.cat((upconv5, skip4), 1)
        iconv5 = self.iconv5(concat5)

        upconv4 = self.upconv4(iconv5)
        concat4 = torch.cat((upconv4, skip3), 1)
        iconv4 = self.iconv4(concat4)
        self.disp4 = self.disp4_layer(iconv4)
        self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=2, mode='bilinear', align_corners=True)

        upconv3 = self.upconv3(iconv4)
        concat3 = torch.cat((upconv3, skip2, self.udisp4), 1)
        iconv3 = self.iconv3(concat3)
        self.disp3 = self.disp3_layer(iconv3)
        self.udisp3 = nn.functional.interpolate(self.disp3, scale_factor=2, mode='bilinear', align_corners=True)

        upconv2 = self.upconv2(iconv3)
        concat2 = torch.cat((upconv2, skip1, self.udisp3), 1)
        iconv2 = self.iconv2(concat2)
        self.disp2 = self.disp2_layer(iconv2)
        self.udisp2 = nn.functional.interpolate(self.disp2, scale_factor=2, mode='bilinear', align_corners=True)

        upconv1 = self.upconv1(iconv2)
        concat1 = torch.cat((upconv1, self.udisp2), 1)
        iconv1 = self.iconv1(concat1)
        self.disp1 = self.disp1_layer(iconv1)
        return self.disp1, self.disp2, self.disp3, self.disp4

In [0]:
summary(Resnet50_md(3).to(device), (3, 192, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 96, 96]           9,472
       BatchNorm2d-2           [-1, 64, 96, 96]             128
              conv-3           [-1, 64, 96, 96]               0
           maxpool-4           [-1, 64, 48, 48]               0
            Conv2d-5           [-1, 64, 48, 48]           4,160
       BatchNorm2d-6           [-1, 64, 48, 48]             128
              conv-7           [-1, 64, 48, 48]               0
            Conv2d-8           [-1, 64, 24, 24]          36,928
       BatchNorm2d-9           [-1, 64, 24, 24]             128
             conv-10           [-1, 64, 24, 24]               0
           Conv2d-11          [-1, 256, 24, 24]          16,640
           Conv2d-12          [-1, 256, 24, 24]          16,640
      BatchNorm2d-13          [-1, 256, 24, 24]             512
          resconv-14          [-1, 256,

In [0]:
model = Resnet50_md(3)
model = model.to(device)

In [0]:
model = Resnet50_md(3)
model = model.to(device)

model.train()
for i in tqdm(range(1000)):
    input = torch.randn(128, 3, 192, 192)
    input = input.to(device)

    model(input)

    del input

    torch.cuda.empty_cache()

In [0]:
torch.cuda.memory_allocated()

0

In [0]:
del input

NameError: ignored

In [0]:
gc.collect()
torch.cuda.empty_cache()

In [0]:
import torch
import torch.nn as nn


class ResDoubleConv(nn.Module):
    '''Basic DoubleConv of a ResNetV2'''

    def __init__(self, in_channels, out_channels):
        super(ResDoubleConv, self).__init__()

        self.double_conv = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels,
                      kernel_size=3, padding=1, bias=False)
        )

    def forward(self, x):
        out = self.double_conv(x)

        return out


class ResDownBlock(nn.Module):
    '''Basic DownBlock of a ResNetV2'''

    def __init__(self, in_channels, out_channels):
        super(ResDownBlock, self).__init__()

        self.double_conv = ResDoubleConv(in_channels, out_channels)

        self.proj_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.down_sample = nn.MaxPool2d(2)

    def forward(self, input):
        identity = self.proj_layer(input)
        out = self.double_conv(input)
        out = out + identity

        return self.down_sample(out), out


class ResUpBlock(nn.Module):
    '''Basic UpBlock of a ResNetV2'''

    def __init__(self, in_channels, out_channels):
        super(ResUpBlock, self).__init__()

        self.upsample_1 = nn.PixelShuffle(2)
        self.upsample_2 = nn.PixelShuffle(2)
        self.upsample_3 = nn.PixelShuffle(2)
        self.upsample_4 = nn.PixelShuffle(2)

        self.double_conv = ResDoubleConv(in_channels, out_channels)

        self.proj_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, down_input, skip_input, decoder_input=None):

        upsampled = [self.upsample_1(down_input), self.upsample_2(
            down_input), self.upsample_3(down_input), self.upsample_4(down_input)]
        x = torch.cat(upsampled, dim=1)
        x = torch.cat([x, skip_input], dim=1)

        if decoder_input is not None:
            x = torch.cat([x, decoder_input], dim=1)

        identity = self.proj_layer(x)

        out = self.double_conv(x) + identity

        return out


class ResUNet(nn.Module):
    def __init__(self):
        super(ResUNet, self).__init__()

        # Encoder
        # H / 2   ; input = 192x192x6 ; output = 96x96x64   ; skip1 = 192x192x64
        self.res_down1 = ResDownBlock(6, 64)
        # H / 4   ; input = 96x96x64  ; output = 48x48x128  ; skip2 = 96x96x128
        self.res_down2 = ResDownBlock(64, 128)
        # H / 8   ; input = 48x48x128 ; output = 24x24x256  ; skip3 = 48x48x256
        self.res_down3 = ResDownBlock(128, 256)
        # H / 16  ; input = 24x24x256 ; output = 12x12x512  ; skip4 = 24x24x512
        self.res_down4 = ResDownBlock(256, 512)

        # Bridge
        self.bridge = ResDoubleConv(512, 512)

        # Depth Decoder
        # H / 8  ; input = 24x24x1024(upscaled)  24x24x512(skip4)  ; output = 24x24x512(dskip4)
        self.d_res_up4 = ResUpBlock(512 + 512, 512)
        # H / 4  ; input = 48x48x512(upscaled)   48x48x256(skip3)  ; output = 48x48x256(dskip3)
        self.d_res_up3 = ResUpBlock(512 + 256, 256)
        # H / 2  ; input = 96x96x256(upscaled)   96x96x128(skip2)  ; output = 96x96x128(dskip2)
        self.d_res_up2 = ResUpBlock(256 + 128, 128)
        # H / 1  ; input = 192x192x128(upscaled) 192x192x64(skip1) ; output = 192x192x64(dskip1)
        self.d_res_up1 = ResUpBlock(128 + 64, 64)

        # Depth Output
        self.depth_output = nn.Conv2d(
            64, 1, kernel_size=1, stride=1, bias=False)  # output = 192x192x1

        # Segmentation Decoder
        # H / 8  ; input = 24x24x1024(upscaled)  24x24x512(dskip4)  24x24x512(skip4)  ; output = 24x24x512
        self.s_res_up4 = ResUpBlock(512 + 512 + 512, 512)
        # H / 4  ; input = 48x48x512(upscaled)   48x48x256(dskip3)  48x48x256(skip3)  ; output = 48x48x256
        self.s_res_up3 = ResUpBlock(512 + 256 + 256, 256)
        # H / 2  ; input = 96x96x256(upscaled)   96x96x128(dskip2)  96x96x128(skip2)  ; output = 96x96x128
        self.s_res_up2 = ResUpBlock(256 + 128 + 128, 128)
        # H / 1  ; input = 192x192x128(upscaled) 192x192x64(dskip1) 192x192x64(skip1) ; output = 192x192x64
        self.s_res_up1 = ResUpBlock(128 + 64 + 64, 64)

        # Segmentation Output
        self.segment_output = nn.Conv2d(
            64, 1, kernel_size=1, stride=1, bias=False)  # output = 192x192x1

    def forward(self, input):

        # Encoder
        rd1, skip1_out = self.res_down1(input)
        rd2, skip2_out = self.res_down2(rd1)
        rd3, skip3_out = self.res_down3(rd2)
        rd4, skip4_out = self.res_down4(rd3)

        # Bridge
        bridge = self.bridge(rd4)

        # Depth Decoder
        dru4 = self.d_res_up4(bridge, skip4_out)
        dru3 = self.d_res_up3(dru4, skip3_out)
        dru2 = self.d_res_up2(dru3, skip2_out)
        dru1 = self.d_res_up1(dru2, skip1_out)

        d_out = self.depth_output(dru1)

        # Segmentation Decoder
        sru4 = self.s_res_up4(bridge, skip4_out, dru4)
        sru3 = self.s_res_up3(sru4, skip3_out, dru3)
        sru2 = self.s_res_up2(sru3, skip2_out, dru2)
        sru1 = self.s_res_up1(sru2, skip1_out, dru1)

        s_out = self.segment_output(sru1)
        return d_out, s_out

In [0]:
model = ResUNet().to(device)

In [0]:
summary(model, (6, 192, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 192, 192]             384
       BatchNorm2d-2         [-1, 64, 192, 192]             128
       BatchNorm2d-3          [-1, 6, 192, 192]              12
              ReLU-4          [-1, 6, 192, 192]               0
            Conv2d-5         [-1, 64, 192, 192]           3,456
       BatchNorm2d-6         [-1, 64, 192, 192]             128
              ReLU-7         [-1, 64, 192, 192]               0
            Conv2d-8         [-1, 64, 192, 192]          36,864
     ResDoubleConv-9         [-1, 64, 192, 192]               0
        MaxPool2d-10           [-1, 64, 96, 96]               0
     ResDownBlock-11  [[-1, 64, 96, 96], [-1, 64, 192, 192]]               0
           Conv2d-12          [-1, 128, 96, 96]           8,192
      BatchNorm2d-13          [-1, 128, 96, 96]             256
      BatchNorm2d-14      

In [0]:
print(sum(p.numel() for p in model.parameters()))

34997388


In [0]:
34997388 * 4 / 1024/1024

133.5044403076172