<a href="https://colab.research.google.com/github/s183796/AIStudentProjects/blob/final_christine/2d_unet_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Image, display, clear_output
import numpy as np
%matplotlib nbagg
%matplotlib inline
import seaborn as sns
import pandas as pd

import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu
from torch.nn.functional import softmax
import PIL.Image
import os
import torchvision
import cv2

from torchvision import transforms
from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, Dataset, Subset

import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF
import glob
from sklearn.model_selection import train_test_split

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
pip install albumentations



In [3]:
import albumentations as A
from albumentations.pytorch import ToTensorV2


In [4]:
# Source: https://towardsdatascience.com/cook-your-first-u-net-in-pytorch-b3297a844cf3, visited the 16th of November 2023
# Modifications have been made to the original clde with adding batch normalization, dropout and changing the input and output sizes

class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        #Encoder
        # Input: 1x128x128
        self.e11 = nn.Conv2d(1, 64, kernel_size=3,padding=1)
        self.bn11 = nn.BatchNorm2d(64) # batch normalization
        self.e12 = nn.Conv2d(64, 64, kernel_size=3,padding=1)
        self.bn12 = nn.BatchNorm2d(64) # batch normalization
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) #64x64x64

        self.e21 = nn.Conv2d(64, 128, kernel_size=3,padding=1)
        self.bn21 = nn.BatchNorm2d(128) # batch normalization
        self.e22 = nn.Conv2d(128, 128, kernel_size=3,padding=1)
        self.bn22 = nn.BatchNorm2d(128) # batch normalization
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) #32x32x128

        self.e31 = nn.Conv2d(128, 256, kernel_size=3,padding=1)
        self.bn31 = nn.BatchNorm2d(256) # batch normalization
        self.e32 = nn.Conv2d(256, 256, kernel_size=3,padding=1)
        self.bn32 = nn.BatchNorm2d(256) # batch normalization
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) #16x16x256

        self.e41 = nn.Conv2d(256, 512, kernel_size=3,padding=1)
        self.bn41 = nn.BatchNorm2d(512) # batch normalization
        self.e42 = nn.Conv2d(512, 512, kernel_size=3,padding=1)
        self.bn42 = nn.BatchNorm2d(512) # batch normalization
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) #8x8x512

        self.e51 = nn.Conv2d(512, 1024, kernel_size=3,padding=1)
        self.bn51 = nn.BatchNorm2d(1024) # batch normalization
        self.e52 = nn.Conv2d(1024, 1024, kernel_size=3,padding=1)
        self.bn52 = nn.BatchNorm2d(1024) # batch normalization

        self.dropout = nn.Dropout(0.5)

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024,512,kernel_size=2,stride=2) #16x16x1024
        self.d11 = nn.Conv2d(1024,512,kernel_size=3,padding=1)
        self.d12 = nn.Conv2d(512,512,kernel_size=3,padding=1)

        self.upconv2 = nn.ConvTranspose2d(512,256,kernel_size=2,stride=2) #
        self.d21 = nn.Conv2d(512,256,kernel_size=3,padding=1)
        self.d22 = nn.Conv2d(256,256,kernel_size=3,padding=1)

        self.upconv3 = nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)
        self.d31 = nn.Conv2d(256,128,kernel_size=3,padding=1)
        self.d32 = nn.Conv2d(128,128,kernel_size=3,padding=1)

        self.upconv4 = nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)
        self.d41 = nn.Conv2d(128,64,kernel_size=3,padding=1)
        self.d42 = nn.Conv2d(64,64,kernel_size=3,padding=1)

        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe11 = F.relu(self.bn11(self.e11(x)))
        xe12 = F.relu(self.bn12(self.e12(xe11)))
        xp1 = self.pool1(xe12)

        #xp1 = self.dropout(xp1) # dropout
        xe21 = F.relu(self.bn21(self.e21(xp1)))
        xe22 = F.relu(self.bn22(self.e22(xe21)))
        xp2 = self.pool2(xe22)

        xp2 = self.dropout(xp2) # dropout
        xe31 = F.relu(self.bn31(self.e31(xp2)))
        xe32 = F.relu(self.bn32(self.e32(xe31)))
        xp3 = self.pool3(xe32)

        xp3 = self.dropout(xp3) # dropout
        xe41 = F.relu(self.bn41(self.e41(xp3)))
        xe42 = F.relu(self.bn42(self.e42(xe41)))
        xp4 = self.pool4(xe42)

        xp4 = self.dropout(xp4) # dropout
        xe51 = F.relu(self.bn51(self.e51(xp4)))
        xe52 = F.relu(self.bn52(self.e52(xe51)))

        # Up-convolutions
        xup1 = self.upconv1(xe52)
        xup1 = self.dropout(xup1) # dropout
        xcat = torch.cat([xup1, xe42], dim=1)
        #xcat=xup1

        xup21 = F.relu(self.d11(xcat))
        xup22 = F.relu(self.d12(xup21))

        xup2 = self.upconv2(xup22)
        xup2 = self.dropout(xup2) # dropout
        #xcat2 = torch.cat([xup2, xe32[:,:,:-1,:-1]], dim=1)
        xcat2 = torch.cat([xup2, xe32], dim=1)
        #xcat2=xup2


        xup31 = F.relu(self.d21(xcat2))
        xup32 = F.relu(self.d22(xup31))
        xup3 = self.upconv3(xup32)
        #xup3 = self.dropout(xup3) # dropout
        xcat3 = torch.cat([xup3, xe22], dim=1)
        #xcat3=xup3

        xup41 = F.relu(self.d31(xcat3))
        xup42 = F.relu(self.d32(xup41))

        xup4 = self.upconv4(xup42)
        #xup4 = self.dropout(xup4) # dropout
        #xcat4 = torch.cat([xup4, xe12[:,:,2:-3,2:-3]], dim=1)
        xcat4 = torch.cat([xup4, xe12], dim=1)
        #xcat4=xup4

        xup51 = F.relu(self.d41(xcat4))
        xup52 = F.relu(self.d42(xup51))

        out = self.outconv(xup52)

        #output = softmax(out, dim=1)
        output=out
        return output

In [5]:
#Setting up hyper parameters, from exercise week 6
loss_fn =  nn.CrossEntropyLoss()   #Choosing cross entropy loss

import random


In [6]:
#Training size of 65%
#Splitting test and validation set in 50%
test_size=0.35
batch_size = 16

training_idx, test_idx = train_test_split(
    range(300),
    test_size=test_size,
    random_state=42
)


test_idx, val_idx = train_test_split(
    test_idx,
    test_size=0.5,
    random_state=42
)


In [7]:
#Gather the training images together

training_file_names=[]
label_train_names=[]
for i in range(len(training_idx)):
  if len(str(training_idx[i]))<2:
    training_file_names.append('SOCprist000'+str(training_idx[i])+'.tiff')
    label_train_names.append('slice__00'+str(training_idx[i])+'.tif')
  if len(str(training_idx[i]))==2:
    training_file_names.append('SOCprist00'+str(training_idx[i])+'.tiff')
    label_train_names.append('slice__0'+str(training_idx[i])+'.tif')
  else:
    training_file_names.append('SOCprist0'+str(training_idx[i])+'.tiff')
    label_train_names.append('slice__'+str(training_idx[i])+'.tif')

In [8]:
#Gather the test images together

test_file_names=[]
label_test_names=[]
for i in range(len(test_idx)):
  if len(str(test_idx[i]))<2:
    test_file_names.append('SOCprist000'+str(test_idx[i])+'.tiff')
    label_test_names.append('slice__00'+str(test_idx[i])+'.tif')
  if len(str(test_idx[i]))==2:
    test_file_names.append('SOCprist00'+str(test_idx[i])+'.tiff')
    label_test_names.append('slice__0'+str(test_idx[i])+'.tif')
  else:
    test_file_names.append('SOCprist0'+str(test_idx[i])+'.tiff')
    label_test_names.append('slice__'+str(test_idx[i])+'.tif')

In [9]:
#Gather the validation images together

val_file_names=[]
label_val_names=[]
for i in range(len(val_idx)):
  if len(str(val_idx[i]))<2:
    val_file_names.append('SOCprist000'+str(val_idx[i])+'.tiff')
    label_val_names.append('slice__00'+str(val_idx[i])+'.tif')
  if len(str(val_idx[i]))==2:
    val_file_names.append('SOCprist00'+str(val_idx[i])+'.tiff')
    label_val_names.append('slice__0'+str(val_idx[i])+'.tif')
  else:
    val_file_names.append('SOCprist0'+str(val_idx[i])+'.tiff')
    label_val_names.append('slice__'+str(val_idx[i])+'.tif')

In [10]:
#Creating the dataset
class Trainingdataset(Dataset):
    def __init__(self, root_dir, transform=None,file_names=None,label_names=None):
        self.root_dir = root_dir
        self.image_folder = os.path.join(root_dir, 'data/')
        self.label_folder = os.path.join(root_dir, 'labels/')
        self.transform = transform
        self.file_names= file_names
        self.label_names= label_names

        self.image_filenames = sorted([f for f in os.listdir(self.image_folder) if f in self.file_names])
        self.label_filenames = sorted([f for f in os.listdir(self.label_folder) if f in self.label_names])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
      img_name = os.path.join(self.image_folder, self.image_filenames[idx])

      number1=img_name[-8:-5] #make sure the label fits with the image
      label_name=os.path.join(self.label_folder,'slice__'+str(number1)+'.tif') #finding the corresponding label

      #Loading in image and label
      image = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
      label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE)

      image=np.array(image)
      label=np.array(label)

      if self.transform is not None: #Inspired by: https://www.youtube.com/watch?v=rAdLwKJBvPM, visited the November 23rd 2023
        augmentations=self.transform(image=image,mask=label) #Adding transformations

      #Extracting image and label from augmentations
      image=augmentations["image"]
      label=augmentations["mask"]

      return image, label

In [11]:
#Defining the transformations

#Transformation for the training and validation dataset
transform = A.Compose(
    [
     A.RandomCrop(width=128, height=128), #Random cropping 128x128
     A.GaussNoise(p=0.05),  # Add Gaussian noise to 5%
     A.Normalize(
        mean=[0.5],
        std=[0.5],
        max_pixel_value=255.0,
     ), #Normalizing pixels
     ToTensorV2(),
    ]
)

#The transformations from test only consist of normalization and to tensor transform
transform_test = A.Compose(
    [
     A.Normalize(
        mean=[0.5],
        std=[0.5],
        max_pixel_value=255.0,
     ),
     ToTensorV2(),
    ]
)



In [12]:
SOC_dataset_train = Trainingdataset(root_dir='drive/My Drive//AI data/', transform = transform,file_names=training_file_names,label_names=label_train_names)
SOC_dataset_test = Trainingdataset(root_dir='drive/My Drive//AI data/', transform = transform_test,file_names=test_file_names,label_names=label_test_names)
SOC_dataset_val = Trainingdataset(root_dir='drive/My Drive//AI data/', transform = transform,file_names=val_file_names,label_names=label_val_names)

In [13]:
pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/806.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.6/806.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.1


In [14]:
#Importing accuracy metrics
from torchmetrics.classification import JaccardIndex
from torchmetrics.functional.classification import dice
from torchmetrics.classification import MulticlassAccuracy

In [15]:
batch_size = 16 #Batch size of 16

train_loader = DataLoader(SOC_dataset_train, batch_size=batch_size, shuffle=True,drop_last=False)
test_loader = DataLoader(SOC_dataset_test, batch_size=batch_size,drop_last=True)
val_loader = DataLoader(SOC_dataset_val, batch_size=batch_size,drop_last=True)


In [17]:
#Parts of this code section is from exercise 4.2-EXE-CNN-CIFAR-10.ipynb
net=UNet(n_class=3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

optimizer = optim.Adam(net.parameters(), lr=0.0001) #choosing the Adam optimizzer

#Defining the different accuracy metrics
jaccard=JaccardIndex(task="multiclass", num_classes=3).to(device) #jaccard
accuracy=MulticlassAccuracy(num_classes=3).to(device) #pixel wise
num_epochs = 50
validation_every_steps = np.ceil(len(train_loader.dataset)/batch_size) #How often the validation should be

step = 0
net.train()

#Allocating list for accuracy measures
train_accuracies_jaccard = []
train_accuracies_dice = []
train_accuracies_pixel = []
valid_accuracies_jaccard = []
valid_accuracies_dice = []
valid_accuracies_pixel = []
loss_epochs=[]
val_losses=[]
loss_train=[]
val_loss=[]

for epoch in range(num_epochs):
    train_accuracies_batches_jaccard = []
    train_accuracies_batches_dice = []
    train_accuracies_batches_pixel = []
    loss_epochs=[]

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        #Forward pass
        output = net(inputs)

        #Divide targets into classes: [0,1,2]
        un_target=targets.unique()
        targets[targets==un_target[0]]=0
        targets[targets==un_target[1]]=1
        targets[targets==un_target[2]]=2

        targets = targets.to(torch.int64)

        #compute loss function
        loss = loss_fn(output, targets)

        optimizer.zero_grad()

        #Backward pass
        loss.backward()

        optimizer.step()

        step += 1

        # Compute accuracy. Note the adding of the softmax function to obtain the probabilities
        predictions = torch.argmax(softmax(output,dim=1),dim=1) #Prediction with max of softmax output
        train_accuracies_batches_dice.append(dice(predictions,targets).cpu())
        train_accuracies_batches_jaccard.append(jaccard(predictions,targets).cpu())
        train_accuracies_batches_pixel.append(accuracy(predictions,targets).cpu())
        loss_epochs.append(loss.detach().cpu().numpy())

        if step % validation_every_steps == 0:

            # Append average training accuracy to list.
            train_accuracies_jaccard.append(np.mean(train_accuracies_batches_jaccard))
            train_accuracies_dice.append(np.mean(train_accuracies_batches_dice))
            train_accuracies_pixel.append(np.mean(train_accuracies_batches_pixel))
            loss_train.append(np.mean(loss_epochs))

            train_accuracies_batches_jaccard = []
            train_accuracies_batches_dice = []
            train_accuracies_batches_pixel = []

            # Compute accuracies on validation set.
            valid_accuracies_batches_jaccard = []
            valid_accuracies_batches_dice = []
            valid_accuracies_batches_pixel = []
            val_losses=[]
            with torch.no_grad():
                net.eval()
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    output = net(inputs)

                    un_target=targets.unique()
                    # Compute loss.
                    targets[targets==un_target[0]]=0
                    targets[targets==un_target[1]]=1
                    targets[targets==un_target[2]]=2

                    targets = targets.to(torch.int64)

                    loss = loss_fn(output, targets)

                    predictions = torch.argmax(softmax(output,dim=1),dim=1)
                    valid_accuracies_batches_dice.append(dice(predictions,targets).cpu())
                    valid_accuracies_batches_jaccard.append(jaccard(predictions,targets).cpu())
                    valid_accuracies_batches_pixel.append(accuracy(predictions,targets).cpu())
                    val_losses.append(loss.cpu())
                net.train()

            # Append average validation accuracy to list.
            valid_accuracies_jaccard.append(np.sum(valid_accuracies_batches_jaccard) / len(val_loader))
            valid_accuracies_dice.append(np.sum(valid_accuracies_batches_dice) / len(val_loader))
            valid_accuracies_pixel.append(np.sum(valid_accuracies_batches_pixel) / len(val_loader))
            val_loss.append(np.mean(val_losses))

            print(f"Step {step:<5}   training accuracy with jaccard: {train_accuracies_jaccard[-1]}")
            print(f"             training accuracy with dice: {train_accuracies_dice[-1]}")
            print(f"             training accuracy with pixel by pixel: {train_accuracies_pixel[-1]}")
            print(f"             validation accuracy with jaccard: {valid_accuracies_jaccard[-1]}")
            print(f"             validation accuracy with dice: {valid_accuracies_dice[-1]}")
            print(f"             validation accuracy with pixel by pixel: {valid_accuracies_pixel[-1]}")
            print(f"             loss in validation: {val_loss[-1]}")


print("Finished training.")

Step 13      training accuracy with jaccard: 0.5649235248565674
             training accuracy with dice: 0.6819261312484741
             training accuracy with pixel by pixel: 0.7426443696022034
             validation accuracy with jaccard: 0.314875324567159
             validation accuracy with dice: 0.530609130859375
             validation accuracy with pixel by pixel: 0.4886625607808431
             loss in validation: 1.0232611894607544
Step 26      training accuracy with jaccard: 0.8882465958595276
             training accuracy with dice: 0.9393175840377808
             training accuracy with pixel by pixel: 0.9372642040252686
             validation accuracy with jaccard: 0.27284151315689087
             validation accuracy with dice: 0.5155359903971354
             validation accuracy with pixel by pixel: 0.4509158134460449
             loss in validation: 2.83874773979187
Step 39      training accuracy with jaccard: 0.9304141402244568
             training accuracy with dic

In [22]:
#Saving the model

torch.save(net,'./2dunet.pth')

model = torch.load('./2dunet.pth')

In [50]:
test_dice=0
test_jaccard=0
test_pixel=0

#Finding the test accuracy for all test images
for inputs, targets in test_loader:
  inputs_val, targets = inputs.to(device), targets.to(device)

  un_target=targets.unique()
  # Compute loss.
  targets[targets==un_target[0]]=0
  targets[targets==un_target[1]]=1
  targets[targets==un_target[2]]=2

  targets_val = targets.to(torch.int64)

  #Determine kernel size and stride to reconstruct original image size
  kernel_size = 128
  stride = 128


  #Calculating need for padding
  pad_min=int(np.floor((kernel_size*np.ceil(inputs_val.size(-1)/kernel_size)-inputs_val.size(-1))/2))
  pad_max=int(np.ceil((kernel_size*np.ceil(inputs_val.size(-1)/kernel_size)-inputs_val.size(-1))/2))

  inputs_val=torch.nn.functional.pad(inputs_val,pad=(pad_min,pad_max,pad_min,pad_max)) #to get the size to 512 (4*128)
  targets_val=torch.nn.functional.pad(targets_val,pad=(pad_min,pad_max,pad_min,pad_max)) #to get the size to 512 (4*128)

  #Source: https://discuss.pytorch.org/t/how-to-split-tensors-with-overlap-and-then-reconstruct-the-original-tensor/70261/7?fbclid=IwAR1rdUAuDnUpVm2OwmXRaFo-l2AMLJ1RLn5bJEp6f1JcU7wR5CHpugMHc6Y
  #Visited the 23rd of November 2023
  B, C, W, H = inputs_val.size(0), inputs_val.size(1), inputs_val.size(2), inputs_val.size(3)

  #Source: https://discuss.pytorch.org/t/how-to-split-tensors-with-overlap-and-then-reconstruct-the-original-tensor/70261/7?fbclid=IwAR1rdUAuDnUpVm2OwmXRaFo-l2AMLJ1RLn5bJEp6f1JcU7wR5CHpugMHc6Y
  #Visited the 23rd of November 2023

  #Splitting target and test images into kernels
  images_split_1 = inputs_val.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride).permute(0,1,2,3,5,4)
  targets_split=targets_val.unfold(2, kernel_size, stride).unfold(1, kernel_size, stride).permute(0,1,2,4,3)

  preds=torch.empty(images_split_1.size(0),images_split_1.size(2),images_split_1.size(3),images_split_1.size(-1), images_split_1.size(-1))
  tar=targets_split

  for i in range(images_split_1.size(2)):
    for j in range(images_split_1.size(3)):
      output_val = model(images_split_1[:,:,i,j,:,:])
      predicted_val = softmax(output_val,dim=1).max(1)[1]

      preds[:,i,j,:,:]=predicted_val

  #Source: https://discuss.pytorch.org/t/how-to-split-tensors-with-overlap-and-then-reconstruct-the-original-tensor/70261/7?fbclid=IwAR1rdUAuDnUpVm2OwmXRaFo-l2AMLJ1RLn5bJEp6f1JcU7wR5CHpugMHc6Y
  #Visited the 23rd of November 2023

  #Patching images back together
  patches = preds.contiguous().view(B, -1, kernel_size*kernel_size)
  patches = patches.permute(0,2,1)
  patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)


  output = torch.nn.functional.fold(
      patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
  output = output.to(torch.int64)
  for i in range(output.size(0)):
    test_dice+=dice(output[i,0,pad_min:-pad_max,pad_min:-pad_max],targets_val[i,pad_min:-pad_max,pad_min:-pad_max].cpu())
    test_jaccard+=jaccard(output[i,0,pad_min:-pad_max,pad_min:-pad_max].to(device),targets_val[i,pad_min:-pad_max,pad_min:-pad_max])
    test_pixel+=accuracy(output[i,0,pad_min:-pad_max,pad_min:-pad_max].to(device),targets_val[i,pad_min:-pad_max,pad_min:-pad_max])

#Final test accuracies
tdice=test_dice/(batch_size*len(val_loader))
tjaccard=test_jaccard/(batch_size*len(val_loader))
tpixel=test_pixel/(batch_size*len(val_loader))


In [None]:
preds=torch.empty(images_split_1.size(0),images_split_1.size(2),images_split_1.size(3),images_split_1.size(-1), images_split_1.size(-1))
tar=targets_split

import matplotlib.pylab as pylab
params = {'legend.fontsize': 'x-large',
          'figure.figsize': (15, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

#Plotting the label probabilities, predicted labels and difference between labels for 128x128 kernels of the test images
for i in range(images_split_1.size(2)):
  for j in range(images_split_1.size(3)):
    output_val = net(images_split_1[:,:,i,j,:,:])
    predicted_val = softmax(output_val,dim=1).max(1)[1]

    output=softmax(output_val,dim=1)
    targets=tar[0,i,j,:,:]
    predicted=predicted_val
    inputs=images_split_1[:,:,i,j,:,:]

    fig,axs = plt.subplots(2,2,figsize=(15,10),dpi=200)
    im1=axs[0,0].imshow(output.detach().cpu().numpy()[0,0,:,:])
    axs[0,0].set_title('Label 1, probability')
    axs[0,0].set_xlabel('Number of pixels [#]')
    axs[0,0].set_ylabel('Number of pixels [#]')
    cbar1=plt.colorbar(im1, ax=axs[0,0],label='Probability', pad=0.1)
    im1.set_clim(0,1)
    im2=axs[1,0].imshow(output.detach().cpu().numpy()[0,1,:,:])
    axs[1,0].set_title('Label 2, probability')
    plt.colorbar(im2, ax=axs[1,0],label='Probability', pad=0.1)
    im2.set_clim(0,1)
    im3=axs[0,1].imshow(output.detach().cpu().numpy()[0,2,:,:])
    axs[1,0].set_xlabel('Number of pixels [#]')
    axs[1,0].set_ylabel('Number of pixels [#]')
    axs[0,1].set_title('Label 3,probabiliy')
    plt.colorbar(im3, ax=axs[0,1],label='Probability', pad=0.1)
    im3.set_clim(0,1)
    axs[0,1].set_xlabel('Number of pixels [#]')
    axs[0,1].set_ylabel('Number of pixels [#]')
    im4=axs[1,1].imshow(targets.cpu())
    axs[1,1].set_title('Target labels')
    axs[1,1].set_xlabel('Number of pixels [#]')
    axs[1,1].set_ylabel('Number of pixels [#]')
    plt.colorbar(im4, ax=axs[1,1],ticks=[0,1,2],label='Label number', pad=0.1)
    #fig.suptitle('Probability distribution for each class', fontsize=16)
    plt.tight_layout()
    plt.show()

    fig,axs=plt.subplots(2,2,figsize=(15,10),dpi=200)
    im1=axs[1,0].imshow(predicted[0,:,:].cpu())
    axs[1,0].set_title('Predicted labels')
    axs[1,0].set_xlabel('Number of pixels [#]')
    axs[1,0].set_ylabel('Number of pixels [#]')
    plt.colorbar(im1, ax=axs[1, 0],ticks=[0,1,2],label='Label number', pad=0.1)
    im2=axs[1,1].imshow(targets.cpu())
    axs[1,1].set_title('Target labels')
    plt.colorbar(im2, ax=axs[1,1],ticks=[0,1,2],label='Label number', pad=0.1)
    axs[1,1].set_xlabel('Number of pixels [#]')
    axs[1,1].set_ylabel('Number of pixels [#]')
    im3=axs[0,0].imshow(inputs[0,0,:,:].cpu())
    axs[0,0].set_title('Original image')
    plt.colorbar(im3, ax=axs[0,0],label='Normalized intensities', pad=0.1)
    axs[0,0].set_xlabel('Number of pixels [#]')
    axs[0,0].set_ylabel('Number of pixels [#]')
    im4=axs[0,1].imshow(np.abs(targets.detach().cpu().numpy()-predicted[0,:,:].detach().cpu().numpy()))
    axs[0,1].set_title('Difference in labels')
    axs[0,1].set_xlabel('Number of pixels [#]')
    axs[0,1].set_ylabel('Number of pixels [#]')
    plt.colorbar(im4, ax=axs[0,1],ticks=[0,1,2],label='Absolute difference in labels', pad=0.1)

    #fig.suptitle('Comparison of target and predicted mask', fontsize=16)
    plt.tight_layout()
    plt.show()


    preds[:,i,j,:,:]=predicted_val

In [None]:
#Source: https://discuss.pytorch.org/t/how-to-split-tensors-with-overlap-and-then-reconstruct-the-original-tensor/70261/7?fbclid=IwAR1rdUAuDnUpVm2OwmXRaFo-l2AMLJ1RLn5bJEp6f1JcU7wR5CHpugMHc6Y
#Visited the 23rd of November 2023
patches = preds.contiguous().view(B, -1, kernel_size*kernel_size)
patches = patches.permute(0,2,1)
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)

output = torch.nn.functional.fold(
    patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]


In [None]:
#Example of 1 image that is patched together (used in the poster)
import matplotlib.pylab as pylab
params = {'legend.fontsize': 'x-large',
          'figure.figsize': (15, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

plt.figure(dpi=200)
plt.imshow(output[0,0,:,:].detach().cpu())
plt.colorbar(ticks=[0,1,2],label='Label number')
plt.title('Predicted labels')
plt.xlabel('Number of pixels [#]')
plt.ylabel('Number of pixels [#]')

plt.figure(dpi=200)
plt.imshow(targets_val[0,:,:].detach().cpu())
plt.colorbar(ticks=[0,1,2],label='Label number')
plt.title('Target labels')
plt.xlabel('Number of pixels [#]')
plt.ylabel('Number of pixels [#]')

plt.figure(dpi=200)
plt.imshow(np.abs(targets_val[0,:,:].detach().cpu().numpy()-output[0,0,:,:].detach().cpu().numpy()))
plt.colorbar(ticks=[0,1,2],label='Absolute label difference')
plt.title('Prediction vs. targets')
plt.xlabel('Number of pixels [#]')
plt.ylabel('Number of pixels [#]')


In [None]:
#Plots of training, validation and test accuracies and loss function
import matplotlib.pylab as pylab
params = {'legend.fontsize': 'x-large',
          'figure.figsize': (15, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

fig,axs=plt.subplots(1,2,figsize=(15,5))

axs[0].plot(train_accuracies_dice,color='red',label='Dice')
axs[0].plot(train_accuracies_jaccard,color='blue',label='Jaccard')
axs[0].plot(train_accuracies_pixel,color='green',label='Pixel-wise')
axs[0].grid()
axs[0].legend(loc='lower right')
axs[0].set_title('Training accuracies')
axs[0].set_xlabel('Epoch number')
axs[0].set_ylabel('Accuracy')
axs[0].set_ylim([0,1])

axs[1].plot(valid_accuracies_dice,color='red',label='Dice')
axs[1].plot(valid_accuracies_jaccard,color='blue',label='Jaccard')
axs[1].plot(valid_accuracies_pixel,color='green',label='Pixel-wise')
axs[1].legend(loc='lower right',fontsize=16)
axs[1].tick_params(axis='both', which='minor', labelsize=16)
axs[1].grid()
axs[1].set_title('Validation accuracies',fontsize=20)
axs[1].set_xlabel('Epoch number',fontsize=18)
axs[1].set_ylabel('Accuracy',fontsize=18)
axs[1].set_ylim([0,1])
plt.tight_layout()

params = {'legend.fontsize': 'x-large',
          'figure.figsize': (15, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

fig,axs=plt.subplots(1,3,figsize=(15,5))
axs[0].plot(train_accuracies_dice,color='red',label='Training accuracy')
axs[0].plot(valid_accuracies_dice,color='blue',label='Validation accuracy')
axs[0].set_title('Dice coefficient accuracies')
axs[0].set_xlabel('Epoch number')
axs[0].grid()
axs[0].legend(loc='lower right')
axs[0].set_ylabel('Accuracy')
axs[0].set_ylim([0,1])

axs[1].plot(train_accuracies_jaccard,color='red',label='Training accuracy')
axs[1].plot(valid_accuracies_jaccard,color='blue',label='Validation accuracy')
axs[1].set_title('Jaccard coefficient accuracies')
axs[1].set_xlabel('Epoch number')
axs[1].grid()
axs[1].legend(loc='lower right')
axs[1].set_ylabel('Accuracy')
axs[1].set_ylim([0,1])

axs[2].plot(train_accuracies_pixel,color='red',label='Training accuracy')
axs[2].plot(valid_accuracies_pixel,color='blue',label='Validation accuracy')
axs[2].set_title('Pixel-wise coefficient accuracies')
axs[2].set_xlabel('Epoch number')
axs[2].grid()
axs[2].legend(loc='lower right')
axs[2].set_ylabel('Accuracy')
axs[2].set_ylim([0,1])

plt.tight_layout()

params = {'legend.fontsize': 'x-large',
          'figure.figsize': (15, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

fig,axs=plt.subplots(1,1,figsize=(15,5))
axs.plot(loss_train,color='blue',label='Training loss')
axs.plot(val_loss,color='red',label='Validation loss')
axs.legend()
axs.grid()
axs.set_title('Loss')
axs.set_xlabel('Epoch number')
axs.set_ylabel('Loss')
axs.set_ylim([0,1])

plt.tight_layout()