In [2]:
import torch
import torch.nn as nn

In [24]:
def double_conv(in_channel, out_channel):
  conv = nn.Sequential(
      nn.Conv2d(in_channel,out_channel,kernel_size = 3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channel,out_channel,kernel_size = 3),
      nn.ReLU(inplace=True),
  )
  return conv

In [4]:
def crop_tensor(tensor,target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]

In [30]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet,self).__init__()

    # Forwoard Layers
    self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.down_conv_1 = double_conv(1,64)
    self.down_conv_2 = double_conv(64,128)
    self.down_conv_3 = double_conv(128,256)
    self.down_conv_4 = double_conv(256,512)
    self.down_conv_5 = double_conv(512,1024)
    
    # Upscaling layers
    self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,
                                         out_channels=512,
                                         kernel_size=2,
                                         stride=2)
    self.up_conv_1 = double_conv(1024,512)

    self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,
                                         out_channels=256,
                                         kernel_size=2,
                                         stride=2)
    self.up_conv_2 = double_conv(512,256)

    self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,
                                         out_channels=128,
                                         kernel_size=2,
                                         stride=2)
    self.up_conv_3 = double_conv(256,128)

    self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,
                                         out_channels=64,
                                         kernel_size=2,
                                         stride=2)
    self.up_conv_4 = double_conv(128,64)
    
    self.out = nn.Conv2d(
        in_channels=64,
        out_channels=2,
        kernel_size=1
    )

  def forward(self,image):
    # image encoder
    x1 =  self.down_conv_1(image)
    x2 =  self.max_pool_2x2(x1)
    x3 =  self.down_conv_2(x2)
    x4 =  self.max_pool_2x2(x3)
    x5 =  self.down_conv_3(x4)
    x6 =  self.max_pool_2x2(x5)
    x7 =  self.down_conv_4(x6)
    x8 =  self.max_pool_2x2(x7)
    x9 =  self.down_conv_5(x8)

    #image Decorder
    x = self.up_trans_1(x9)
    y = crop_tensor(x7,x)
    x = self.up_conv_1(torch.cat([x,y],1))

    x = self.up_trans_2(x)
    y = crop_tensor(x5,x)
    x = self.up_conv_2(torch.cat([x,y],1))

    x = self.up_trans_3(x)
    y = crop_tensor(x3,x)
    x = self.up_conv_3(torch.cat([x,y],1))

    x = self.up_trans_4(x)
    y = crop_tensor(x1,x)
    x = self.up_conv_4(torch.cat([x,y],1))
    x = self.out(x)
    return x

In [31]:
image = torch.rand((1,1,572,572))
model = UNet()

# print(model(image))
model.forward(image)

torch.Size([1, 64, 388, 388])


tensor([[[[0.1071, 0.1079, 0.1105,  ..., 0.1100, 0.1042, 0.1101],
          [0.1128, 0.1120, 0.1109,  ..., 0.1108, 0.1108, 0.1103],
          [0.1121, 0.1090, 0.1072,  ..., 0.1134, 0.1115, 0.1098],
          ...,
          [0.1102, 0.1111, 0.1052,  ..., 0.1099, 0.1118, 0.1110],
          [0.1080, 0.1072, 0.1077,  ..., 0.1082, 0.1094, 0.1138],
          [0.1094, 0.1077, 0.1086,  ..., 0.1066, 0.1114, 0.1131]],

         [[0.0645, 0.0617, 0.0599,  ..., 0.0628, 0.0621, 0.0550],
          [0.0587, 0.0573, 0.0568,  ..., 0.0600, 0.0569, 0.0558],
          [0.0578, 0.0594, 0.0628,  ..., 0.0544, 0.0553, 0.0580],
          ...,
          [0.0568, 0.0550, 0.0594,  ..., 0.0607, 0.0581, 0.0617],
          [0.0547, 0.0572, 0.0604,  ..., 0.0593, 0.0583, 0.0604],
          [0.0586, 0.0644, 0.0641,  ..., 0.0629, 0.0610, 0.0614]]]],
       grad_fn=<ThnnConv2DBackward0>)