<a href="https://colab.research.google.com/github/xoro-o/UNET_architecture/blob/main/Simple_UNET_architecture_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn

In [38]:
def doubleConv(in_c,out_c):
  conv = nn.Sequential(nn.Conv2d(in_c,out_c,3),
                       nn.ReLU(inplace=True),
                       nn.Conv2d(out_c,out_c,3),
                       nn.ReLU(inplace=True)
                       )
  return conv



## here crop_img is not a good resolve as division operation can throw us in some loopholes
## so there we 2 crop_img created as per the size and output changed because of the div operation
def crop_img(tensor,target_tensor):
  target_size = target_tensor.size()[2] # t_s = 100
  tensor_size = tensor.size()[2]        # t_s = 135
  delta = tensor_size-target_size   
  # print("delta was : ",delta)           # delta = 35
  delta = delta//2                      # delta was 17
  # print("delta became :",delta)          
  return tensor[:,:,delta:tensor_size-delta-1,delta:tensor_size-delta-1]

def crop_img2(tensor,target_tensor):
  target_size = target_tensor.size()[2] # t_s = 100
  tensor_size = tensor.size()[2]        # t_s = 135
  delta = tensor_size-target_size   
  # print("delta was : ",delta)           # delta = 35
  delta = delta//2                      # delta was 17
  # print("delta became :",delta)          
  return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]  

In [30]:
class UNET(nn.Module):
  def __init__(self):
    super(UNET,self).__init__()
    self.maxPool = nn.MaxPool2d(3,2)

    self.down_conv_1 = doubleConv(1,64)
    self.down_conv_2 = doubleConv(64,128)
    self.down_conv_3 = doubleConv(128,256)
    self.down_conv_4 = doubleConv(256,512)
    self.down_conv_5 = doubleConv(512,1024)

    self.trans_up_1 = nn.ConvTranspose2d(1024,512,2,2)
    self.up_conv_1 = doubleConv(1024,512)

    self.trans_up_2 = nn.ConvTranspose2d(512,256,2,2)
    self.up_conv_2 = doubleConv(512,256)

    self.trans_up_3 = nn.ConvTranspose2d(256,128,2,2)
    self.up_conv_3 = doubleConv(256,128)

    self.trans_up_4 = nn.ConvTranspose2d(128,64,2,2)
    self.up_conv_4 = doubleConv(128,64)

    self.out = nn.Conv2d(64,2,1)



  def forward(self,image):
    # ENCODER DONE
    x1 = self.down_conv_1(image)
    x2 = self.maxPool(x1)
    x3 = self.down_conv_2(x2)
    x4 = self.maxPool(x3)
    x5 = self.down_conv_3(x4)
    x6 = self.maxPool(x5)
    x7 = self.down_conv_4(x6)
    x8 = self.maxPool(x7)
    x9 = self.down_conv_5(x8)
    # print(x9.size())

    #DECODER
    x = self.trans_up_1(x9)
    # print(x.size())
    y = crop_img(x7,x)
    # print(y.size())
    # print(torch.cat([x,y],1).size())
    x = self.up_conv_1(torch.cat([x,y],1))
    # print(x.size())
    x = self.trans_up_2(x)
    # print(x.size())
    # print(x5.size())
    y = crop_img(x5,x)
    # print(y.size())
    x = self.up_conv_2(torch.cat([x,y],1))
    print(x.size())
    x = self.trans_up_3(x)
    y = crop_img(x3,x)
    x = self.up_conv_3(torch.cat([x,y],1))

    x = self.trans_up_4(x)
    y = crop_img2(x1,x)
    x = self.up_conv_4(torch.cat([x,y],1))

    x = self.out(x)

    print(x.size())
    return x


    




In [39]:

image = torch.rand((1,1,572,572))
model = UNET()
print(model(image))

torch.Size([1, 256, 96, 96])
torch.Size([1, 2, 372, 372])
tensor([[[[-0.1085, -0.1091, -0.1092,  ..., -0.1019, -0.1070, -0.1014],
          [-0.1083, -0.1056, -0.1078,  ..., -0.1084, -0.1023, -0.1064],
          [-0.1069, -0.1048, -0.1101,  ..., -0.1073, -0.1054, -0.1038],
          ...,
          [-0.1090, -0.1076, -0.1112,  ..., -0.1084, -0.1039, -0.1094],
          [-0.1068, -0.1072, -0.1089,  ..., -0.1065, -0.1091, -0.1018],
          [-0.1118, -0.1089, -0.1102,  ..., -0.1069, -0.1102, -0.1056]],

         [[-0.0934, -0.0921, -0.0957,  ..., -0.0992, -0.0978, -0.0960],
          [-0.0930, -0.0991, -0.0980,  ..., -0.0982, -0.0990, -0.0950],
          [-0.0993, -0.0968, -0.0952,  ..., -0.0921, -0.0943, -0.0972],
          ...,
          [-0.0935, -0.0993, -0.0956,  ..., -0.0978, -0.0972, -0.0975],
          [-0.0965, -0.0960, -0.0939,  ..., -0.0973, -0.0941, -0.0936],
          [-0.0951, -0.0936, -0.0962,  ..., -0.0943, -0.0950, -0.0972]]]],
       grad_fn=<ConvolutionBackward0>)
