In [3]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# import cv2

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torch.nn.functional as F
from torch.utils.data.sampler import Sampler

from pathlib import Path
import sys
from IPython.display import display


# Primary Model

In [4]:
class convBlock(nn.Module):
    def __init__(self, input_channel, output_channel):
        super(convBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, output_channel, 3)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3)
        self.relu = nn.ReLU()
    
    def forward(self, input:torch.tensor):
        input = self.conv1(input)
        input = self.relu(input)
        input = self.conv2(input)
        return input


In [5]:
class UnetEncoder(nn.Module):
    def __init__(self, channel_list:list):
        super(UnetEncoder, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.block_list = []

        for i in range(len(channel_list)-1):
            self.block_list.append(convBlock(channel_list[i], channel_list[i+1]))
    
    def forward(self, input:torch.tensor):
        layered_encoder_out = []

        for block in self.block_list:
            input = block(input)
            layered_encoder_out .append(input)
            input = self.pool(input)

        return layered_encoder_out 

In [9]:
class UnetDecoder(nn.Module):
    def __init__(self, channel_list:list):
        super(UnetDecoder, self).__init__()
        self.channel_list = channel_list
        self.block_list = []
        self.up_sampler = []

        for i in range(len(channel_list)-1):
            self.up_sampler.append(nn.ConvTranspose2d(channel_list[i], channel_list[i+1], 2, 2))

        for i in range(len(channel_list)-1):
            self.block_list.append(convBlock(channel_list[i], channel_list[i+1]))

    # layer concat takes from the encoder layered ouptut for concat
    def forward(self, input, layered_concat):
        for i in range(len(self.channel_list)-1):
            input = self.up_sampler[i](input)
            concat_feature = self.crop(layered_concat[i], input)
            input = torch.concat([input, concat_feature], dim=1)
            input = self.block_list[i](input)
        return input

    def crop(self, concat_feature, input):
            B, C, H, W = input.shape
            concat_feature = torchvision.transforms.CenterCrop([H, W])(concat_feature)
            return concat_feature


In [13]:
class UNET(nn.Module):
    def __init__(self, in_channel_list:list, out_channel_list:list, classes=1, keep_dim=True, output_size=(1024, 1024)):
        super(UNET, self).__init__()
        self.encoder = UnetEncoder(in_channel_list)
        self.decoder = UnetDecoder(out_channel_list)
        self.compressor = nn.Conv2d(out_channel_list[-1], classes, 1)
        self.output_size = output_size
        self.keep_dim = keep_dim
    
    def forward(self, input):
        encoder_output = self.encoder(input)
        encoder_output = list(reversed(encoder_output))
        output = self.decoder(encoder_output[0], encoder_output[1:])
        output = self.compressor(output)
        if self.keep_dim:
            output = F.interpolate(output, self.output_size)
        return output


In [7]:
enc_block = convBlock(1, 64)
x = torch.randn(1, 1, 572, 572)
x = enc_block(x)
display(x.shape)

torch.Size([1, 64, 568, 568])

In [8]:
chan_list = [3,64,128,256,512,1024]
encoder = UnetEncoder(channel_list=chan_list)
# input image
x    = torch.randn(1, 3, 572, 572)
ftrs = encoder(x)
for ftr in ftrs: print(ftr.shape)

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])


In [10]:
channel_list = [1024, 512, 256, 128, 64]
decoder = UnetDecoder(channel_list=channel_list)
x = torch.randn(1, 1024, 28, 28)
decoder(x, ftrs[::-1][1:]).shape

torch.Size([1, 64, 388, 388])

In [16]:
in_chan = [3,64,128,256,512,1024]
out_chan = [1024, 512, 256, 128, 64]
unet = UNET(in_chan, out_chan, keep_dim=True)
x    = torch.randn(1, 3, 572, 572)
unet(x).shape

torch.Size([1, 1, 1024, 1024])