UNet Architecture Code:

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as Func

class UNet(nn.Module):
    def __init__(self, input_channels, output_classes, use_bilinear=True):
        super(UNet, self).__init__()
        self.input_channels = input_channels
        self.output_classes = output_classes
        self.use_bilinear = use_bilinear

        self.initial_conv = DoubleConv(input_channels, 64)
        self.encoder1 = Down(64, 128)
        self.encoder2 = Down(128, 256)
        self.encoder3 = Down(256, 512)
        reduction_factor = 2 if use_bilinear else 1
        self.encoder4 = Down(512, 1024 // reduction_factor)
        self.decoder1 = Up(1024, 512 // reduction_factor, use_bilinear)
        self.decoder2 = Up(512, 256 // reduction_factor, use_bilinear)
        self.decoder3 = Up(256, 128 // reduction_factor, use_bilinear)
        self.decoder4 = Up(128, 64, use_bilinear)
        self.final_conv = OutConv(64, output_classes)

    def forward(self, input_tensor):
        layer1 = self.initial_conv(input_tensor)
        layer2 = self.encoder1(layer1)
        layer3 = self.encoder2(layer2)
        layer4 = self.encoder3(layer3)
        layer5 = self.encoder4(layer4)
        decoded = self.decoder1(layer5, layer4)
        decoded = self.decoder2(decoded, layer3)
        decoded = self.decoder3(decoded, layer2)
        decoded = self.decoder4(decoded, layer1)
        logits = self.final_conv(decoded)
        return logits

if __name__ == '__main__':
    network = UNet(39, 39, True)
    network = network.cuda()
    image_tensor = torch.Tensor(32, 39, 64, 64).cuda()
    result_tensor = network(image_tensor)
    print(result_tensor.shape)


Convolutional Blocks Code:

In [None]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, channels_in, channels_out, channels_mid=None):
        super().__init__()
        if not channels_mid:
            channels_mid = channels_out
        self.conv_layers = nn.Sequential(
            nn.Conv2d(channels_in, channels_mid, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels_mid),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels_mid, channels_out, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, tensor_in):
        return self.conv_layers(tensor_in)


class Down(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()
        self.pool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(channels_in, channels_out)
        )

    def forward(self, tensor_in):
        return self.pool_conv(tensor_in)


class Up(nn.Module):
    def __init__(self, channels_in, channels_out, use_bilinear=True):
        super().__init__()
        if use_bilinear:
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(channels_in, channels_out, channels_in // 2)
        else:
            self.upsample = nn.ConvTranspose2d(channels_in , channels_in // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(channels_in, channels_out)

    def forward(self, tensor_in, tensor_skip):
        tensor_in = self.upsample(tensor_in)
        diffY = tensor_skip.size()[2] - tensor_in.size()[2]
        diffX = tensor_skip.size()[3] - tensor_in.size()[3]

        tensor_in = Func.pad(tensor_in, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        tensor = torch.cat([tensor_skip, tensor_in], dim=1)
        return self.conv(tensor)


class OutConv(nn.Module):
    def __init__(self, channels_in, channels_out):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(channels_in, channels_out, kernel_size=1)

    def forward(self, tensor_in):
        return self.conv(tensor_in)
