<a href="https://colab.research.google.com/github/therishabhmittal-05/camera/blob/main/Semantic_Segmentation/UNet_Semantic_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
from torch import nn, optim

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels:int, out_channels:int):
    super().__init__()
    self.doubleconv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU()
    )

  def forward(self, x):
    return self.doubleconv(x)

In [4]:
class DownSample(nn.Module):
  def __init__(self):
    super().__init__()
    self.downsample = nn.MaxPool2d(2)
  def forward(self, x):
    return self.downsample(x)

In [6]:
class UpSample(nn.Module):
  def __init__(self, in_channels:int, out_channels:int):
    super().__init__()
    self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
  def forward(self, x):
    return self.upsample(x)

In [16]:
class CropandConcat(nn.Module):
  def forward(self, x, contract_x):
    contract_x = torchvision.transforms.functional.center_crop(contract_x, (x.shape[2], x.shape[3]))
    return torch.cat([x, contract_x], dim=1)

In [17]:
class Unet(nn.Module):
  def __init__(self, in_channels:int, out_channels:int):
    super().__init__()
    self.downconv = nn.ModuleList(
        [DoubleConv(i, o) for i, o in [(in_channels, 64), (64, 128), (128, 256), (256, 512)]]
    )
    self.downsample = nn.ModuleList(
       [DownSample() for _ in range(4)]
    )

    self.middleconv = DoubleConv(512, 1024)
    self.upsample = nn.ModuleList(
        [UpSample(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]]
    )
    self.upconv = nn.ModuleList(
        [DoubleConv(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]]
    )
    self.concat = nn.ModuleList(
        [CropandConcat() for _ in range(4)]
    )
    self.finalconv = nn.Conv2d(64, out_channels, kernel_size=1)

  def forward(self, x):
    pass_contract = []
    for i in range(len(self.downconv)):
      x = self.downconv[i](x)
      pass_contract.append(x)
      x = self.downsample[i](x)
    x = self.middleconv(x)
    for i in range(len(self.upconv)):
      x = self.upsample[i](x)
      x = self.concat[i](x, pass_contract.pop())
      x = self.upconv[i](x)
    return self.finalconv(x)

In [18]:
model = Unet(64, 64).to(device)
model


Unet(
  (downconv): ModuleList(
    (0): DoubleConv(
      (doubleconv): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (1): DoubleConv(
      (doubleconv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (2): DoubleConv(
      (doubleconv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (3): DoubleConv(
      (doubleconv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2):

In [19]:
from torchsummary import summary
summary(model, (64, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]          36,928
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
        DoubleConv-5         [-1, 64, 256, 256]               0
         MaxPool2d-6         [-1, 64, 128, 128]               0
        DownSample-7         [-1, 64, 128, 128]               0
            Conv2d-8        [-1, 128, 128, 128]          73,856
              ReLU-9        [-1, 128, 128, 128]               0
           Conv2d-10        [-1, 128, 128, 128]         147,584
             ReLU-11        [-1, 128, 128, 128]               0
       DoubleConv-12        [-1, 128, 128, 128]               0
        MaxPool2d-13          [-1, 128, 64, 64]               0
       DownSample-14          [-1, 128,