# Designing a U-Net

## Import Statements

In [1]:
import torch
import torch.nn as nn
import numpy as np

## U-Net Implementation

In [5]:
def build_double_conv(inp_channels: int, output_channels: int):
    conv = nn.Sequential(
        nn.Conv2d(inp_channels, output_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(output_channels, output_channels, kernel_size=3),
        nn.ReLU(inplace=True),
    )
    return conv

def crop_tensor(orignal_tensor, target_tensor):
    orignal_size = orignal_tensor.size()[2]
    target_size = target_tensor.size()[2]
    delta = orignal_size - target_size
    delta = delta//2
    return orignal_tensor[:, :, delta:orignal_size-delta, delta:orignal_size-delta]



class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.MaxPool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = build_double_conv(1, 64)
        self.down_conv_2 = build_double_conv(64, 128)
        self.down_conv_3 = build_double_conv(128, 256)
        self.down_conv_4 = build_double_conv(256, 512)
        self.down_conv_5 = build_double_conv(512, 1024)
        
        ## Second Part of the architecture
        self.up_transpose_1 = nn.ConvTranspose2d(
                            in_channels=1024,
                            out_channels=512,
                            kernel_size=2,
                            stride = 2)
        self.up_conv_1 = build_double_conv(1024, 512)
        
        self.up_transpose_2 = nn.ConvTranspose2d(
                            in_channels=512,
                            out_channels=256,
                            kernel_size=2,
                            stride = 2)
        self.up_conv_2 = build_double_conv(512, 256)
        
        self.up_transpose_3 = nn.ConvTranspose2d(
                            in_channels=256,
                            out_channels=128,
                            kernel_size=2,
                            stride = 2)
        self.up_conv_3 = build_double_conv(256, 128)
        
        self.up_transpose_4 = nn.ConvTranspose2d(
                            in_channels=128,
                            out_channels=64,
                            kernel_size=2,
                            stride = 2)
        self.up_conv_4 = build_double_conv(128, 64)
        
        ## Output
        self.output = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
        
    
    def forward(self, image):
        ## Encoder
        x1 = self.down_conv_1(image) ## Skip connection
        x2 = self.MaxPool_2x2(x1)
        
        x3 = self.down_conv_2(x2) ## Skip connection
        x4 = self.MaxPool_2x2(x3)
        
        x5 = self.down_conv_3(x4) ## Skip connection
        x6 = self.MaxPool_2x2(x5) 
        
        x7 = self.down_conv_4(x6) ## Skip connection
        x8 = self.MaxPool_2x2(x7)
        
        x9 = self.down_conv_5(x8)
        
        ## Decoder
        x = self.up_transpose_1(x9)
        y = crop_tensor(x7, x)
        x = self.up_conv_1(torch.cat([x, y], axis=1))
        
        x = self.up_transpose_2(x)
        y = crop_tensor(x5, x)
        x = self.up_conv_2(torch.cat([x, y], axis=1))
        
        x = self.up_transpose_3(x)
        y = crop_tensor(x3, x)
        x = self.up_conv_3(torch.cat([x, y], axis=1))
        
        x = self.up_transpose_4(x)
        y = crop_tensor(x1, x)
        x = self.up_conv_4(torch.cat([x, y], axis=1))
        
        ## Output
        x = self.output(x)
        return x
        
        
        
        
        
        
        
        
        
        
        
        
    

In [8]:
image = torch.rand((1, 1, 572, 572))
model = UNet()
model(image)

tensor([[[[-0.0771, -0.0800, -0.0728,  ..., -0.0756, -0.0763, -0.0764],
          [-0.0751, -0.0731, -0.0729,  ..., -0.0740, -0.0707, -0.0783],
          [-0.0756, -0.0729, -0.0755,  ..., -0.0755, -0.0748, -0.0759],
          ...,
          [-0.0724, -0.0765, -0.0723,  ..., -0.0778, -0.0732, -0.0745],
          [-0.0753, -0.0744, -0.0774,  ..., -0.0784, -0.0741, -0.0751],
          [-0.0781, -0.0761, -0.0787,  ..., -0.0726, -0.0768, -0.0796]],

         [[-0.0725, -0.0709, -0.0767,  ..., -0.0739, -0.0784, -0.0722],
          [-0.0727, -0.0750, -0.0772,  ..., -0.0726, -0.0729, -0.0741],
          [-0.0752, -0.0756, -0.0733,  ..., -0.0728, -0.0716, -0.0722],
          ...,
          [-0.0740, -0.0731, -0.0741,  ..., -0.0731, -0.0704, -0.0760],
          [-0.0707, -0.0752, -0.0701,  ..., -0.0741, -0.0705, -0.0753],
          [-0.0721, -0.0786, -0.0725,  ..., -0.0763, -0.0759, -0.0739]]]],
       grad_fn=<MkldnnConvolutionBackward>)

## Dataset

In [None]:
import os
