In [6]:
import torch
import torch.nn as nn

In [7]:
def double_conv(in_c,out_c):    #convolutional layer of Conv2d and reLU
  conv=nn.Sequential(
      nn.Conv2d(in_c,out_c,kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c,out_c,kernel_size=3),
      nn.ReLU(inplace=True)

  )
  return conv

In [16]:
def crop_img(tensor,target_tensor):     #cropping image 
  target_size=target_tensor.size()[2]
  tensor_size=tensor.size()[2]
  delta=tensor_size - target_size
  delta=delta//2
  return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta] 

In [25]:


class UNet(nn.Module):
  def __init__(self):

    super(UNet,self).__init__()

    self.max_pool_2x2=nn.MaxPool2d(kernel_size=2,stride=2)  # downscaling the image
    self.down_conv_1=double_conv(1,64)
    self.down_conv_2=double_conv(64,128)
    self.down_conv_3=double_conv(128,256)
    self.down_conv_4=double_conv(256,512)
    self.down_conv_5=double_conv(512,1024)


    self.up_trans_1=nn.ConvTranspose2d(          # to upscale the image 
        in_channels=1024,
        out_channels=512,kernel_size=2,stride=2
    )
    self.up_conv_1=double_conv(1024,512)
    
    self.up_trans_2=nn.ConvTranspose2d(
        in_channels=512,
        out_channels=256,kernel_size=2,stride=2
    )
    self.up_conv_2=double_conv(512,256)

    self.up_trans_3=nn.ConvTranspose2d(
        in_channels=256,
        out_channels=128,kernel_size=2,stride=2
    )
    self.up_conv_3=double_conv(256,128)

    self.up_trans_4=nn.ConvTranspose2d(
        in_channels=128,
        out_channels=64,kernel_size=2,stride=2
    )
    self.up_conv_4=double_conv(128,64)
    


    self.out=nn.Conv2d(     # for the output image size
        in_channels=64,
        out_channels=2,   #here you can chaneg the out channels as per the number of segmentation we want
        kernel_size=1
    )


  def forward(self,image):
    #bs,c,h,w
    x1=self.down_conv_1(image)

    x2=self.max_pool_2x2(x1)
    x3=self.down_conv_2(x2)
    x4=self.max_pool_2x2(x3)
    x5=self.down_conv_3(x4)
    
    x6=self.max_pool_2x2(x5)
    x7=self.down_conv_4(x6)
    x8=self.max_pool_2x2(x7)
    x9=self.down_conv_5(x8) 
    
  
    x=self.up_trans_1(x9)
    y=crop_img(x7,x)           #need to crop the image to be of expected size
    x=self.up_conv_1(torch.cat([x,y],1))
    
    
    x=self.up_trans_2(x)
    y=crop_img(x5,x)
    x=self.up_conv_2(torch.cat([x,y],1))


    x=self.up_trans_3(x)
    y=crop_img(x3,x)
    x=self.up_conv_3(torch.cat([x,y],1))


    x=self.up_trans_4(x)
    y=crop_img(x1,x)
    x=self.up_conv_4(torch.cat([x,y],1))

    self.out(x)     # to convert to expexted output size
    
    print(x.size())
    return x
if __name__=="__main__":
  image=torch.rand((1,1,572,572))
  model=UNet()
  print(model(image))









torch.Size([1, 64, 388, 388])
tensor([[[[0.0094, 0.0070, 0.0048,  ..., 0.0066, 0.0029, 0.0077],
          [0.0049, 0.0144, 0.0000,  ..., 0.0110, 0.0074, 0.0015],
          [0.0163, 0.0033, 0.0044,  ..., 0.0032, 0.0065, 0.0024],
          ...,
          [0.0109, 0.0073, 0.0119,  ..., 0.0000, 0.0080, 0.0131],
          [0.0000, 0.0135, 0.0038,  ..., 0.0197, 0.0055, 0.0196],
          [0.0179, 0.0000, 0.0214,  ..., 0.0250, 0.0078, 0.0000]],

         [[0.0506, 0.0586, 0.0488,  ..., 0.0564, 0.0531, 0.0485],
          [0.0605, 0.0623, 0.0477,  ..., 0.0472, 0.0611, 0.0529],
          [0.0455, 0.0649, 0.0460,  ..., 0.0534, 0.0383, 0.0509],
          ...,
          [0.0595, 0.0586, 0.0479,  ..., 0.0496, 0.0546, 0.0611],
          [0.0509, 0.0514, 0.0558,  ..., 0.0442, 0.0530, 0.0515],
          [0.0549, 0.0603, 0.0434,  ..., 0.0474, 0.0425, 0.0665]],

         [[0.0157, 0.0279, 0.0089,  ..., 0.0074, 0.0165, 0.0049],
          [0.0130, 0.0175, 0.0207,  ..., 0.0128, 0.0125, 0.0200],
          [0