In [1]:
import torch
import torch.nn as nn

In [24]:
def doub_conv(i_ch,o_ch):
  conv = nn.Sequential(
      nn.Conv2d(i_ch,o_ch,kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(o_ch,o_ch,kernel_size=3),
      nn.ReLU(inplace=True)
  )
  return conv

def crop_img(tensor,tar_tensor):
  target_size=tar_tensor.size()[2]
  tensor_size=tensor.size()[2]
  siz=tensor_size-target_size
  siz=siz//2
  return tensor[ :, :, siz:tensor_size-siz , siz:tensor_size-siz]

In [32]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet,self).__init__()
    self.mx_pl_22= nn.MaxPool2d(kernel_size=2,stride=2)
    self.dwn_conv_1=doub_conv(1,64)
    self.dwn_conv_2=doub_conv(64,128)
    self.dwn_conv_3=doub_conv(128,256)
    self.dwn_conv_4=doub_conv(256,512)
    self.dwn_conv_5=doub_conv(512,1024)
    self.up_trans_1=nn.ConvTranspose2d(in_channels=1024, 
                                      out_channels=512, 
                                      kernel_size=2, 
                                      stride=2)
    
    self.up_conv_1=doub_conv(1024,512)

    self.up_trans_2=nn.ConvTranspose2d(in_channels=512, 
                                      out_channels=256, 
                                      kernel_size=2, 
                                      stride=2)
    
    self.up_conv_2=doub_conv(512,256)

    self.up_trans_3=nn.ConvTranspose2d(in_channels=256, 
                                      out_channels=128, 
                                      kernel_size=2, 
                                      stride=2)
    
    self.up_conv_3=doub_conv(256,128)


    self.up_trans_4=nn.ConvTranspose2d(in_channels=128, 
                                      out_channels=64, 
                                      kernel_size=2, 
                                      stride=2
                                      )
    
    self.up_conv_4=doub_conv(128,64)        

    self.out = nn.Conv2d(in_channels=64,out_channels=2,kernel_size=1)
    

  def forward(self,image):
    #encoder
    x1=self.dwn_conv_1(image)
    print(x1.size())
    x2=self.mx_pl_22(x1)
    x3=self.dwn_conv_2(x2)
    x4=self.mx_pl_22(x3)
    print(x4.size())
    x5=self.dwn_conv_3(x4)
    x6=self.mx_pl_22(x5)
    x7=self.dwn_conv_4(x6)
    x8=self.mx_pl_22(x7)
    x9=self.dwn_conv_5(x8)
    print(x9.size())


    #decoder
    x=self.up_trans_1(x9)
    y=crop_img(x7,x)
    x_1=self.up_conv_1(torch.cat([x,y],1))

    x=self.up_trans_2(x_1)
    y=crop_img(x5,x)
    x_2=self.up_conv_2(torch.cat([x,y],1))

    x=self.up_trans_3(x_2)
    y=crop_img(x3,x)
    x_3=self.up_conv_3(torch.cat([x,y],1))

    x=self.up_trans_4(x_3)
    y=crop_img(x1,x)
    x4=self.up_conv_4(torch.cat([x,y],1))
    out=self.out(x4)
    return out


In [33]:
if __name__ == "__main__":
  image=torch.rand((1,1,572,572))
  model=UNet()
  print(model(image))


torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 140, 140])
torch.Size([1, 1024, 28, 28])
out torch.Size([1, 2, 388, 388])
tensor([[[[ 0.0685,  0.0632,  0.0680,  ...,  0.0684,  0.0640,  0.0624],
          [ 0.0693,  0.0669,  0.0651,  ...,  0.0666,  0.0640,  0.0653],
          [ 0.0695,  0.0697,  0.0652,  ...,  0.0658,  0.0686,  0.0702],
          ...,
          [ 0.0709,  0.0665,  0.0688,  ...,  0.0645,  0.0635,  0.0690],
          [ 0.0667,  0.0653,  0.0651,  ...,  0.0663,  0.0646,  0.0673],
          [ 0.0633,  0.0669,  0.0659,  ...,  0.0659,  0.0719,  0.0639]],

         [[-0.0695, -0.0718, -0.0698,  ..., -0.0669, -0.0685, -0.0701],
          [-0.0766, -0.0710, -0.0707,  ..., -0.0679, -0.0693, -0.0718],
          [-0.0696, -0.0707, -0.0732,  ..., -0.0718, -0.0755, -0.0718],
          ...,
          [-0.0690, -0.0702, -0.0741,  ..., -0.0701, -0.0710, -0.0708],
          [-0.0668, -0.0748, -0.0695,  ..., -0.0730, -0.0711, -0.0697],
          [-0.0701, -0.0700, -0.0660,  ..., -0.0717,