In [None]:
from __future__ import print_function

In [None]:
import torch as tc
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class EncoderStack(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, padding=0, stride=1):
        super(EncoderStack, self).__init__()
        
        self.conv1 = nn.conv2D(in_ch, out_ch, kernel_size, stride, padding)
        self.conv2 = nn.conv2D(out_ch, out_ch, kernel_size, stride, padding)
        self.maxPool = nn.MaxPool2D(2, 2)
        
    def forward(self, input):
        x = self.conv1(input)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        conv_saved = x
        x = seslf.maxPool(x)
        return x, conv_saved

In [None]:
class DecoderStack(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, upsample_size, padding=0, stride=1):
        super(DecoderStack, self).__init__()
        
        self.upsample = nn.upsample(upsample_size, scale=2, mode='bilinear')
        self.conv1 = nn.conv2D(in_ch, out_ch, kernel_size, stride, padding)
        self.conv2 = nn.conv2D(out_ch, out_ch, kernel_size, stride, padding)
        
    
    def _crop_concat(self, upsampled, bypass):
        """
         Crop y to the (h, w) of x and concat them.
         Used for the expansive path.
        Returns:
            The concatenated tensor
        """
        c = (bypass.size()[2] - upsampled.size()[2]) // 2
        bypass = F.pad(bypass, (-c, -c, -c, -c))

        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, input, bypass):
        x = self.upsample(input)
        x = self._crop_concat(x, bypass)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x    

In [None]:
class UNet(nn.Module):
    def __init__(self, in_shape):
        super(UNet, self).__init__()
        
        channels, width, height = in_shape
        
        self.enc1 = EncoderStack(3, 64, 3)
        self.enc2 = EncoderStack(64, 128, 3)
        self.enc3 = EncoderStack(128, 256, 3)
        self.enc4 = EncoderStack(256, 512, 3)
        
        self.center = nn.Sequential(
            EncoderStack(512, 1024, kernel_size=3),
            EncoderStack(1024, 1024, kernel_size=3)
        )
        
        self.dec1 = DecoderStack(1024, 512, 3, 56)
        self.dec1 = DecoderStack(512, 256, 3, 104)
        self.dec1 = DecoderStack(256, 128, 3, 200)
        self.dec1 = DecoderStack(128, 64, 3, 392)
        
        self.conv = nn.conv2D(64, 2, 1)
        
    def forward(self, inp):
        x, enc_saved1 = self.enc1(inp)
        x, enc_saved2 = self.enc2(x)
        x, enc_saved3 = self.enc3(x)
        x, enc_saved4 = self.enc4(x)
        x = self.center(x)
        x = self.dec1(x, enc_saved4)
        x = self.dec2(x, enc_saved3)
        x = self.dec3(x, enc_saved2)
        x = self.dec4(x, enc_saved1)
        x = self.conv(x)
        return x
    