In [1]:
import torch 
import torch.nn as nn

In [9]:
def double_conv(in_c , out_c):
  conv = nn.Sequential(
      nn.Conv2d(in_c , out_c , kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c, out_c , kernel_size=3),
      nn.ReLU(inplace=True),
  )
  return conv

In [18]:
def crop_img(tensor , target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  #cropping 
  delta = tensor_size-target_size
  delta = delta//2
  return tensor[:,:,delta:tensor_size-delta , delta:tensor_size-delta ]


In [33]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet , self).__init__()

    self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2 , stride = 2)

    #down convolution , left half of the U (Contracting path)
    self.down_conv1 = double_conv(1,64)
    self.down_conv2 = double_conv(64,128)
    self.down_conv3 = double_conv(128,256)
    self.down_conv4 = double_conv(256,512)
    self.down_conv5 = double_conv(512,1024)

    #up convolution , right half of the U (Expansive path)
    self.up_trans_1 = nn.ConvTranspose2d(
        in_channels = 1024,
        out_channels = 512,
        kernel_size=2,
        stride=2)
    self.up_conv1 = double_conv(1024,512)

    
    self.up_trans_2 = nn.ConvTranspose2d(
        in_channels = 512,
        out_channels = 256,
        kernel_size=2,
        stride=2)
    self.up_conv2 = double_conv(512,256)


    self.up_trans_3 = nn.ConvTranspose2d(
        in_channels = 256,
        out_channels = 128,
        kernel_size=2,
        stride=2)
    self.up_conv3 = double_conv(256,128)


    self.up_trans_4 = nn.ConvTranspose2d(
        in_channels = 128,
        out_channels = 64,
        kernel_size=2,
        stride=2)
    self.up_conv4 = double_conv(128,64)

    self.out = nn.Conv2d(
        in_channels = 64,
        out_channels = 2,
        kernel_size = 1
    )

  def forward(self , image):
    #encoder 
    x1 = self.down_conv1(image)
    x2 = self.max_pool_2x2(x1)
    x3 = self.down_conv2(x2)
    x4 = self.max_pool_2x2(x3)
    x5 = self.down_conv3(x4)
    x6 = self.max_pool_2x2(x5)
    x7 = self.down_conv4(x6)
    x8 = self.max_pool_2x2(x7)
    x9 = self.down_conv5(x8)

    #decoder
    x = self.up_trans_1(x9)
    y = crop_img(x7,x)
    x = self.up_conv1(torch.cat([x,y],1))

    x = self.up_trans_2(x)
    y = crop_img(x5,x)
    x = self.up_conv2(torch.cat([x,y],1))

    x = self.up_trans_3(x)
    y = crop_img(x3,x)
    x = self.up_conv3(torch.cat([x,y],1))

    x = self.up_trans_4(x)
    y = crop_img(x1,x)
    x = self.up_conv4(torch.cat([x,y],1))

    x = self.out(x)
    print(x.size())
    return x

In [34]:
if __name__ == "__main__":
  image = torch.rand((1,1,572,572)) #(batch size , chanels , height , width)
  model = UNet()
  print(model(image))

torch.Size([1, 2, 388, 388])
tensor([[[[0.1205, 0.1267, 0.1166,  ..., 0.1255, 0.1269, 0.1298],
          [0.1237, 0.1200, 0.1170,  ..., 0.1249, 0.1235, 0.1252],
          [0.1220, 0.1248, 0.1247,  ..., 0.1263, 0.1268, 0.1224],
          ...,
          [0.1268, 0.1227, 0.1234,  ..., 0.1274, 0.1180, 0.1239],
          [0.1209, 0.1234, 0.1272,  ..., 0.1204, 0.1250, 0.1246],
          [0.1206, 0.1210, 0.1236,  ..., 0.1249, 0.1220, 0.1282]],

         [[0.0172, 0.0188, 0.0176,  ..., 0.0208, 0.0199, 0.0238],
          [0.0193, 0.0205, 0.0136,  ..., 0.0231, 0.0202, 0.0238],
          [0.0193, 0.0162, 0.0154,  ..., 0.0194, 0.0189, 0.0254],
          ...,
          [0.0173, 0.0219, 0.0187,  ..., 0.0171, 0.0159, 0.0192],
          [0.0272, 0.0209, 0.0133,  ..., 0.0191, 0.0212, 0.0252],
          [0.0212, 0.0163, 0.0221,  ..., 0.0216, 0.0219, 0.0215]]]],
       grad_fn=<ThnnConv2DBackward0>)
