In [None]:
import nibabel as nib
from nibabel import processing

from skimage import io
from skimage.transform import resize
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

import numpy as np
import os
import glob
import warnings
import shutil

import SimpleITK as sitk
# from nipype.interfaces.ants import N4BiasFieldCorrection
from scipy import misc

from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import TensorDataset, DataLoader
import argparse
import re
from pathlib import Path

In [None]:
# check if GPU is available 
cuda = torch.cuda.is_available()
print("GPU available:", cuda)

In [None]:
# grab all of the images and stack them. Do the same with the masks
data_root_folder = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data_3d/')
scans_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/lung/')
masks_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/masks/')

# grab all the images in the directories
scans_nifti = sorted(scans_dir.glob('*.nii'))
masks_nifti = sorted(masks_dir.glob('*.nii'))

In [None]:
# get information from each nifti file

def load_scan(scans_nifti):
  result_data = []
  result_img_format = []
  paths = []
  for scan in scans_nifti:
    loaded_scan = nib.load(scan)
    data = loaded_scan.get_fdata()
    result_data.append(np.squeeze(data))
    result_img_format.append(loaded_scan)
    paths.append(scan)
  return result_data, result_img_format, paths

In [None]:
output_scans, scan_imgs, scan_list = load_scan(scans_nifti) # [load_scan(scan) for scan in scans_nifti]
output_masks, mask_imgs, mask_list = load_scan(masks_nifti) # [load_scan(mask) for mask in masks_nifti]

In [None]:
scan_slices_train_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data_3d/train/scan/')
mask_slices_train_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data_3d/train/mask/')

scan_slices_valid_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data_3d/valid/scan/')
mask_slices_valid_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data_3d/valid/mask/')

scan_slices_test_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data_3d/test/scan/')
mask_slices_test_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data_3d/test/mask/')

In [None]:
num_train_scan = scan_slices_train_dir.glob('*.nii.gz')
num_train_scan = list(num_train_scan)

num_valid_scan = scan_slices_valid_dir.glob('*.nii.gz')
num_valid_scan = list(num_valid_scan)

num_test_scan = scan_slices_test_dir.glob('*.nii.gz')
num_test_scan = list(num_test_scan)


In [None]:
# separate out scan number from organization for train/valid/test
def get_ind_3d(path):
  scan_num = int(re.findall(r'\d+', path)[0])
  if 'valid' in path:
    start = 56
  elif 'test' in path:
    start = 72
  else:
    start = 1
  abs_ind = start + scan_num - 1
  return abs_ind

In [None]:
class BasicDataset(TensorDataset):

    def __init__(self, folder, n_sample=None):
        self.folder = os.path.join(data_root_folder, folder)
        self.imgs_dir = os.path.join(self.folder, 'scan')
        self.masks_dir = os.path.join(self.folder, 'mask')


        self.imgs_file = sorted(glob.glob(os.path.join(self.imgs_dir, '*.nii.gz')))
        self.masks_file = sorted(glob.glob(os.path.join(self.masks_dir, '*.nii.gz')))


        assert len(self.imgs_file) == len(self.masks_file), 'There are some missing images or masks in {0}'.format(folder)

        # If n_sample is not None (It has been set by the user)
        if not n_sample or n_sample > len(self.imgs_file):
            n_sample = len(self.imgs_file)

        self.n_sample = n_sample
        self.ids = list([i+1 for i in range(n_sample)])

    # This function returns the lenght of the dataset (AKA number of samples in that set)
    def __len__(self):
        return self.n_sample

    
    def __getitem__(self, i):
        idx = self.ids[i]
        scan_path = os.path.join(self.imgs_dir, f"scan_{str(i+1)}.nii.gz")
        mask_path = os.path.join(self.imgs_dir, f"mask_{str(i+1)}.nii.gz")


        scan_num_abs  = get_ind_3d(scan_path)

        img = nib.load(os.path.join(self.imgs_dir, f"scan_{str(i+1)}.nii.gz")).get_fdata()
        mask = nib.load(os.path.join(self.masks_dir, f"mask_{str(i+1)}.nii.gz")).get_fdata()

        img = img[:,:,0:32]
        mask = mask[:,:,0:32]

        scan_max = np.max(img)
        img = np.array(img) / scan_max

        # Add an axis to the mask array so that it is in [channel, width, height] format.
        img = np.expand_dims(img, axis=0)
        mask = np.expand_dims(mask, axis=0)



        return {
            'scan': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor),
            'img_id': idx,
            'scan_num': scan_num_abs,
            'max_pixel': scan_max
        }

In [None]:
# Create train, validation, and test dataset instances
train_dataset = BasicDataset('train')
valid_dataset = BasicDataset('valid')
test_dataset = BasicDataset('test')

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, num_workers=2, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=2, pin_memory=True)

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x);


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        y = self.maxpool_conv(x)
        return y;


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up_conv = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
            nn.Conv3d(in_channels, out_channels, kernel_size=1, padding=0),
        )
        self.conv = DoubleConv(out_channels * 2, out_channels)


    def forward(self, x1, x2):
        x1 = self.up_conv(x1)
        x = torch.cat([x1, x2], dim=1)
        x = self.conv(x)
        return x

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv_sigmoid = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.conv_sigmoid(x)

In [None]:
# 3D U-Net implementation 
class UNet(nn.Module):
    def __init__(self, name, n_channels, n_classes):
        super(UNet, self).__init__()
        self.name = name
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inputL = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outputL = OutConv(64, n_classes)

    def forward(self, x):
        x0 = self.inputL(x)

        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.up4(x, x0)

        x = self.outputL(x)

        return x

In [None]:
my_UNet = UNet('MyUNet', n_channels=1, n_classes=1)
my_UNet.cuda()

In [None]:
# use ADAM optimizer and BCE loss 
optimizer = torch.optim.Adam(my_UNet.parameters(), lr=0.001)
loss_function = nn.BCELoss()

In [None]:
for batch in test_dataloader:
    sample_batch = batch
    break

#  network prediction
with torch.no_grad():
    y_pred = my_UNet(sample_batch['scan'].cuda())

In [None]:
def dice_coeff_binary(y_pred, y_true):
        """Values must be only zero or one."""
        eps = 0.0001
        inter = torch.dot(y_pred.view(-1), y_true.view(-1))
        union = torch.sum(y_pred) + torch.sum(y_true)
        return ((2 * inter.float() + eps) / (union.float() + eps)).cpu().numpy()

In [None]:
# training loop
def train_net(net, epochs, train_dataloader, valid_dataloader, optimizer, loss_function):

    if not os.path.isdir('{0}'.format(net.name)):
        os.mkdir('{0}'.format(net.name))

    n_train = len(train_dataloader)
    n_valid = len(valid_dataloader)

    train_loss = list()
    valid_loss = list()
    train_dice = list()
    valid_dice = list()

    for epoch in range(epochs):

        net.train()
        train_batch_loss = list()
        train_batch_dice = list()

        for i, batch in enumerate(train_dataloader):

            # Load a batch 
            imgs = batch['scan'].cuda()
            true_masks = batch['mask'].cuda()

            # Produce the estimated mask 
            y_pred = net(imgs)

            # Compute the loss for this batch 
            loss = loss_function(y_pred, true_masks)
            batch_loss = loss.item()
            train_batch_loss.append(batch_loss)

            # Make the thresholded mask 
            pred_binary = (y_pred > 0.5).float()                    

            # Compute the DICE score 
            batch_dice_score = dice_coeff_binary(pred_binary, true_masks)
            train_batch_dice.append(batch_dice_score)


            # Reset gradient 
            optimizer.zero_grad()

            # Compute losses
            loss.backward()

            # Update  weights
            optimizer.step()

            # Print the progress
            print(f'EPOCH {epoch + 1}/{epochs} - Training Batch {i+1}/{n_train} - Loss: {batch_loss}, DICE score: {batch_dice_score}', end='\r')

        average_training_loss = np.array(train_batch_loss).mean()
        average_training_dice = np.array(train_batch_dice).mean()
        train_loss.append(average_training_loss)
        train_dice.append(average_training_dice)



        net.eval()
        valid_batch_loss = list()
        valid_batch_dice = list()


        with torch.no_grad():
            for i, batch in enumerate(valid_dataloader):

                # Load a batch and pass it to the GPU
                imgs = batch['scan'].cuda()
                true_masks = batch['mask'].cuda()

                # Produce the estimated mask using current weights
                y_pred = net(imgs)

                # Compute the loss for this batch and append it to the epoch loss
                loss = loss_function(y_pred, true_masks)
                batch_loss = loss.item()
                valid_batch_loss.append(batch_loss)

                # Make the thresholded mask to compute the DICE score
                pred_binary = (y_pred > 0.5).float()                    # You can change the probablity threshold!

                # Compute the DICE score for this batch and append it to the epoch dice
                batch_dice_score = dice_coeff_binary(pred_binary, true_masks)
                valid_batch_dice.append(batch_dice_score)

                # Print the progress
                print(f'EPOCH {epoch + 1}/{epochs} - Validation Batch {i+1}/{n_valid} - Loss: {batch_loss}, DICE score: {batch_dice_score}', end='\r')

        average_validation_loss = np.array(valid_batch_loss).mean()
        average_validation_dice = np.array(valid_batch_dice).mean()
        valid_loss.append(average_validation_loss)
        valid_dice.append(average_validation_dice)

        print(f'EPOCH {epoch + 1}/{epochs} - Training Loss: {average_training_loss}, Training DICE score: {average_training_dice}, Validation Loss: {average_validation_loss}, Validation DICE score: {average_validation_dice}')

    return train_loss, train_dice, valid_loss, valid_dice

In [None]:
model_path = '/content/drive/MyDrive/Heart_Segmentation/3d_unet.pt'
EPOCHS = 50
train_loss, train_dice, valid_loss, valid_dice = train_net(my_UNet, EPOCHS, train_dataloader, valid_dataloader, optimizer, loss_function)
torch.save(my_UNet.state_dict(), model_path)

In [None]:
plt.figure(figsize=(15,8))
plt.suptitle('Learning Curve', fontsize=18)

plt.subplot(1,2,1)
plt.plot(np.arange(EPOCHS)+1, train_loss, '-o', label='Training Loss')
plt.plot(np.arange(EPOCHS)+1, valid_loss, '-o', label='Validation Loss')
# plt.xticks(np.arange(EPOCHS)+1)
plt.xlabel('Epoch', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.legend()

plt.subplot(1,2,2)
plt.plot(np.arange(EPOCHS)+1, train_dice, '-o', label='Training DICE score')
plt.plot(np.arange(EPOCHS)+1, valid_dice, '-o', label='Validation DICE score')
# plt.xticks(np.arange(EPOCHS)+1)
plt.xlabel('Epoch', fontsize=15)
plt.ylabel('DICE score', fontsize=15)
plt.yticks(np.arange(0.6, 1, 0.05))
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# store masks in array
masks_test = []
for batch in batches:
    with torch.no_grad():
        y_pred = my_UNet(batch['scan'].cuda())
    for i in range(1):
        pred_msk = ((y_pred.cpu().numpy()[i][0,:,:] > 0.5)).astype('uint8')
        # print(pred_msk.shape)
        # print(type(pred_msk))
        masks_test.append(pred_msk)

In [None]:
# storing outputs
for i in range(len(masks_test)):
    mask_nifti = mask_imgs[i + 71]
    scan_nifti = scan_imgs[i + 71]
    scan_array = output_scans[i+71]
    scan_array_cropped = scan_array[:,:,0:32]
    curr_mask = masks_test[i]

    curr_scan_nifti = processing.conform(nib.Nifti1Image(scan_array_cropped, scan_nifti.affine), out_shape=scan_array_cropped.shape)
    curr_mask_nifti = processing.conform(nib.Nifti1Image(curr_mask, mask_nifti.affine), out_shape=curr_mask.shape)# save curr mask as a nifti at output path

    mask_nifti_path = '/Heart_Segmentation/output_3d/mask_'+  str(i + 1) + '.nii.gz'
    scan_cropped_path = '/Heart_Segmentation/output_3d/scan_cropped_'+  str(i + 1) + '.nii.gz'
    nib.save(curr, mask_nifti_path)
    nib.save(curr_scan_nifti, scan_cropped_path)