Source: https://nn.labml.ai/unet/index.html

In [1]:
import torch
import torchvision.transforms.functional
from torch import nn

In [2]:
class DoubleConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels, out_channels=out_channels,
                kernel_size=3, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=out_channels, out_channels=out_channels,
                kernel_size=3, padding=1
            ),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv(x)

In [3]:
class DownSample(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size=2)
    
    def forward(self, x):
        return self.pool(x)

In [4]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(
            in_channels, out_channels,
            kernel_size=2, stride=2
        )
    
    def forward(self, x):
        return self.up(x)

At every step in the expansive path the corresponding feature map from the contracting path concatenated with the current feature map.

In [5]:
class CropAndConcat(nn.Module):
    def forward(self, x, constracting_x):
        x_height, x_width = x.shape[2], x.shape[3]
        constracting_x = torchvision.transforms.functional.center_crop(x_height, x_width)
        
        # concentrate the feature maps
        x = torch.cat([x, constracting_x], dim=1)
        return x

In [6]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # constracting path
        down_features = [(in_channels, 64), (64, 128), (128, 256), (256, 512)]
        
        self.down_conv = nn.ModuleList(
            [DoubleConvolution(n_in, n_out) for n_in, n_out in down_features]
        )
        self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])
        
        # the bottom
        self.middle_conv = DoubleConvolution(512, 1024)
        
        # expansion path
        up_features = [(1024, 512), (512, 256), (256, 128), (128, 64)]
        self.up_sample = nn.ModuleList(
            [UpSample(n_in, n_out) for n_in, n_out in up_features]
        )
        
        self.up_conv = nn.ModuleList(
            [DoubleConvolution(n_in, n_out) for n_in, n_out in up_features]
        )
        
        # skip connection
        self.skip_connection = nn.ModuleList([CropAndConcat() for _ in range(4)])
        
        # the final convolution in expansion path
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        residuals = []
        
        for i in range(len(self.down_conv)):
            x = self.down_conv[i](x)
            residuals.append(x)
            x = self.down_sample[i](x)
        
        x = self.middle_conv(x)
        
        for i in range(len(self.up_conv)):
            x = self.up_sample[i](x)
            x = self.skip_connection[i](x, residuals.pop())
            x = self.up_conv[i](x)
        
        x = self.final_conv(x)

In [7]:
model = UNet(in_channels=3, out_channels=2)

In [8]:
model

UNet(
  (down_conv): ModuleList(
    (0): DoubleConvolution(
      (conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (1): DoubleConvolution(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (2): DoubleConvolution(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (3): DoubleConvolution(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        

In [9]:
x = torch.randn((20, 3, 64, 64)).float()

In [10]:
x.shape

torch.Size([20, 3, 64, 64])

In [None]:
from torchsummary import summary

In [None]:
summary(model, (20, 3, 64, 64))