<a href="https://colab.research.google.com/github/s183796/Group5_repos/blob/main/3D_unet_test_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Image, display, clear_output
import numpy as np
import seaborn as sns
import pandas as pd
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu
from torch.nn.functional import softmax
import PIL.Image
import os
import torchvision
import cv2

from torchvision import transforms
from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, Dataset, Subset

import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF
import glob
import random

from google.colab import drive
drive.mount('/content/drive')


print('loaded packages')

Mounted at /content/drive
loaded packages


In [2]:
pip install torchio

Collecting torchio
  Downloading torchio-0.19.3-py2.py3-none-any.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m173.0/173.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting Deprecated (from torchio)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl (9.6 kB)
Collecting SimpleITK!=2.0.*,!=2.1.1.1 (from torchio)
  Downloading SimpleITK-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
Collecting colorama<0.5.0,>=0.4.3 (from typer[all]->torchio)
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting shellingham<2.0.0,>=1.3.0 (from typer[all]->torchio)
  Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB)
Installing collected packages: SimpleITK, shellingham, Deprecated, colorama, torchio
Successfully installed Deprecated-1.2.14 SimpleITK-2.3.1 colorama-0.4.6 shellingham-1.5.4 to

In [3]:
# Source: https://towardsdatascience.com/cook-your-first-u-net-in-pytorch-b3297a844cf3, visited the 16th of November 2023
# Modifications have been made to the original code with changing the input and output sizes

class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        # input: 1x128x128
        self.e11 = nn.Conv3d(1, 64, kernel_size=3,padding=1)
        self.e12 = nn.Conv3d(64, 64, kernel_size=3,padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) #64x64x64

        self.e21 = nn.Conv3d(64, 128, kernel_size=3,padding=1)
        self.e22 = nn.Conv3d(128, 128, kernel_size=3,padding=1)
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) #32x32x128

        self.e31 = nn.Conv3d(128, 256, kernel_size=3,padding=1)
        self.e32 = nn.Conv3d(256, 256, kernel_size=3,padding=1)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) #16x16x256

        self.e41 = nn.Conv3d(256, 512, kernel_size=3,padding=1)
        self.e42 = nn.Conv3d(512, 512, kernel_size=3,padding=1)

        self.upconv2 = nn.ConvTranspose3d(512,256,kernel_size=2,stride=2) #
        self.d21 = nn.Conv3d(512,256,kernel_size=3,padding=1)
        self.d22 = nn.Conv3d(256,256,kernel_size=3,padding=1)

        self.upconv3 = nn.ConvTranspose3d(256,128,kernel_size=2,stride=2)
        self.d31 = nn.Conv3d(256,128,kernel_size=3,padding=1)
        self.d32 = nn.Conv3d(128,128,kernel_size=3,padding=1)

        self.upconv4 = nn.ConvTranspose3d(128,64,kernel_size=2,stride=2)
        self.d41 = nn.Conv3d(128,64,kernel_size=3,padding=1)
        self.d42 = nn.Conv3d(64,64,kernel_size=3,padding=1)

        self.outconv = nn.Conv3d(64, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe11 = F.relu(self.e11(x))
        xe12 = F.relu(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = F.relu(self.e21(xp1))
        xe22 = F.relu(self.e22(xe21))
        xp2 = self.pool2(xe22)


        xe31 = F.relu(self.e31(xp2))
        xe32 = F.relu(self.e32(xe31))

        xp3 = self.pool3(xe32)

        xe41 = F.relu(self.e41(xp3))
        xe42 = F.relu(self.e42(xe41))

        # Decoder
        xup2 = self.upconv2(xe42)
        xcat2 = torch.cat([xup2, xe32], dim=1)


        xup31 = F.relu(self.d21(xcat2))
        xup32 = F.relu(self.d22(xup31))
        xup3 = self.upconv3(xup32)
        xcat3 = torch.cat([xup3, xe22], dim=1)

        xup41 = F.relu(self.d31(xcat3))
        xup42 = F.relu(self.d32(xup41))

        xup4 = self.upconv4(xup42)
        xcat4 = torch.cat([xup4, xe12], dim=1)
        xup51 = F.relu(self.d41(xcat4))
        xup52 = F.relu(self.d42(xup51))

        out = self.outconv(xup52)

        output = out

        return output

print("Defined Unet")

Defined Unet


In [4]:
#Setting up hyper parameters, from exercise week 6
loss_fn =  nn.CrossEntropyLoss()

In [5]:
#Creating dataset
class SOCDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_folder = os.path.join(root_dir, 'data/')
        self.label_folder = os.path.join(root_dir, 'labels/')
        self.image_filenames = sorted([f for f in os.listdir(self.image_folder) if f.endswith('.tiff')])
        self.label_filenames = sorted([f for f in os.listdir(self.label_folder) if f.endswith('.tif')])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
      img_name = os.path.join(self.image_folder, self.image_filenames[idx])

      number1=img_name[-8:-5] #Making sure image and label match
      label_name=os.path.join(self.label_folder,'slice__'+str(number1)+'.tif')

      image = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
      label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE)
      image=torch.from_numpy(np.array(image))
      label=torch.from_numpy(np.array(label))

      return image, label


In [6]:
SOC_dataset = SOCDataset(root_dir='drive/My Drive//AI data/') #change path if running on the HPC

print('loaded data')

images=[]
labels=[]
for i in range(len(SOC_dataset)):
  image,label=SOC_dataset[i]

  images.append(image)
  labels.append(label)

loaded data


In [7]:
def split_cube(volume,cube_size):

  vol_split=volume.unfold(2, cube_size, cube_size).unfold(1, cube_size, cube_size).unfold(0, cube_size, cube_size)

  vol_split=vol_split.reshape([vol_split.size(0)**3,cube_size,cube_size,cube_size])
  return vol_split

In [8]:

#creating large volumes
im_vol=torch.stack(images)
labels_vol=torch.stack(labels)

#Creating sub volumes
im_vol=split_cube(im_vol,64)
label_vol=split_cube(labels_vol,64)

In [9]:
import torchio as tio
from re import I

#Splitting data in images and labelled images
#This is implemented with the torchio workflow
elements = []
for i in range(im_vol.size(0)):
    element = tio.Subject(
        image_sub=tio.ScalarImage(tensor=im_vol.unsqueeze(0)[:,i,:,:]),
        label_sub=tio.LabelMap(tensor=label_vol.unsqueeze(0)[:,i,:,:]),
    )
    elements.append(element)
dataset = tio.SubjectsDataset(elements)

In [10]:

#Adding normalization of pixel intensities,
transforms = (
    tio.ZNormalization(masking_method=tio.ZNormalization.mean)
)

transform = transforms

In [11]:
#Splitting in training, validation and test data

#splitting ratios
training_size=0.7
val_split=0.5

#Find number of images in train, val and test
train_number = int(training_size * len(dataset))
val_number = len(dataset) - train_number
test_number = val_number*val_split

#splitting of training and validation set
train_val = train_number, val_number
train_images, val_images = torch.utils.data.random_split(elements, train_val)

val_test = int(val_number*val_split), int(val_number*val_split)+1
test_images,val_images = torch.utils.data.random_split(val_images,val_test)

#Creating datasets with pixel normalization
train_set = tio.SubjectsDataset(
    train_images, transform=transform)

val_set = tio.SubjectsDataset(
    val_images, transform=transform)

test_set = tio.SubjectsDataset(
    test_images, transform=transform)

print('Training set:', len(train_set), 'subjects')
print('Validation set:', len(val_set), 'subjects')
print('Test set:', len(test_set), 'subjects')


Training set: 240 subjects
Validation set: 52 subjects
Test set: 51 subjects


In [12]:
import multiprocessing
num_workers = 4 #adapted to the HPC

#Train batch size of 16, validation batch size of 32
training_batch_size = 16
validation_batch_size = 2 * training_batch_size

#Splitting in training, test and validation
training_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=training_batch_size,
    shuffle=True,
    num_workers=num_workers,
)

val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=validation_batch_size,
    num_workers=num_workers,
)

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=validation_batch_size,
    num_workers=num_workers,
)



In [13]:
pip install torchmetrics


Collecting torchmetrics
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.1


In [14]:
model = torch.load('drive/My Drive//Models/3dunet.pth')


In [16]:
#Importing accuracy metrics
from torchmetrics.classification import JaccardIndex
from torchmetrics.functional.classification import dice
from torchmetrics.classification import MulticlassAccuracy

In [17]:
from torchvision.utils import make_grid

test_accuracies_dice=0
test_accuracies_jaccard=0
test_accuracies_pixel=0
l_test=0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
jaccard=JaccardIndex(task="multiclass", num_classes=3).to(device) #jaccard
accuracy=MulticlassAccuracy(num_classes=3).to(device) #pixel wise



for subjects_batch in test_loader:
  inputs = subjects_batch['image_sub'][tio.DATA].to(device)
  targets = subjects_batch['label_sub'][tio.DATA].to(device)
  output = model(inputs)

  un_target=targets.unique()
  targets[targets==un_target[0]]=0
  targets[targets==un_target[1]]=1
  targets[targets==un_target[2]]=2

  targets = targets.to(torch.int64) [:,0,:,:,:]

  #Computing test accuracies
  predicted = torch.argmax(softmax(output,dim=1),dim=1)
  predicted=predicted.to(torch.int64)
  for i in range(predicted.size(0)):
    test_accuracies_dice+=dice(predicted[i,:,:,:],targets[i,:,:,:])
    test_accuracies_jaccard+=jaccard(predicted[i,:,:,:],targets[i,:,:,:])
    test_accuracies_pixel+=accuracy(predicted[i,:,:,:],targets[i,:,:,:])
  l_test+=predicted.size(0) #save size of batchsize

#Average test accuracy
test_dice=(test_accuracies_dice/l_test).cpu()
test_jaccard=(test_accuracies_jaccard/l_test).cpu()
test_pixel=(test_accuracies_pixel/l_test).cpu()

print("Finished test loader")



OutOfMemoryError: ignored

In [None]:
print(f"test_accuracies_dice: {test_dice}")
print(f"test_accuracies_jaccard: {test_jaccard}")
print(f"test_accuracies_pixel: {test_pixel}")

In [None]:
label1=output[0,0,:,:,:].cpu().detach().numpy()
label2=output[0,1,:,:,:].cpu().detach().numpy()
label3=output[0,2,:,:,:].cpu().detach().numpy()

predicted=predicted.cpu().detach().numpy() [0,:,:,:]

params = {'legend.fontsize': 'x-large',
          'figure.figsize': (5, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

x, y, z = np.meshgrid(np.arange(label1.shape[0]),np.arange(label1.shape[1]),np.arange(label1.shape[2]), indexing='ij')
fig = plt.figure(dpi=200)
ax = fig.add_subplot(111, projection='3d')
im=ax.scatter(x, y, z, c=predicted.flatten(), cmap='viridis', marker='o')
ax.set_xlabel('Number of pixels [#]')
fig.colorbar(im,ticks=[0,1,2],label='Label number', pad = 0.1, fraction = 0.05)
plt.title('Prediction')
plt.tight_layout()
plt.show()
plt.savefig('Prediction.png')


In [None]:

params = {'legend.fontsize': 'x-large',
          'figure.figsize': (5, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

targets=targets.cpu().detach().numpy() [0,:,:,:]
x, y, z = np.meshgrid(np.arange(label1.shape[0]),np.arange(label1.shape[1]),np.arange(label1.shape[2]), indexing='ij')
fig = plt.figure(dpi=200)
ax = fig.add_subplot(111, projection='3d')
im2=ax.scatter(x, y, z, c=targets.flatten(), cmap='viridis', marker='o')
ax.set_xlabel('Number of pixels [#]')
fig.colorbar(im2,ticks=[0,1,2],label='Label number', pad = 0.1, fraction = 0.05)
plt.title('Target')
plt.tight_layout()
plt.show()
plt.savefig('Target.png')

params = {'legend.fontsize': 'x-large',
          'figure.figsize': (5, 5),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
pylab.rcParams.update(params)

#Difference plot, example with one image
fig = plt.figure(dpi=200)
ax = fig.add_subplot(111, projection='3d')
im3=ax.scatter(x, y, z, c=np.abs(targets.flatten()-predicted.flatten()), cmap='viridis', marker='o')
ax.set_xlabel('Number of pixels [#]')
fig.colorbar(im3,ticks=[0,1,2],label='Difference in label number', pad = 0.1, fraction = 0.05)
plt.title('Absolute difference in labels')
plt.tight_layout()
plt.show()
plt.savefig('Diff')


#Splitting volume cube in slices
plt.figure(dpi=200)
plt.imshow(np.abs(targets[0,:,32,:]-predicted[:,32,:]))
plt.colorbar()
plt.title('Horizontal slice, layer 32')
plt.tight_layout()
plt.show()
plt.savefig('Horizontal_slice_32')

plt.figure(dpi=200)
plt.imshow(np.abs(targets[0,:,:,15]-predicted[:,:,15]))
plt.colorbar()
plt.title('Vertical slice, layer 15')
plt.tight_layout()
plt.show()
plt.savefig('Horizontal_slice_15')

plt.figure(dpi=200)
plt.imshow(np.abs(targets[0,:,45,:]-predicted[:,45,:]))
plt.colorbar()
plt.title('Horizontal slice, layer 45')
plt.tight_layout()
plt.show()
plt.savefig('Horizontal_slice_45')

plt.figure(dpi=200)
plt.imshow(np.abs(targets[0,:,:,34]-predicted[:,:,34]))
plt.colorbar()
plt.title('Vertical slice, layer 34')
plt.tight_layout()
plt.show()
plt.savefig('Horizontal_slice_34')