In [4]:
import torch
import torch.nn as nn

""" Convolutional block:
    It follows a two 3x3 convolutional layer, each followed by a batch normalization and a relu activation.
"""
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)

        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)

        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

""" Encoder block:
    It consists of an conv_block followed by a max pooling.
    Here the number of filters doubles and the height and width half after every block.
"""
class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p

""" Decoder block:
    The decoder block begins with a transpose convolution, followed by a concatenation with the skip
    connection from the encoder block. Next comes the conv_block.
    Here the number filters decreases by half and the height and width doubles.
"""
class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)

        return x


class build_unet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        """ Bottleneck """
        self.b = conv_block(512, 1024)

        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        """ Bottleneck """
        b = self.b(p4)

        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        """ Classifier """
        outputs = self.outputs(d4)

        return outputs

if __name__ == "__main__":
    # inputs = torch.randn((2, 32, 256, 256))
    # e = encoder_block(32, 64)
    # x, p = e(inputs)
    # print(x.shape, p.shape)
    #
    # d = decoder_block(64, 32)
    # y = d(p, x)
    # print(y.shape)

    inputs = torch.randn((2, 3, 512, 512))
    model = build_unet()
    y = model(inputs)
    print(y.shape)

torch.Size([2, 1, 512, 512])




1.  Importing Libraries: The code begins by importing the necessary libraries, including torch for PyTorch and nn for defining neural network layers

2. **conv_block Clas**s: This class defines a convolutional block, which is a basic building block of the U-Net. It contains two convolutional layers with batch normalization and ReLU activation functions. The purpose of this block is to learn features from the input.

3. **encoder_block Class**: An encoder block consists of a conv_block followed by max-pooling. Max-pooling reduces the spatial dimensions of the feature maps while increasing the number of filters. This class is responsible for downsampling the input.

4. **decoder_block Class**: A decoder block is used in the decoding part of the U-Net. It begins with a transpose convolution (also known as deconvolution or upsampling), followed by concatenation with the skip connection from the corresponding encoder block. Then, another conv_block is applied to learn features. This block is responsible for upsampling the feature maps.

5. **build_unet Class**: This class defines the complete U-Net architecture. It includes the encoder, bottleneck (central block), and decoder. The number of filters in each encoder block doubles, while in the decoder block, it decreases by half. The classifier layer at the end produces the final output.

6. **forward Method in build_unet Clas**s: This method specifies the forward pass of the U-Net. It first passes the input through the encoder, then the bottleneck, and finally the decoder. Skip connections are used to concatenate the feature maps from the encoder to the corresponding decoder layers. The classifier layer produces the final output.

7. **if __name__ == "__main__"** Block: This block provides an example of how to use the defined classes. It creates random input data and passes it through the U-Net model, printing the shape of the output

