<a href="https://colab.research.google.com/github/xoro-o/UNET-PyTorch-doodle/blob/main/UNET_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn 
from torchvision import transforms

In [39]:
class DoubleConv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(DoubleConv,self).__init__()
    self.conv  =  nn.Sequential(
        nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )
  def forward(self,X):
    return self.conv(X) 


   

In [42]:
class UNET(nn.Module):
  def __init__(self,in_channels=3,out_channels=1,features = [64,128,256,512]):
    super(UNET,self).__init__()
    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2)

    #Down part of UNET
    #here we will define the DOWN module
    for feature in features:
      self.downs.append(DoubleConv(in_channels,feature))
      in_channels = feature
    

    #Up part of UNET: 
    #here we define the UP module

    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature*2,feature,2,2))
      self.ups.append(DoubleConv(feature*2,feature))


    # bottleneck part
    self.bottleneck  = DoubleConv(features[-1],features[-1]*2)
    #final part
    self.final = nn.Conv2d(64,1,1)


  # here again as we divide delta by 2 it will floor results so we have to cover up our implementation with 2 crop images
  def crop_img(tensor,target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size-target_size
    delta = delta//2
    return tensor[:,:,delta:tensor_size-delta-1,delta:tensor_size-delta-1]
  def crop_img2(tensor,target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size-target_size
    delta = delta//2
    return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]    


  def forward(self,x):
    skip_connections = []
    for down in self.downs:
      x = down(x)
      print(x.size())
      skip_connections.append(x)
      x = self.pool(x)
    
    x = self.bottleneck(x)

    skip_connections = skip_connections[::-1]
    # here my self.ups will be like :
    # [conv_transpos1,doubleconv_1,conv_transpose2,doubleconv_2,conv_transpos3,doubleconv_3]
    # so here i want to access the transpose part first i.e the indices : 0,2,4,6..
    # then i want to access the doubleconv (after the transpose part) i. indices 1,3,5 ..
    for idx in range(0,len(self.ups),2):
      x = self.ups[idx](x)
      skip_connection = skip_connections[idx//2]
      

      # here we can run into the problem of invalid matching size between the skip_connection dimensions and upsampled transpose convs
      # so we can do two things:
      # first we can crop and then concatenate
      # or we can resize 
      if(x.shape != skip_connection.shape):
        #either we can crop and concatenat using the crop_img function defined above or we can resize
        x = transforms.functional.resize(x,size = skip_connection.shape[2:])

      concat_skip = torch.cat((skip_connection,x),1)
      x = self.ups[idx+1](concat_skip)

    return self.final(x)


    



In [41]:
iamge = torch.rand((3,1,160,160))
model  = UNET(in_channels = 1,out_channels = 1)
pred  = model(iamge)
print(model)
print(iamge.shape)
print(pred.shape)
assert iamge.shape == pred.shape

torch.Size([3, 64, 160, 160])
torch.Size([3, 128, 80, 80])
torch.Size([3, 256, 40, 40])
torch.Size([3, 512, 20, 20])
UNET(
  (ups): ModuleList(
    (0): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (1): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (2): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (3): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_r