In [3]:
import torch
import torch.nn as nn

In [18]:
def double_conv(in_c,out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c,out_c,kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c,out_c,kernel_size=3),
        nn.ReLU(inplace=True)
    )
    
    return conv

def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:,:,delta:tensor_size - delta, delta:tensor_size - delta]
    



class UNet(nn.Module):
    
    def __init__(self):
        super(UNet,self).__init__()
        # encoder
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.down_conv_1 = double_conv(1,64)
        self.down_conv_2 = double_conv(64,128)
        self.down_conv_3 = double_conv(128,256)
        self.down_conv_4 = double_conv(256,512)
        self.down_conv_5 = double_conv(512,1024)
        
        #decoder 
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        self.up_conv_1 = double_conv(1024,512)
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
        self.up_conv_2 = double_conv(512,256)
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        self.up_conv_3 = double_conv(256,128)
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
        self.up_conv_4 = double_conv(128,64)
        self.out = nn.Conv2d(in_channels=64,out_channels=2,kernel_size=1)



    def forward(self,image):
        # encoder 
        x1 = self.down_conv_1(image) #
        x2 = self.max_pool_2(x1)
        x3 = self.down_conv_2(x2) # 
        x4 = self.max_pool_2(x3)
        x5 = self.down_conv_3(x4) # 
        x6 = self.max_pool_2(x5)
        x7 = self.down_conv_4(x6) # 
        x8 = self.max_pool_2(x7) 
        x9 = self.down_conv_5(x8) # 
        
        x = self.up_trans_1(x9)
        y = crop_img(x7,x)
        x = self.up_conv_1(torch.cat([x,y],1))

        x = self.up_trans_2(x)
        y = crop_img(x5,x)
        x = self.up_conv_2(torch.cat([x,y],1))

        x = self.up_trans_3(x)
        y = crop_img(x3,x)
        x = self.up_conv_3(torch.cat([x,y],1))

        x = self.up_trans_4(x)
        y = crop_img(x1,x)
        x = self.up_conv_4(torch.cat([x,y],1))
        
        x = self.out(x)
        print(x.size())
        return x
    

image = torch.rand((1,1,572,572))
model = UNet()
print(model(image))
    

torch.Size([1, 2, 388, 388])
tensor([[[[ 0.0689,  0.0713,  0.0713,  ...,  0.0740,  0.0676,  0.0647],
          [ 0.0697,  0.0710,  0.0721,  ...,  0.0656,  0.0722,  0.0654],
          [ 0.0716,  0.0629,  0.0675,  ...,  0.0755,  0.0662,  0.0702],
          ...,
          [ 0.0650,  0.0674,  0.0686,  ...,  0.0706,  0.0666,  0.0743],
          [ 0.0750,  0.0657,  0.0680,  ...,  0.0691,  0.0713,  0.0701],
          [ 0.0694,  0.0733,  0.0671,  ...,  0.0674,  0.0710,  0.0644]],

         [[-0.0279, -0.0254, -0.0221,  ..., -0.0255, -0.0251, -0.0267],
          [-0.0203, -0.0255, -0.0228,  ..., -0.0257, -0.0231, -0.0233],
          [-0.0209, -0.0317, -0.0203,  ..., -0.0226, -0.0298, -0.0222],
          ...,
          [-0.0243, -0.0253, -0.0236,  ..., -0.0215, -0.0259, -0.0246],
          [-0.0264, -0.0290, -0.0234,  ..., -0.0235, -0.0221, -0.0226],
          [-0.0232, -0.0183, -0.0253,  ..., -0.0236, -0.0209, -0.0262]]]],
       grad_fn=<ConvolutionBackward0>)
