<a href="https://colab.research.google.com/github/princexoleo/u_net_pattern_lab/blob/master/u_net_original_from_scratch_with_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#U-Net Architechture Development with PyTorch framework
* In this notebook I'll show how to implements U-Net architecture which shown in original papers

In [92]:
# import necessary libraries and module
import torch
import torch.nn as nn 
from torch import optim
import torch.nn.functional as F
import numpy as np

##Create simple class called U-Net
* 2D Convolution

In [94]:
# As we need a dounble convolution operation, so I create double conv function
def double_conv(input_channel, output_channel):
  conv = nn.Sequential(
      nn.Conv2d(input_channel, output_channel, kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(output_channel, output_channel, kernel_size=3),
      nn.ReLU(inplace=True)
  )
  return conv


def crop_img(org_tensor, target_tensor):
  target_size = target_tensor.size()[2]
  org_tensor_size = org_tensor.size()[2]
  delta = org_tensor_size - target_size
  delta = delta // 2
  return org_tensor[:,:, delta:org_tensor_size-delta, delta:org_tensor_size-delta]


In [95]:
class UNet(nn.Module):
  def __init__(self, input_c, num_class):
    super(UNet, self).__init__()
    self.input_c = input_c
    self.num_class = num_class

    self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.down_conv1 = double_conv(input_c, 64)
    self.down_conv2 = double_conv(64, 128)
    self.down_conv3 = double_conv(128, 256)
    self.down_conv4 = double_conv(256, 512)
    self.down_conv5 = double_conv(512, 1024)
    ##
    # Starting 2nd part of(expansion part)
    # First we need transpose
    self.up_trans_1 = nn.ConvTranspose2d(
        in_channels=1024,
        out_channels = 512,
        kernel_size = 2,
        stride = 2
    )
    self.up_conv_1 = double_conv(1024, 512)

    self.up_trans_2 = nn.ConvTranspose2d(
        in_channels=512,
        out_channels = 256,
        kernel_size = 2,
        stride = 2
    )
    self.up_conv_2 = double_conv(512, 256)

    self.up_trans_3 = nn.ConvTranspose2d(
        in_channels=256,
        out_channels = 128,
        kernel_size = 2,
        stride = 2
    )
    self.up_conv_3 = double_conv(256, 128)

    self.up_trans_4 = nn.ConvTranspose2d(
        in_channels=128,
        out_channels = 64,
        kernel_size = 2,
        stride = 2
    )
    self.up_conv_4 = double_conv(128, 64)


    ## Output layer
    self.out = nn.Conv2d(
        in_channels=64,
        out_channels = 2,
        kernel_size= 1
    )

  def forward(self, img):
    # expected size (batch_size, in_channel, height, width)
    print("Input: {}".format(img.size()))
    # encoder
    x1 = self.down_conv1(img) #
    x2 = self.max_pool_2x2(x1)
    x3 = self.down_conv2(x2) #
    x4 = self.max_pool_2x2(x3)
    x5 = self.down_conv3(x4) #
    x6 = self.max_pool_2x2(x5)
    x7 = self.down_conv4(x6) #
    x8 = self.max_pool_2x2(x7)
    x9 = self.down_conv5(x8) 
    #x2 = self.max_pool_2x2(x9)
    #print(x9.size())
    # decoder
    # now we need to concat tensor
    # before concat we need to crop image/ pad image [original paper they crop]
    x = self.up_trans_1(x9)
    y = crop_img(x7, x)
    x = self.up_conv_1(torch.cat([x,y], 1))
    # print(x.size())

    x = self.up_trans_2(x)
    y = crop_img(x5, x)
    x = self.up_conv_2(torch.cat([x,y], 1))


    x = self.up_trans_3(x)
    y = crop_img(x3, x)
    x = self.up_conv_3(torch.cat([x,y], 1))

    x = self.up_trans_4(x)
    y = crop_img(x1, x)
    x = self.up_conv_4(torch.cat([x,y], 1))

    ##
    logits = self.out(x)
    print("Out: {}".format(logits.size()))
  

In [96]:
sample_img = torch.rand((1,1,572,572))
unet_model = UNet(1, 2)
#net_model
## apply sample image with unet
unet_model(sample_img)

Input: torch.Size([1, 1, 572, 572])
Out: torch.Size([1, 2, 388, 388])


##UNet: Covolutation1D

In [97]:
# As we need a dounble convolution operation, so I create double conv function
def double_conv(input_channel, output_channel):
  conv = nn.Sequential(
      nn.Conv1d(input_channel, output_channel, kernel_size=3,padding=1),
      nn.ReLU(inplace=True),
      nn.Conv1d(output_channel, output_channel, kernel_size=3, padding=1),
      nn.ReLU(inplace=True)
  )
  return conv

def crop_tensor(x1,x2):
  diffY = x2.size()[1] - x1.size()[1]
  diffX = x2.size()[2] - x1.size()[2]

  x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
  return x1


###U-NET Model with 1D Convolution

In [98]:
class UNet_1D(nn.Module):
  def __init__(self, input_c, num_class):
    super(UNet_1D, self).__init__()
    self.input_c = input_c
    self.num_class = num_class

    self.max_pool_2x2 = nn.MaxPool1d(kernel_size=2, stride=2)
    self.down_conv1 = double_conv(input_c, 64)
    self.down_conv2 = double_conv(64, 128)
    self.down_conv3 = double_conv(128, 256)
    self.down_conv4 = double_conv(256, 512)
    self.down_conv5 = double_conv(512, 1024)
    ##
    # Starting 2nd part of(expansion part)
    # First we need transpose
    self.up_trans_1 = nn.ConvTranspose1d(
        in_channels=1024,
        out_channels = 512,
        kernel_size = 2,
        stride = 2,
        output_padding = 1
    )
    self.up_conv_1 = double_conv(1024, 512)

    self.up_trans_2 = nn.ConvTranspose1d(
        in_channels=512,
        out_channels = 256,
        kernel_size = 2,
        stride = 2,
        output_padding = 1
    )
    self.up_conv_2 = double_conv(512, 256)

    self.up_trans_3 = nn.ConvTranspose1d(
        in_channels=256,
        out_channels = 128,
        kernel_size = 2,
        stride = 2,
        output_padding = 1
    )
    self.up_conv_3 = double_conv(256, 128)

    self.up_trans_4 = nn.ConvTranspose1d(
        in_channels=128,
        out_channels = 64,
        kernel_size = 2,
        stride = 2,
        padding = 1
        
    )
    self.up_conv_4 = double_conv(128, 64)


    ## Output layer
    self.out = nn.Conv1d(
        in_channels=64,
        out_channels = self.num_class,
        kernel_size= 1,
    )

  def forward(self, img):
    # expected size (batch_size, in_channel, seq_length)
    print("Input: {}".format(img.size()))
    # encoder
    x1 = self.down_conv1(img) #
    x2 = self.max_pool_2x2(x1)
    x3 = self.down_conv2(x2) #
    x4 = self.max_pool_2x2(x3)
    x5 = self.down_conv3(x4) #
    x6 = self.max_pool_2x2(x5)
    x7 = self.down_conv4(x6) #
    x8 = self.max_pool_2x2(x7)
    x9 = self.down_conv5(x8) 
    #x2 = self.max_pool_2x2(x9)
    #print(x9.size())
    # decoder
    # now we need to concat tensor
    # before concat we need to crop image/ pad image [original paper they crop]
    x = self.up_trans_1(x9)
    y = crop_tensor(x7, x)
    x = self.up_conv_1(torch.cat([x,y], 1))
    # print(x.size())

    x = self.up_trans_2(x)
    y = crop_tensor(x5, x)
    x = self.up_conv_2(torch.cat([x,y], 1))


    x = self.up_trans_3(x)
    y = crop_tensor(x3, x)
    x = self.up_conv_3(torch.cat([x,y], 1))

    x = self.up_trans_4(x)
    y = crop_tensor(x1, x)
    x = self.up_conv_4(torch.cat([x,y], 1))

    ##
    logits = self.out(x)
    print("Out: {}".format(logits.size()))


    

In [99]:
# (batch_size, in_channell, seq_length)
sample_input = torch.rand((1,42,700))
unet_1D_model = UNet_1D(42, 8)
#net_model
unet_1D_model(sample_input)

Input: torch.Size([1, 42, 700])
Out: torch.Size([1, 8, 700])
