source: https://medium.com/@fernandopalominocobo/mastering-u-net-a-step-by-step-guide-to-segmentation-from-scratch-with-pytorch-6a17c5916114

![unet-arch](../assets/unet_arch.png)

In [2]:
import copy
import os
import random
import shutil
import zipfile
from math import atan2, cos, sin, sqrt, pi, log

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from numpy import linalg as LA
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm

# Create Model

Double Convolutions with ReLU

In [2]:
class DoubleConv(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        # two convolutions with 3x3 kernel
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        
        return self.conv_op(x)

Downsampling involves using double convolutions followed by max pooling
- We also save the convolutioned tensore before max pooling to allow for skip connections between low and high level features

In [3]:
class Downsample(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.double_conv = DoubleConv(in_channels, out_channels)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        down = self.double_conv(x)
        p = self.max_pool(down)
        
        return down, p

Upsampling involves deconvolution followed by the double convolution.

In [4]:
class Upsample(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.double_conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        
        return self.double_conv(x)

Define full model

In [6]:
class UNet(nn.Module):
    
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        # decoder
        self.down_conv_1 = Downsample(in_channels, out_channels=64)
        self.down_conv_2 = Downsample(in_channels=64, out_channels=128)
        self.down_conv_3 = Downsample(in_channels=128, out_channels=256)
        self.down_conv_4 = Downsample(in_channels=256, out_channels=512)
        
        # bottleneck
        self.bottle_neck = DoubleConv(in_channels=512, out_channels=1024)
        
        # encoder
        self.up_conv_1 = Upsample(in_channels=1024, out_channels=512)
        self.up_conv_2 = Upsample(in_channels=512, out_channels=256)
        self.up_conv_3 = Upsample(in_channels=256, out_channels=128)
        self.up_conv_4 = Upsample(in_channels=128, out_channels=64)
        
        # segmentation map
        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
        
    def forward(self, x):
        
        # decoder
        down_1, p1 = self.down_conv_1(x)
        down_2, p2 = self.down_conv_2(p1)
        down_3, p3 = self.down_conv_3(p2)
        down_4, p4 = self.down_conv_4(p3)
        
        # bottleneck
        b = self.bottle_neck(p4)
        
        # encoder
        up_1 = self.up_conv_1(b, down_4)
        up_2 = self.up_conv_2(up_1, down_3)
        up_3 = self.up_conv_3(up_2, down_2)
        up_4 = self.up_conv_4(up_3, down_1)
        
        # segmentation map
        out = self.out(up_4)
        return out

Lets test with dummy data

In [None]:
input_image = torch.rand((1, 3, 512, 512))
model = UNet(3, 10) # rgb 3 channel / 10 classes
output = model(input_image)
print(output.size())