[![deep-learning-notes](https://github.com/semilleroCV/deep-learning-notes/raw/main/assets/banner-notebook.png)](https://github.com/semilleroCV/deep-learning-notes)

# <font color='#4C5FDA'>**U-net from scratch** </font> <a name="tema1">

In [1]:
#@title **Import required libraries**

# Pytorch essentials
import torch # 2.2.1
import torch.nn as nn


In [2]:
print(torch.__version__)

2.2.1


## <font color='#ECA702'>**U-net architecture**</font>

It consists of a contracting
path (left side) and an expansive path (right side). The contracting path follows
the typical architecture of a convolutional network. It consists of the repeated
application of two 3x3 convolutions (unpadded convolutions), each followed by
a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2
for downsampling. At each downsampling step we double the number of feature
channels. Every step in the expansive path consists of an upsampling of the
feature map followed by a 2x2 convolution (“up-convolution”) that halves the
number of feature channels, a concatenation with the correspondingly cropped
feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels
in every convolution. At the final layer a 1x1 convolution is used to map each
64-component feature vector to the desired number of classes. In total the network has 23 convolutional layers.

<div align="center"> <image src="https://imgs.search.brave.com/6lbIK-xzYuzh28AextLXfu6l0sxRrVbSexgE3eSLp_Q/rs:fit:860:0:0/g:ce/aHR0cHM6Ly9tZWRp/YS5nZWVrc2Zvcmdl/ZWtzLm9yZy93cC1j/b250ZW50L3VwbG9h/ZHMvMjAyMjA2MTQx/MjEyMzEvR3JvdXAx/NC5qcGc" width=800>  </div>

Each blue
box corresponds to a multi-channel feature map. The number of channels is denoted
on top of the box. The x-y-size is provided at the lower left edge of the box. White
boxes represent copied feature maps. The arrows denote the different operations.


### <font color='#52F17F'>**Architecture Modules**</font>

In [3]:
class DoubleConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
     super().__init__()
     self.double_conv = nn.Sequential(
         nn.Conv2d(in_channels, out_channels, kernel_size=3),
         nn.ReLU(inplace=True),
         nn.Conv2d(out_channels, out_channels, kernel_size=3),
         nn.ReLU(inplace=True)
     )

  def forward(self, x):
    return self.double_conv(x)


class Downscaling(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.conv = DoubleConvBlock(in_channels, out_channels)

  def forward(self, x):
    p = self.pool(x)
    down = self.conv(p)
    return x, down


class Upscaling(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    # To scale up the feature maps I used transposed convolution. 
    self.up_conv = nn.ConvTranspose2d(in_channels, in_channels//2, 2, stride=2)
    self.conv = DoubleConvBlock(in_channels, out_channels)

  def center_crop(self, layer, target_size):
    _, _, layer_height, layer_width = layer.size()
    diff_y = (layer_height - target_size[0]) // 2
    diff_x = (layer_width - target_size[1]) // 2
    return layer[
        :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
    ]

  def forward(self, x1, x2):
    x1 = self.up_conv(x1)
    x2 = self.center_crop(x2, x1.shape[2:])
    x = torch.cat([x1, x2], 1)
    return self.conv(x)


class OutConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, 1)

  def forward(self, x):
    return self.conv(x)

In [4]:
class UNet(nn.Module):
  def __init__(self, n_channels, n_classes):
    super().__init__()

    self.n_channels = n_channels
    self.n_classes = n_classes

    # Following the Unet structure, we create the model with the blocks declared above
    self.inc = DoubleConvBlock(n_channels, 64)
    self.down1 = Downscaling(64, 128) 
    self.down2 = Downscaling(128, 256) 
    self.down3 = Downscaling(256, 512) 
    self.down4 = Downscaling(512, 1024)

    self.up_conv1 = Upscaling(1024, 512)
    self.up_conv2 = Upscaling(512, 256)
    self.up_conv3 = Upscaling(256, 128)
    self.up_conv4 = Upscaling(128, 64)

    self.out = OutConv(64, n_classes)

  def forward(self, x):
    # We create the UNet forward, don't forget the residual connections

    # 3x572x572 to 64x568x568
    inc = self.inc(x)

    # 64x568x568 to 128x280x280
    crop1, down1 = self.down1(inc)

    # 6128x280x280 to 256x136x136
    crop2, down2 = self.down2(down1)

    # 256x136x136 to 512x64x64
    crop3, down3 = self.down3(down2)

    # 512x64x64 to 1024x28x28
    crop4, down4 = self.down4(down3)

    # 1024x28x28 to 512x52x52
    upsampling1 = self.up_conv1(down4, crop4)

    # 512x52x52 to 256x100x100
    upsampling2 = self.up_conv2(upsampling1, crop3)

    # 256x100x100 to 128x196x196
    upsampling3 = self.up_conv3(upsampling2, crop2)

    # 128x196x196 to 64x388x388
    upsampling4 = self.up_conv4(upsampling3, crop1)

    # 64x388x388 to 1x388x388
    ouput = self.out(upsampling4)
    return ouput

In [5]:
# Test the model to see if it returns the expected output.

input_image = torch.rand([2, 1, 572, 572])
print(f"Input: {input_image.size()}")
model = UNet(n_channels=1, n_classes=1)
ouput = model(input_image)
print(f"Ouput: {ouput.size()}") # Expected: [2, 1, 388, 388]

Input: torch.Size([2, 1, 572, 572])
Ouput: torch.Size([2, 1, 388, 388])


In [6]:
print(model)

UNet(
  (inc): DoubleConvBlock(
    (double_conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (3): ReLU(inplace=True)
    )
  )
  (down1): Downscaling(
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv): DoubleConvBlock(
      (double_conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
  )
  (down2): Downscaling(
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv): DoubleConvBlock(
      (double_conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU(inplace=T