<a href="https://colab.research.google.com/github/yashika-git/Deep_Learning/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
  def __init__(
      self,
      in_channels=1, # no of input channels
      n_classes=2, # no of output channels
      depth=5, # depth of the network
      wf=6, # no of filters in the first layer is 2*wf
      padding=False, # setting padding=True may result in artifacts
      batch_norm=False, # use batchnorm after layers with an activation function
      up_mode='upconv' # one of 'upconv' or 'upsample'.'upconv'will use transposed convolutions while 'upsample' will use bilinear upsampling
    ):
    super(UNet, self).__init__()
    assert up_mode in ('upconv', 'upsample')
    self.padding = padding
    self.depth = depth
    prev_channels = in_channels
    self.down_path = nn.ModuleList()
    for i in range(depth):
      self.down_path.append(
          UNetConvBlock(prev_channels, 2**(wf + i), padding, batch_norm)
      )
      prev_channels = 2**(wf + i)


    self.up_path = nn.ModuleList()
    for i in reversed(range(depth-1)):
      self.up_path.append(
          UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
      )
      prev_channels = 2 ** (wf + i)

    self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)


  def forward(self, x):
    blocks = []
    for i, down in enumerate(self.down_path):
      x = down(x)
      if i != len(self.down_path) - 1:
        blocks.append(x)
        x = F.max_pool2d(x, 2)

    for i, up in enumerate(self.up_path):
      x = up(x, blocks[-i -1])

    return self.last(x)

class UNetConvBlock(nn.Module):
  def __init__(self, in_size, out_size, padding, batch_norm):
    super(UNetConvBlock, self).__init__()
    block = []

    block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
    block.append(nn.ReLU())
    if batch_norm:
      block.append(nn.BatchNorm2d(out_size))

    block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
    block.append(nn.ReLU())
    if batch_norm:
      block.append(nn.BatchNorm2d(out_size))      

    self.block = nn.Sequential(*block)

  def forward(self, x):
    out = self.block(x)
    return out   
    

class UNetUpBlock(nn.Module):
  def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
    super(UNetUpBlock, self).__init__()
    if up_mode == 'upconv':
      self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
    elif up_mode == 'upsample':
      self.up = nn.Sequential(
      nn.Upsample(mode='bilinear', scale_factor=2),
      nn.Conv2d(in_size, out_size, kernel_size=1),)

      self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

  def center_crop(self, layer, target_size):
    _, _, layer_height, layer_width = layer.size()
    diff_y = (layer_height - target_size[0]) // 2
    diff_x = (layer_width - target_size[1]) // 2
    return layer[
                 :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
    ]  

  def forward(self, x, bridge):
    up = self.up(x)
    crop1 = self.center_crop(bridge, up.shape[2:])
    out = torch.cat([up, crop1], 1)
    out = self.conv_block(out)

    return out

In [10]:
import torch
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=5,n_classes=2, depth=5, wf=3, padding=True, up_mode='upsample').to(device)
