<a href="https://colab.research.google.com/github/KeisukeShimokawa/papers-challenge/blob/master/src/cv/UNet/notebooks/UNet_model_forward.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

In [2]:
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=0),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=0),
        nn.ReLU(inplace=True)
    )
    return conv

In [13]:
def crop_img(tensor, target_tensor):
    target_size = target_tensor.shape[2]
    tensor_size = tensor.shape[2]

    delta = tensor_size - target_size
    delta = delta // 2

    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

In [30]:
class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=(2, 2))
        self.down_conv1 = double_conv(  1,   64)
        self.down_conv2 = double_conv( 64,  128)
        self.down_conv3 = double_conv(128,  256)
        self.down_conv4 = double_conv(256,  512)
        self.down_conv5 = double_conv(512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv_1 = double_conv(1024, 512)
        self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv_2 = double_conv(512, 256)
        self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv_3 = double_conv(256, 128)
        self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv_4 = double_conv(128, 64)

        self.out = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, image):
        # bs, c, h, w
        # encoder
        x1 = self.down_conv1(image) # 
        # print(x1.shape)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv2(x2)    # 
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv3(x4)    # 
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv4(x6)    # 
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv5(x8)
        # print(x9.shape)

        # decoder
        x = self.up_trans_1(x9)
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([x, y], dim=1))

        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([x, y], dim=1))

        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([x, y], dim=1))

        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([x, y], dim=1))

        x = self.out(x)
        print(x.shape)

In [31]:
image = torch.rand((1, 1, 572, 572))
model = UNet()
print(model(image))

torch.Size([1, 1, 388, 388])
None
