In [1]:
import  torch
import  torch.nn as  nn
import  torchvision.transforms.functional as func

In [70]:
class UNet(nn.Module):
    def __init__(self, img_channel, out_dim):
        super(UNet, self).__init__()
        self.img_channel = img_channel
        self.out_dim = out_dim
        self.config = [64, 128, 256, 512]


        self.maxpool_2x2 = nn.MaxPool2d(kernel_size=(2,2), stride=2)
        self.down_block = self._down_block()
        self.final_conv1x1 = nn.Conv2d(self.config[0], self.out_dim, (3,3), (1,1), 1)

        self.bottle_neck = self._conv_block(512, 1024)
        # print(self.down_block)
        self.up_block = self._up_block()
        # print(self.up_block)


    def _up_block(self):
        layers = []
        # backward iter the self.config
        for i in range(len(self.config) - 1, -1, -1):
            in_channel = self.config[i]
            layers.append(nn.ConvTranspose2d(in_channel * 2, in_channel, (4,4), (2,2), (1,1)))
            layers.append(self._conv_block(in_channel * 2, in_channel))
        return nn.Sequential(*layers)

    def _down_block(self):
        in_channel = self.img_channel
        layers = []
        for i in range(len(self.config)):
            layers.append(self._conv_block(in_channel, self.config[i]))
            layers.append(self.maxpool_2x2)
            in_channel = self.config[i]
        return nn.Sequential(*layers)


    def _conv_block(self, inchannel, outchannel, k_s=(3,3), s=1, p=0):
        return  nn.Sequential(
            nn.Conv2d(inchannel, outchannel, k_s, s, p),
            nn.ReLU(),
            nn.Conv2d(outchannel, outchannel, k_s, s, p),
            nn.ReLU()
        )

    def forward(self, x):
        down_data = []
        for layer in self.down_block:
            x = layer(x)
            print(x.shape)
            if isinstance(layer, nn.Sequential):
                down_data.append(x)

        x = self.bottle_neck(x)

        back_count = len(down_data) - 1
        for layer in self.up_block:
            if isinstance(layer, nn.Sequential):
                # if is Conv_block
                # resize accord down_data to x's size
                # then concat resized tensor and x
                temp = func.resize(down_data[back_count], x.shape[2:])
                x = torch.cat((x, temp), dim=1)
                x = layer(x)
                back_count -= 1
            else:
                # if is convtranspose, go through it
                x = layer(x)
        x = self.final_conv1x1(x)
        return x

In [71]:
model = UNet(3, 2)
x = torch.randn(8, 3, 572, 572)
model(x).shape

torch.Size([8, 64, 568, 568])
torch.Size([8, 64, 284, 284])
torch.Size([8, 128, 280, 280])
torch.Size([8, 128, 140, 140])
torch.Size([8, 256, 136, 136])
torch.Size([8, 256, 68, 68])
torch.Size([8, 512, 64, 64])
torch.Size([8, 512, 32, 32])


torch.Size([8, 2, 388, 388])

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
DEVICE

'cpu'