In [1]:
import torch
import torch.nn as nn

In [2]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = self.double_conv(1,64)
        self.down_conv_2 = self.double_conv(64,128)
        self.down_conv_3 = self.double_conv(128,256)
        self.down_conv_4 = self.double_conv(256,512)
        self.down_conv_5 = self.double_conv(512,1024)
        
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024, 
                                             out_channels=512,
                                             kernel_size=2,
                                             stride=2)
        
        self.up_conv_1 = self.double_conv(1024,512)
        
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512, 
                                             out_channels=256,
                                             kernel_size=2,
                                             stride=2)
        
        self.up_conv_2 = self.double_conv(512,256)
        
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256, 
                                             out_channels=128,
                                             kernel_size=2,
                                             stride=2)
        
        self.up_conv_3 = self.double_conv(256,128)
        
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128, 
                                             out_channels=64,
                                             kernel_size=2,
                                             stride=2)
        
        self.up_conv_4 = self.double_conv(128,64)
        
        self.out = nn.Conv2d(in_channels=64,
                            out_channels=2,
                            kernel_size=1)
    
    def forward(self, image):
        #bs, c, h, w
        #encoder
        x1 = self.down_conv_1(image)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        
        #decoder
        x = self.up_trans_1(x9)
        y = self.crop_img(x7,x)
        x = self.up_conv_1(torch.cat([x,y], 1))

        x = self.up_trans_2(x)
        y = self.crop_img(x5,x)
        x = self.up_conv_2(torch.cat([x,y], 1))
        
        x = self.up_trans_3(x)
        y = self.crop_img(x3,x)
        x = self.up_conv_3(torch.cat([x,y], 1))
        
        x = self.up_trans_4(x)
        y = self.crop_img(x1,x)
        x = self.up_conv_4(torch.cat([x,y], 1))
        
        x = self.out(x)
        return x


    def double_conv(self, in_c,out_c):
        conv = nn.Sequential(nn.Conv2d(in_c,out_c,kernel_size=3),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(out_c,out_c,kernel_size=3),
                        nn.ReLU(inplace=True))
        return conv
    
    def crop_img(self, tensor, target_tensor):
        target_size = target_tensor.size()[2]
        tensor_size = tensor.size()[2]
        delta = tensor_size - target_size
        delta = delta // 2
        return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

In [12]:
if __name__ =="__main__":
    image = torch.rand((3,1,572,572))
    print(image)
    model = UNet()
    print(model(image))

tensor([[[[0.7145, 0.3768, 0.4524,  ..., 0.3801, 0.3851, 0.9306],
          [0.3022, 0.3406, 0.6886,  ..., 0.9732, 0.8838, 0.4303],
          [0.6610, 0.9261, 0.5715,  ..., 0.1538, 0.2843, 0.8623],
          ...,
          [0.9821, 0.7599, 0.1031,  ..., 0.2992, 0.2371, 0.9410],
          [0.4346, 0.3866, 0.8038,  ..., 0.1587, 0.7556, 0.6859],
          [0.8826, 0.8866, 0.8277,  ..., 0.7483, 0.1735, 0.1148]]],


        [[[0.0988, 0.6071, 0.9464,  ..., 0.5317, 0.5102, 0.0210],
          [0.3567, 0.0619, 0.7091,  ..., 0.1209, 0.0633, 0.2986],
          [0.2093, 0.9127, 0.0765,  ..., 0.5565, 0.0369, 0.6273],
          ...,
          [0.0506, 0.9251, 0.7480,  ..., 0.9654, 0.6507, 0.6901],
          [0.7092, 0.6830, 0.9533,  ..., 0.9709, 0.5614, 0.5447],
          [0.6402, 0.8867, 0.2112,  ..., 0.6493, 0.4130, 0.4670]]],


        [[[0.1370, 0.8204, 0.3310,  ..., 0.8133, 0.4507, 0.6091],
          [0.5000, 0.7159, 0.4919,  ..., 0.5215, 0.0480, 0.5203],
          [0.0095, 0.0808, 0.7193,  ..