In [1]:
import math
import torch 
from torch import nn
import torch.nn.functional as F

In [18]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.conv = Conv2dSamePadding(3, 32, 3)
        
        self.conv1 = Conv2dSamePadding(32, 64, 3)
        self.conv2 = Conv2dSamePadding(64, 128, 3)
        self.conv3 = Conv2dSamePadding(128, 256, 3)
        self.conv4 = Conv2dSamePadding(256, 512, 3)
        
        self.conv_last = Conv2dSamePadding(512, 32, 3)
        
        self.conv_skip = Conv2dSamePadding(64, 2, 1)
        
        self.max_pool = nn.MaxPool2d(2, 2)
        
        self.block1 = nn.Sequential(
            Conv2dSamePadding(32, 32, 3),
            nn.BatchNorm2d(32),
            Conv2dSamePadding(32, 32, 3),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        
        self.block2 = nn.Sequential(
            Conv2dSamePadding(64, 64, 3),
            nn.BatchNorm2d(64),
            Conv2dSamePadding(64, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.block3 = nn.Sequential(
            Conv2dSamePadding(128, 128, 3),
            nn.BatchNorm2d(128),
            Conv2dSamePadding(128, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        self.block4 = nn.Sequential(
            Conv2dSamePadding(256, 256, 3),
            nn.BatchNorm2d(256),
            Conv2dSamePadding(256, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        
        self.block5 = nn.Sequential(
            Conv2dSamePadding(512, 512, 3),
            nn.BatchNorm2d(512),
            Conv2dSamePadding(512, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.conv(x)
        residual1 = x
        x = self.block1(x)
        x += residual1
        x = self.max_pool(x)
        
        x = self.conv1(x)
        residual2 = x
        x = self.block2(x)
        x += residual2
        
        x_skip = self.conv_skip(x)
        
        x = self.max_pool(x)
        
        x = self.conv2(x)
        residual3 = x
        x = self.block3(x)
        x += residual3
        x = self.max_pool(x)
        
        x = self.conv3(x)
        residual4 = x
        x = self.block4(x)
        x += residual4
        x = self.max_pool(x)
        
        x = self.conv4(x)
        residual5 = x
        x = self.block5(x)
        x += residual5
        x = self.max_pool(x)
        
        x = self.conv_last(x)
        
        return x, x_skip
    

In [None]:
class Decoder(nn.Module):
    def __init__(self, encoder):
        super(Decoder, self).__init__()
        
        self.encoder = Encoder()
        
        self.conv1 = Conv2dSamePadding(32, 2, 1)
        
        self.conv2 = Conv2dSamePadding(2, 1, 1)
        
    def forward(self, x):

        x1, x2 = self.encoder(x)

        x1 = F.interpolate(x1, size=(128,128))
        x1 = self.conv1(x1)
        
        x = torch.cat((x1, x2))
        
        x = F.interpolate(x, size=(256,256))
        x = self.conv2(x)
        
        return x
  