This notebook was motivated by

[4] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. “U-Net: Convolutional Networks for Biomedical Image Segmentation”.
In: CoRRabs/1505.04597 (2015). arXiv: 1505.04597. url: http://arxiv.org/abs/1505.04597

Implementation: Oleh Bakumenko, University of Duisburg-Essen

In [1]:
import sys
sys.path.append("../") # Otherwise, import from the local folder's parent folder, where your stuff lives.
import PIL
import matplotlib.pyplot as plt
import os
import time
import torch, torch.nn as nn
import albumentations
from typing import List
from torch.multiprocessing import Manager
torch.multiprocessing.set_sharing_strategy("file_system")

from utility import utils as uu
from utility.eval import evaluate_segmentation_model
from utility.segloss import ExampleSegmentationLoss
from pathlib import Path

from utility.plotImageModel import *
from torchsummary import summary

# Data augmentations

Albumentations is a Python library for image augmentation, which apply augmentations also on the target. Some of the commonly used transformations include:
random cropping, random flips of the image horizontally or vertically, rotating the image by a certain angle, rescaleing of the image by a given factor or resizes it to a specific size, adjusting the brightness and contrast, appling Gaussian blur and so on.


In [None]:
cur_path = Path("plots_and_graphs.ipynb")
parent_dir = cur_path.parent.absolute()
masterThesis_folder = str(parent_dir.parent.absolute())+'/'
data_dir = masterThesis_folder+"data/Clean_LiTS/"

augments = albumentations.Compose([
    albumentations.VerticalFlip(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.ColorJitter(brightness=(0.5,1.5), contrast=(1), hue=(-0.1,0.1)),
])

data_augments = augments

# Train, Val, and Test datasets are all contained within this dataset.
# They can be selected by setting 'ds.set_mode(selection)'.

# We could also cache any data we read from disk to shared memory, or
# to regular memory, where each dataloader worker caches the entire
# dataset.

cache_me = False
if cache_me is True:
    cache_mgr = Manager()
    cache_mgr.data = cache_mgr.dict()
    cache_mgr.cached = cache_mgr.dict()
    for k in ["train", "val", "test"]:
        cache_mgr.data[k] = cache_mgr.dict()
        cache_mgr.cached[k] = False

# function from utils, credit: Institute for Artificial Intelligence in Medicine.
# url: https://mml.ikim.nrw/

# dataset outputs a tensor image (dimensions [1,256,256]) and a target list
ds = uu.LiTS_Segmentation_Dataset(
    data_dir = data_dir,
    transforms = data_augments,
    verbose = True,
    cache_data = cache_me,
    cache_mgr = (cache_mgr if cache_me is True else None),
    debug = True,
)


### Hyperparameters 

In [None]:
# Default settings
batch_size = 16
learning_rate = 1e-4
weight_decay = 1e-5
epochs = 6
run_name = "UNet"
device = ("cuda" if torch.cuda.is_available() else "cpu")
time_me = True
mod_step=500

The `torch.utils.data.DataLoader` is a utility class in PyTorch that makes the loading and batching of data for training purposes faster. It simplifies the process by allowing us to specify the dataset, batch size (often 32), and whether the data should be shuffled before each epoch. Additionally, there are other parameters available to further customize the data loading process.

In [None]:
# Dataloader
dl = torch.utils.data.DataLoader(
    dataset = ds, 
    batch_size = batch_size, 
    num_workers = 4, 
    shuffle = True, 
    drop_last = False, 
    pin_memory = True,
    persistent_workers = (not cache_me),
    prefetch_factor = 1
    )

# U-Net

The U-Net model is primarily designed for performing pixel-level or region-level classification within an input image. Its main objective is to generate a segmentation map where each pixel in the input image is assigned a label indicating its class. The input dimensions  are (B x 1 x 256 x 256) and output dimensions  are (B x 3 x 256 x 256).

The main features of the U-Net architecture are:

1. Encoder: The encoder path of the U-Net consists of multiple down-sampling layers. Each down-sampling layer includes two convolutional layers followed by a max-pooling operation. These layers progressively reduce the spatial dimensions of the input image while increasing the number of channels. This allows the model to capture local information and extract lower-level features.

2. Decoder: The decoder path is a mirrored version of the encoder path. It comprises up-sampling layers followed by convolutional layers. The up-sampling layers utilize transposed (inverse) convolution to increase the spatial dimensions of the feature maps. The decoder path plays an important role in recovering the spatial information lost during the down-sampling process and reconstructing the segmented image.

3. Skip Connections: The U-Net architecture incorporates skip connections between the encoder and decoder paths. These connections enable the model to merge (concate) feature maps from the encoder path with corresponding feature maps from the decoder path at the same scale. By fusing high-resolution features from the encoder with up-sampled features from the decoder, the U-Net effectively combines both local and global context information, resulting in accurate segmentation outcomes.

Our modifications to the model include the following:

1. Horizontal 3x3 convolutions will have padding to avoid cropping during skip connections.
2. There will be 3 downsample steps and corresponding skip connections, rather than 4. Therefore, the maximum feature size in the bottom layer will be 512.
3. The final output will be 3 channels wide, as we predict background, liver, and liver tumors.

For each horizontal convolution, the generalized double convolution block is employed. This block comprises two convolutions with a 3x3 kernel, batch normalization to tackle overfitting, and ReLU activation. Furthermore, weight initialization, which has proven effective in classification networks, is also included into the block.

In [2]:
class Conv_Block(nn.Module):
    def __init__(self, n_chans_in,n_chans_out):
        super().__init__()
        self.conv1 = nn.Conv2d(n_chans_in, n_chans_out, kernel_size=3, padding=1,bias= False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans_out)
        self.relu = torch.nn.ReLU()
        self.conv2 = nn.Conv2d(n_chans_out, n_chans_out, kernel_size=3, padding=1,bias= False)
        self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans_out)
        self.relu = torch.nn.ReLU()

        torch.nn.init.kaiming_normal_(self.conv1.weight,nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight,nonlinearity='relu')
        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)
        torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm2.bias)
    def forward(self, x):
        out1 = self.relu(self.batch_norm1(self.conv1(x)))
        out  = self.relu(self.batch_norm2(self.conv2(out1)))
        return out

In [3]:
class UNetMLMed(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
        self.maxPool = nn.MaxPool2d(kernel_size=2)

        self.encoder_layer1 = Conv_Block(1,64)
        self.crop1 = nn.Identity()
        self.encoder_layer2 = Conv_Block(64,128)
        self.crop2 = nn.Identity()
        self.encoder_layer3 = Conv_Block(128,256)
        self.crop3 = nn.Identity()
        self.encoder_layer4 = Conv_Block(256,512)

        self.decoder_upwards4 = torch.nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
        self.decoder_layer3 = Conv_Block(512,256)

        self.decoder_upwards3 = torch.nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        self.decoder_layer2 = Conv_Block(256,128)

        self.decoder_upwards2 = torch.nn.ConvTranspose2d(in_channels=128,out_channels=64 ,kernel_size=2,stride=2)
        self.decoder_layer1 = Conv_Block(128,64)

        self.output_segmentation = nn.Conv2d(in_channels=64, out_channels= 3, kernel_size=1)


    def forward(self, x):
        # a list to save the skip connections
        save = list()
        # encoder part = [Conv_Block() + MaxPool()] x3
        encoder_out1 = self.encoder_layer1(x)
        
        save.append(self.crop1(encoder_out1))
        encoder_in2 = self.maxPool(encoder_out1)
        

        encoder_out2 = self.encoder_layer2(encoder_in2)
        save.append(self.crop2(encoder_out2))
        encoder_in3 = self.maxPool(encoder_out2)
        
        encoder_out3 = self.encoder_layer3(encoder_in3)
        save.append(self.crop3(encoder_out3))
        encoder_in4 = self.maxPool(encoder_out3)
        
        encoder_out4 = self.encoder_layer4(encoder_in4)
        upwards4 = self.decoder_upwards4(encoder_out4)
        # the bottom of the network
        decoder3_in = torch.concat([save[-1],upwards4], dim=1) # skip connection
        # decoder part = [ TransposedConv() + Conv_Block()] x3
        decoder3_out = self.decoder_layer3(decoder3_in)
        
        upwards3 = self.decoder_upwards3(decoder3_out)
        decoder2_in = torch.concat([save[-2],upwards3], dim=1)# skip connection
        decoder2_out = self.decoder_layer2(decoder2_in)
        
        upwards2 = self.decoder_upwards2(decoder2_out)
        decoder1_in = torch.concat([save[-3],upwards2], dim=1)# skip connection
        decoder1_out = self.decoder_layer1(decoder1_in)

        output = self.output_segmentation(decoder1_out)
        return output


In [5]:
model = UNetMLMed()
summary(UNetMLMed(), (1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             576
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,864
       BatchNorm2d-5         [-1, 64, 256, 256]             128
              ReLU-6         [-1, 64, 256, 256]               0
        Conv_Block-7         [-1, 64, 256, 256]               0
          Identity-8         [-1, 64, 256, 256]               0
         MaxPool2d-9         [-1, 64, 128, 128]               0
           Conv2d-10        [-1, 128, 128, 128]          73,728
      BatchNorm2d-11        [-1, 128, 128, 128]             256
             ReLU-12        [-1, 128, 128, 128]               0
           Conv2d-13        [-1, 128, 128, 128]         147,456
      BatchNorm2d-14        [-1, 128, 1

In [None]:
model = model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = weight_decay)

# Specific loss function.

In [None]:
class CrossEntropyPixelWiseLoss(torch.nn.Module):
    """
        Computes  pixel wise Cross-Entropy Loss function
        Inputs:
            prediction: Torch tensor size: torch.Size([batch, 3, height, width])
            targets: List of Torch tensors of length 2, each tensor is size: torch.Size([batch, 1, height, width])
        Outputs:
            pw_loss  = for each target sum_allPixels (-1)target_i*log(pred_i)
    """
    def __init__(self, classes: int = 3, w_l: torch.Tensor = None):
        super().__init__()
        self.classes = classes
        if w_l is None:
            w_l = torch.Tensor([1 for c in range(self.classes)])
        self.weights = w_l


    def forward(self, predictions: torch.Tensor, targets: List[torch.Tensor,]):
        # Predictions size: torch.Size([batch, 3, 256, 256]) 
        # targets: List of tensors of length 2
        # each tensor is size: torch.Size([batch, 1, 256, 256])
 
        batch = predictions.shape[0]
        size = predictions.shape[2]
        ones_matr = torch.ones(batch,1,size,size).to(device)
        
        softmax = nn.Softmax(dim=1)
        predictions = softmax(predictions)

        target_liver = targets[0]
        target_cancer =  targets[1]
        target_bg = ones_matr - target_liver

        product_bg =       ((-1)*target_bg.squeeze()*torch.log(predictions[:, 0, :, :])).sum()
        product_liver =    ((-1)*target_liver.squeeze()*torch.log(predictions[:, 1, :, :])).sum()
        product_cancer =   ((-1)*target_cancer.squeeze()*torch.log(predictions[:, 2, :, :])).sum()

        pw_loss = (self.weights[0]*product_bg+self.weights[1]*product_liver+self.weights[2]*product_cancer)/(batch*size*size)

        return pw_loss


In [None]:
criterion = CrossEntropyPixelWiseLoss(w_l = torch.Tensor([1, 3, 10])).to(device)

In [None]:
if time_me is True:
    c_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    
    # If we are caching, we now have all data and let the (potentially non-persistent) workers know
    if cache_me is True and epoch > 0:
        dl.dataset.set_cached("train")
        dl.dataset.set_cached("val")
    
    # Time me
    if time_me is True:
        e_start = time.time()

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data = data.to(device)
        targets = [target.to(device) for target in targets]
        if time_me is True:
            c_end = time.time()
            if step % 20 == 0:
                print(f"CPU time: {c_end-c_start:.4f}s")
            g_start = time.time()
        predictions = model(data)
        if time_me is True:
            g_end = time.time()
            c_start = time.time()
        if step % 20 == 0 and time_me is True:
            print(f"GPU time: {g_end-g_start:.4f}s")        
        loss = criterion(predictions, targets)
        if step % mod_step == 0:
            print(f"Epoch [{epoch+1}/{epochs}]\t Step [{step+1}/{num_steps}]\t Train Loss: {loss.item():.4f}")
        uu.csv_logger(
            logfile = f"../logs/{run_name}_train.csv",
            content = {"epoch": epoch, "step": step, "loss": loss.item()},
            first = (epoch == 0 and step == 0),
            overwrite = (epoch == 0 and step == 0)
                )
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    metrics = {"epoch": epoch}
    metrics.update(evaluate_segmentation_model(model = model, dataloader = dl, device = device))
    print('\n'.join([f'{m}: {v}' for m, v in metrics.items() if not m.startswith("#")]))
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {m: v for m, v in metrics.items() if not m.startswith("#")},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )
        
    if time_me is True:
        cur_time = time.time()-e_start        
        uu.csv_logger(
            logfile = f"../logs/{run_name}_runtime.csv",
            content = {"epoch": epoch, "time": cur_time},
            first = (epoch == 0),
            overwrite = (epoch == 0)
                )
        print(f"Epoch nr {epoch+1} time: {time.time()-e_start:.4f}s")

# Finally, test time
ds.set_mode("test")
model.eval()
metrics = evaluate_segmentation_model(model = model, dataloader = dl, device = device)
print("Test-time metrics:")
print('\n'.join([f'{m}: {v}' for m, v in metrics.items() if not m.startswith("#")]))
uu.csv_logger(
    logfile = f"../logs/{run_name}_test.csv",
    content = {m: v for m, v in metrics.items() if not m.startswith("#")},
    first = True,
    overwrite = True
        )