## U-Net:
- ### Consider a U-Net segmentation network with the three paths
- ### If we want to use this on an MRI dataset, then it's a 3D convolutional network

In [None]:
import os
import cv2
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchsummary import summary

### Contracting Path
- #### The first is a contracting path of the network

In [None]:
class ContractingPath(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(n_channels, 64, kernel_size=3)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
        self.maxpool = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3)
        self.conv7 = nn.Conv2d(256, 512, kernel_size=3)
        self.conv8 = nn.Conv2d(512, 512, kernel_size=3)
    
    def forward(self, x):
        out = self.conv1(x)
        out_1 = self.conv2(out)
        out = self.maxpool(out_1)
        out = self.conv3(out)
        out_2 = self.conv4(out)
        out = self.maxpool(out_2)
        out = self.conv5(out)
        out_3 = self.conv6(out)
        out = self.maxpool(out_3)
        out = self.conv7(out)
        out_4 = self.conv8(out)
        return [out_1, out_2, out_3, out_4]
model = ContractingPath(3)
summary(model, input_size = (3, 256, 256,))

### Bridge
- #### The Bridge takes the input from the contracting path and connects it to the expanding path

In [None]:
class Bridge(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(n_channels, 1024, kernel_size=3)
        self.conv2 = nn.Conv2d(1024, 512, kernel_size=3)
        self.UpConv1 = nn.ConvTranspose2d(512, 512, kernel_size = 4, stride = 2, padding = 1)
    
    def forward(self, x):
        out1 = self.maxpool(x)
        out2 = self.conv1(out1)
        out3 = self.conv2(out2)
        out = self.UpConv1(out3)
        return out
bridge = Bridge(512)
summary(bridge, (512, 64, 64))

### Putting everything together
- #### Using the two modules above, we can put the model together
- #### The expanding path and the Bridge are added together here to give the final map

In [None]:
class U_Net(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.contracting_path = ContractingPath(in_channels)
        self.bridge = Bridge(512)
        self.conv1 = nn.Conv2d(1024, 512, kernel_size = 3)
        self.conv2 = nn.Conv2d(512, 256, kernel_size = 3)
        self.UpConv1 = nn.ConvTranspose2d(256, 256, kernel_size = 4, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d(512, 256, kernel_size = 3)
        self.conv4 = nn.Conv2d(256, 256, kernel_size = 3)
        self.UpConv2 = nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1)
        self.conv5 = nn.Conv2d(256, 128, kernel_size = 3)
        self.conv6 = nn.Conv2d(128, 64, kernel_size = 3)
        self.UpConv3 = nn.ConvTranspose2d(64, 64, kernel_size = 4, stride = 2, padding = 1)
        self.conv_f1 = nn.Conv2d(128, 64, kernel_size = 3)
        self.conv_f2 = nn.Conv2d(64, 64, kernel_size = 3)
        self.conv_final = nn.Conv2d(64, 1, kernel_size = 1)
        #self.apply(self.init_weights_)
    
    def init_weights_(self, module):
        if isinstance(module, nn.Conv2d):
            torch.nn.init.normal_(module)
    
    def forward(self, x):
        out_c1, out_c2, out_c3, out_c4 = self.contracting_path(x)
        out_b = self.bridge(out_c4)
        height, width = out_b.shape[2:]
        out_c4 = out_c4[..., :height, :width]
        out_4 = torch.cat([out_c4, out_b], dim = 1) # Shape (1, 1024, 56, 56)
        out_4 = self.conv1(out_4) # Shape (1, 512, 54, 54)
        out_4 = self.conv2(out_4) # Shape (1, 256, 52, 52)
        out_3 = self.UpConv1(out_4) # Shape (1, 256, 104, 104)
        height, width = out_3.shape[2:] # (104, 104)
        out_c3 = out_c3[..., :height, :width] # Shape (1, 256, 104, 104)
        out_3 = torch.cat([out_c3, out_3], dim = 1) # Shape (1, 512, 104, 104)
        out_3 = self.conv3(out_3) # Shape (1, 256, 102, 102)
        out_3 = self.conv4(out_3) # Shape (1, 256, 100, 100)
        out_2 = self.UpConv2(out_3) # Shape (1, 128, 200, 200)
        height, width = out_2.shape[2:] # (200, 200)
        out_c2 = out_c2[..., :height, :width] # Shape (1, 128, 200, 200)
        out_2 = torch.cat([out_c2, out_2], dim = 1) # Shape (1, 256, 200, 200)
        out_2 = self.conv5(out_2) # Shape (1, 128, 198, 198)
        out_1 = self.conv6(out_2) # Shape (1, 64, 196, 196)
        out_1 = self.UpConv3(out_1) # Shape (1, 64, 392, 392)
        height, width = out_1.shape[2:] # (392, 392)
        out_c1 = out_c1[..., :height, :width] # Shape (1, 64, 392, 392)
        out = torch.cat([out_c1, out_1], dim = 1) # Shape (1, 128, 392, 392)
        out = self.conv_f1(out) # Shape (1, 64, 390, 390)
        out = self.conv_f2(out) # Shape (1, 64, 388, 388)
        final_output = self.conv_final(out) # Shape (1, 2, 388, 388)
        return final_output
u_net = U_Net(3)
summary(u_net, (3, 256, 256))

In [None]:
samp_in = torch.randint(0, 10, (1, 3, 256, 256), dtype = torch.float32)
sample_out = u_net(samp_in)
print(sample_out.shape)
assert sample_out.shape == torch.Size([1, 1, 388, 388]) 

### Training Loop
- #### Get batches of data from the datasets and obtain the output masks
- #### Apply Softmax to get the logits at the class dimension for the output map

In [None]:
training_data_dict = dict(images = [], masks = [])
testing_data_dict = dict(images = [], masks = [])
validation_data_dict = dict(images = [], masks = [])
# A dataset with lists of images and masks
data_dict = dict(train = training_data_dict, valid = validation_data_dict,
                test = testing_data_dict)
def get_images_masks(split, batch_size):
    assert split in ['train', 'test', 'valid']
    dataset_dict = data_dict[split]
    idx = torch.randint(0, len(dataset_dict['images']), (batch_size,))
    images = torch.stack([dataset_dict['images'][i]for i in idx], dim = 0)
    masks = torch.stack([dataset_dict['masks'][i]for i in idx], dim = 0)
    return images, masks

In [None]:
n_epochs, steps, losses, batch_size = 100, [], [], 64
optimizer = torch.optim.AdamW(u_net.parameters(), lr = 0.0075)
for step in tqdm(range(n_epochs)):
    images, ground_truth_masks = get_images_masks('train', batch_size)
    ground_truth_masks = ground_truth_masks.unsqueeze(1)
    # Forward pass
    predicted_masks = u_net(images)
    optimizer.zero_grad(set_to_none = True)
    loss = F.binary_cross_entropy(predicted_masks, ground_truth_masks)
    losses.append(loss.item())
    steps.append(step)
    
    # Backpropagation
    loss.backward()
    
    # Gradient update
    optimizer.step()
plt.plot(steps, losses)