## UNet/FCN Based on ResNet

This repository contains simple PyTorch implementations of U-Net and FCN, which are deep learning segmentation methods proposed by Ronneberger et al. and Long et al.

In [11]:
import torch
import torch.nn as nn
from torchvision import models

In [12]:
base_model = models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /Users/ZRC/.cache/torch/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [26]:
for index, (name, layer) in enumerate(base_model.named_children()):
    print(index, name, " -> ", layer)
    print("--"* 40)

0 conv1  ->  Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
--------------------------------------------------------------------------------
1 bn1  ->  BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
--------------------------------------------------------------------------------
2 relu  ->  ReLU(inplace=True)
--------------------------------------------------------------------------------
3 maxpool  ->  MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
--------------------------------------------------------------------------------
4 layer1  ->  Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64

## Reference

- https://github.com/usuyama/pytorch-unet

In [7]:
%load_ext autoreload
%autoreload 2

In [22]:
len(list(base_model.children()))

10

In [30]:
import torch
import torch.nn as nn
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNetZrc(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)  # 64

        layer0 = self.layer0(input)    # 64  /2
        layer1 = self.layer1(layer0)#  # 64  /4
        layer2 = self.layer2(layer1)   # 128 /8  
        layer3 = self.layer3(layer2)   # 256 /16 
        layer4 = self.layer4(layer3)   # 512 /32

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)      # 512 / 16
        layer3 = self.layer3_1x1(layer3) # 256 / 16
        x = torch.cat([x, layer3], dim=1) # 256 + 512 / 16
        x = self.conv_up3(x)       # 512   /16

        x = self.upsample(x)       # 512   /8
        layer2 = self.layer2_1x1(layer2) #  128  /8
        x = torch.cat([x, layer2], dim=1) # 512 + 128 /8
        x = self.conv_up2(x)    # 256 /8

        x = self.upsample(x)    # 256 /4
        layer1 = self.layer1_1x1(layer1)  # 64 / 4
        x = torch.cat([x, layer1], dim=1) # 64 + 256 /4
        x = self.conv_up1(x)   # 256  /4

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)    # 128  /2

        x = self.upsample(x)    # 128 
        x = torch.cat([x, x_original], dim=1)  #  128 + 64
        x = self.conv_original_size2(x)    

        out = self.conv_last(x)

        return out

In [31]:
from torchsummary import summary
model = ResNetUNetZrc(n_class = 10)
summary(model, (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 112, 112]           9,408
            Conv2d-6         [-1, 64, 112, 112]           9,408
       BatchNorm2d-7         [-1, 64, 112, 112]             128
       BatchNorm2d-8         [-1, 64, 112, 112]             128
              ReLU-9         [-1, 64, 112, 112]               0
             ReLU-10         [-1, 64, 112, 112]               0
        MaxPool2d-11           [-1, 64, 56, 56]               0
        MaxPool2d-12           [-1, 64, 56, 56]               0
           Conv2d-13           [-1, 64, 56, 56]          36,864
           Conv2d-14           [-1, 64,