<a href="https://colab.research.google.com/github/taravatp/Multi_Spectral_Image_Segmentation/blob/main/networks/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np

In [None]:
def convolution_module(in_channels,out_channels):
  conv = nn.Sequential(
      nn.Conv2d(in_channels,out_channels,kernel_size=3, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels,out_channels,kernel_size=3, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
  )
  return conv

In [None]:
class CNN_encoder(nn.Module):

  def __init__(self, in_channels=12, feat_channels=[32, 64, 128, 256]):

    super(CNN_encoder, self).__init__()

    # Encoder convolutions
    self.down_conv1 = convolution_module(in_channels,32)
    self.down_conv2 = convolution_module(32,64)
    self.down_conv3 = convolution_module(64,128)
    self.down_conv4 = convolution_module(128,256)
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.out_channels = 256

  def forward(self,x):

    x1 = self.down_conv1(x)
    x_low1 = self.max_pool(x1)

    x2 = self.down_conv2(x_low1)
    x_low2 = self.max_pool(x2)

    x3 = self.down_conv3(x_low2)
    x_low3 = self.max_pool(x3)

    #The bottleneck
    x4 = self.down_conv4(x_low3)

    return x4,[x3,x2,x1]

In [None]:
class CNN_decoder(nn.Module):

  def __init__(self):
    super(CNN_decoder,self).__init__()

    self.upsample1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
    self.conv1 = convolution_module(256, 128)

    self.upsample2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
    self.conv2 = convolution_module(128, 64)

    self.upsample3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
    self.conv3 = convolution_module(64, 32)

    self.out = nn.Conv2d(32,3,kernel_size=1)

  def crop_tensor(self,target_tensor,input_tensor):
    target_tensor_size = target_tensor.size()[2]
    input_tensor_size = input_tensor.size()[2]
    delta = input_tensor_size - target_tensor_size
    delta = delta // 2
    cropped_tensor = input_tensor[:,:,delta:input_tensor_size-delta,delta:input_tensor_size-delta]
    return cropped_tensor

  def forward(self,x, features):

    [x1,x2,x3] = features

    x = self.upsample1(x)
    y = self.crop_tensor(x,x1)
    x = self.conv1(torch.cat([x,y],1))

    x = self.upsample2(x)
    y = self.crop_tensor(x,x2)
    x = self.conv2(torch.cat([x,y],1))

    x = self.upsample3(x)
    y = self.crop_tensor(x,x3)
    x = self.conv3(torch.cat([x,y],1))

    x = self.out(x)

    return x

In [None]:
class unet(nn.Module):

  def __init__(self):
    super(unet, self).__init__()
    self.encoder = CNN_encoder()
    self.decoder = CNN_decoder()

  def forward(self,x):
    x,features = self.encoder(x)
    x = self.decoder(x,features)
    return x

In [None]:
if __name__ == "__main__":

  image = torch.rand((8,12,64,64))
  model = unet()
  output = model(image)
  print('output:',output.shape)

output: torch.Size([8, 3, 64, 64])
