# F-Res-UNet

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d

In [2]:
def conv3x3(in_planes, out_planes, stride = 1, padding = 1):
    """3x3 convolutiona with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding)

In [3]:
class BasicBlock_ss(nn.Module):
    
    def __init__(self, in_planes, planes = None, sub_sample=1):
        super(BasicBlock_ss, self).__init__()
        if planes == None:
            planes = in_planes * sub_sample
        
        self.conv1 = conv3x3(in_planes, planes)
        self.bn1   = nn.BatchNorm2d(planes)
        self.relu  = nn.ReLU(in_place=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2   = nn.BatchNorm2d(planes)
        self.sub_sample = sub_sample
        self.doit = planes != in_planes
        if self.doit:
            self.couple = nn.Conv2d(in_planes, planes, kernel_size=1)
            self.bnc    = nn.BatchNorm2d(planes)
            
    def forward(self, x):
        if self.doit:
            residual = self.couple(x)
            residual = self.bnc(residual)
        else:
            residual = x
            
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        if self.sub_sample > 1:
            out = F.max_pool2d(out, kernel_size=self.sub_sample, stride=self.sub_sample)
            residual = F.max_pool2d(residual, kernel_size=self.subsamp, stride=self.sub_sample)
            
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += residual
        out = self.relu(out)
        
        return out

In [4]:
class BasicBlock_us(nn.Module):
    
    def __init__(self, in_planes, up_sample=1):
        super(BasicBlock_us, self).__init__()
        planes = int(in_planes / up_sample) # need fix?
        self.conv1 = nn.ConvTranspose2d(in_planes, planes, kernel_size=3, padding=1, stride=up_sample, output_padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(in_planes=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.up_sample = up_sample
        self.couple = nn.ConvTranspose2d(in_planes, planes, kernel_size=3, padding=1, stride=up_sample, output_padding=1)
        self.bnc = nn.BatchNorm2d(planes)
        
        def forward(self, x):
            residual = self.couple(x)
            residual = self.bnc(residual)
            
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            
            out = self.conv2(out)
            out = self.bn2(out)
            
            out += residual
            out = self.relu(out)
            
            return out

In [6]:
class F_res_UNet(nn.Module):
    """F - residual - UNet segmentation network."""
    def __init__(self, input_nbr, label_nbr):
        """Init F-res-UNet field."""
        super(F_res_UNet, self).__init__()
        
        self.input_nbr = input_nbr
        
        cur_depth = input_nbr
        
        base_depth = 8
        
        # encoding stage 1
        self.encres1_1 = BasicBlock_ss(cur_depth, planes = base_depth)
        cur_depth = base_depth
        d1 = base_depth
        self.encres1_2 = BasicBlock_ss(cur_depth, sub_sample=2)
        cur_depth *= 2
        
        # encoding stage 2
        self.encres2_1 = BasicBlock_ss(cur_depth)
        d2 = cur_depth
        self.encres2_2 = BasicBlock_ss(cur_depth, sub_sample=2)
        cur_depth *= 2

        # encoding stage 3
        self.encres3_1 = BasicBlock_ss(cur_depth)
        d3 = cur_depth
        self.encres3_2 = BasicBlock_ss(cur_depth, sub_sample=2)
        cur_depth *= 2
        
        # encoding stage 4
        self.encres4_1 = BasicBlock_ss(cur_depth)
        d4 = cur_depth
        self.encres4_2 = BasicBlock_ss(cur_depth, sub_sample=2)
        cur_depth *= 2
        
        # decoding stage 4
        self.decres4_1 = BasicBlock_ss(cur_depth)
        self.decres4_2 = BasicBlock_us(cur_depth, up_sample=2)
        cur_depth = int(cur_depth/2)
        
        # decoding stage 3
        self.decres3_1 = BasicBlock_22(cur_depth + d4, planes = cur_depth)
        self.decres3_2 = BasicBlock_us(cur_depth, up_sample=2)
        cur_depth = int(cur_depth/2)
        
        # decoding stage 2
        self.decres2_1 = BasicBlock_ss(cur_depth + d3, planes = cur_depth)
        self.decres2_2 = BasicBlock_us(cur_depth, up_sample=2)
        cur_depth = int(cur_depth/2)
        
        # decoding stage 1
        self.decres1_1 = BasicBlock_ss(cur_depth + d2, planes = cur_depth)
        self.decres1_2 = BasicBlock_us(cur_depth, up_sample=2)
        cur_depth = int(cur_depth/2)
        
        # output
        self.coupling = nn.Conv2d(cur_depth + d1, label_nbr, kernel_size=1)
        self.sm = nn.LogSoftMax(dim=1)
        
    def forward(self, x1, x2):
        
        x = torch.cat((x1, x2), 1)
        
        s1_1 = x.size()
        x1   = self.encres1_1(x)
        x    = self.encres1_2(x1)
        
        s2_1 = x.size()
        x2   = self.encres2_1(x)
        x    = self.encres2_2(x2)
        
        s3_1 = x.size()
        x3   = self.encres3_1(x)
        x    = self.encres3_2(x3)
        
        s4_1 = x.size()
        x4   = self.encres4_1(x)
        x    = self.encres4_2(x4)
        
        x = self.decres4_1(x)
        x = self.decres4_2(x)
        s4_2 = x.size()
        pad4 = ReplicationPad2d((0, s4_1[3] - s4_2[3], 0, s4_1[2] - s4_2[2]))
        x = pad4(x)
        
        x = self.decres3_1(torch.cat((x, x4), 1))
        x = self.decres3_2(x)
        s3_2 = x.size()
        pad3 = ReplicationPad2d((0, s3_1[3] - s3_2[3], 0, s3_1[2] - s3_2[2]))
        x = pad3(x)
        
        x = self.decres2_1(torch.cat((x, x3), 1))
        x = self.decres2_2(x)
        s2_2 = x.size()
        pad2 = ReplicationPad2d((0, s2_1[3] - s2_2[3], 0, s2_1[2] - s2_2[2]))
        x = pad2(x)
        
        x = self.decres1_1(torch.cat((x, x2), 1))
        x = self.decres1_2(x)
        s1_2 = x.size()
        pad3 = ReplicationPad2d((0, s1_1[3] - s1_2[3], 0, s1_1[2] - s1_2[2]))
        x = pad1(x)
        
        x = self.coupling(torch.cat((x, x1), 1))
        x = self.sm(x)
        
        return x

In [1]:
"""
from torchvision import models
from torchsummary import summary

vgg = models.vgg16()
summary(vgg, (3, 224, 224))
"""

'\nfrom torchvision import models\nfrom torchsummary import summary\n\nvgg = models.vgg16()\nsummary(vgg, (3, 224, 224))\n'