In [None]:
import nibabel as nib
from nibabel import processing

from skimage import io
from skimage.transform import resize
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

import numpy as np
import os
import glob
import warnings
import shutil

import SimpleITK as sitk
# from nipype.interfaces.ants import N4BiasFieldCorrection
from scipy import misc

from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import TensorDataset, DataLoader
import argparse
from pathlib import Path
import albumentations as A
from sklearn.metrics import confusion_matrix, accuracy_score
import re

In [None]:
# grab all of the images and stack them. Do the same with the masks
data_root_folder = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data/')
scans_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/lung/')
masks_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/masks/')

# grab all the images in the directories
scans_nifti = sorted(scans_dir.glob('*.nii'))
masks_nifti = sorted(masks_dir.glob('*.nii'))

In [None]:
# get information from each nifti file
def load_scan(scans_nifti):
  result_data = []
  result_img_format = []
  paths = []
  for scan in scans_nifti:
    loaded_scan = nib.load(scan)
    data = loaded_scan.get_fdata()
    result_data.append(np.squeeze(data))
    result_img_format.append(processing.conform(nib.Nifti1Image(np.squeeze(data), loaded_scan.affine), out_shape=data.shape))
    paths.append(scan)
  return result_data, result_img_format, paths

In [None]:
output_scans, scan_imgs, scan_list = load_scan(scans_nifti) # [load_scan(scan) for scan in scans_nifti]
output_masks, mask_imgs, mask_list = load_scan(masks_nifti) # [load_scan(mask) for mask in masks_nifti]

In [None]:
# helper functions for working with data 

def is_binary_data(file_path):
    # Load the MRI volume using nibabel
    img = nib.load(file_path)

    # Get the data array from the image
    data = img.get_fdata()

    # Check if the data is binary
    unique_values = np.unique(data)
    is_binary = len(unique_values) == 2

    return is_binary

def find_max_voxel_value(file_path):
    # Load the MRI volume using nibabel
    img = nib.load(file_path)

    # Get the data array from the image
    data = img.get_fdata()

    # Find the maximum voxel value
    max_voxel_value = np.max(data)
    min_voxel_value = np.min(data)

    return max_voxel_value, min_voxel_value

In [None]:
scan_slices_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/old/scan_slices/')
mask_slices_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/old/scan_slices/mask_slices/')

scan_slices_train_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data/train/scan/')
mask_slices_train_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data/train/mask/')

scan_slices_valid_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data/valid/scan/')
mask_slices_valid_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data/valid/mask/')

scan_slices_test_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data/test/scan/')
mask_slices_test_dir = Path('C:/Users/Susanna/Documents/Heart_Segmentation/data/test/mask/')

In [None]:
# sanity check for number of images and masks in each folder 
num_train_masks = mask_slices_train_dir.glob('*.nii.gz')
num_train_masks = list(num_train_masks)
print(len(num_train_masks))
num_valid_masks = mask_slices_valid_dir.glob('*.nii.gz')
num_valid_masks = list(num_valid_masks)
print(len(num_valid_masks))
num_test_masks = mask_slices_test_dir.glob('*.nii.gz')
num_test_masks = list(num_test_masks)
print(len(num_test_masks))

In [None]:
# the data here are stored as 2D slices, but I want to be able to reconstruct them into 3D scans, and also reconstruct the masks into 3D objects.
# So, I want a way to retrive not just the scan number in the folder the 2D slice it is in, but also its position in the stack of slices that make up the scan 

def get_ind(path):
  slice_num = int(re.findall(r'\d+', path)[0])
  if 'valid' in path:
    start = 56
  elif 'test' in path:
    start = 72
  else:
    start = 1
  if slice_num <= 36:
    slice_in_scan = slice_num
    scan_num = start
  if slice_num > 36:
    scan_num = int(slice_num/36)
    slice_in_scan = slice_num - (scan_num * 36) - start + 1
  return slice_in_scan, scan_num

In [None]:
# creating the dataset 
class BasicDataset(TensorDataset):
    def __init__(self, folder, n_sample=None, transform=None):
        self.folder = os.path.join(data_root_folder, folder)
        self.imgs_dir = os.path.join(self.folder, 'scan')
        self.masks_dir = os.path.join(self.folder, 'mask')

        self.imgs_file = sorted(glob.glob(os.path.join(self.imgs_dir, '*.nii.gz')))
        self.masks_file = sorted(glob.glob(os.path.join(self.masks_dir, '*.nii.gz')))

        assert len(self.imgs_file) == len(self.masks_file), 'There are some missing images or masks in {0}'.format(folder)

        if not n_sample or n_sample > len(self.imgs_file):
            n_sample = len(self.imgs_file)

        self.n_sample = n_sample
        self.ids = list([i+1 for i in range(n_sample)])
        self.transform = transform

    def __len__(self):
        return self.n_sample

    def __getitem__(self, i):
        idx = self.ids[i]
        scan_path = os.path.join(self.imgs_dir, f"scan_{str(i+1)}.nii.gz")
        mask_path = os.path.join(self.imgs_dir, f"mask_{str(i+1)}.nii.gz")


        slice_in_scan  = get_ind(scan_path)[0]
        scan_num = get_ind(scan_path)[1]

        img = nib.load(os.path.join(self.imgs_dir, f"scan_{str(i+1)}.nii.gz")).get_fdata()
        mask = nib.load(os.path.join(self.masks_dir, f"mask_{str(i+1)}.nii.gz")).get_fdata()

        # Scale between 0 to 1
        scan_max = np.max(img)
        img = np.array(img) / scan_max

         # data augmentation
        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img = augmented["image"]
            mask = augmented["mask"]

        # Add an axis to the mask array so that it is in [channel, width, height] format.
        img = np.expand_dims(img, axis=0)
        mask = np.expand_dims(mask, axis=0)

        # return the slice position in the 36 slices making up the scan, as well as scan, mask, and ID
        return {
            'scan': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor),
            'img_id': idx,
            'slice_in_scan': slice_in_scan,
            'scan_num': scan_num,
            'max_pixel': scan_max
        }

In [None]:
# image augmentation 
transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
], is_check_shapes=False)

In [None]:
# Create train, validation, and test dataset instances
train_dataset = BasicDataset('train', transform=transforms)
valid_dataset = BasicDataset('valid', transform=transforms)
test_dataset = BasicDataset('test', transform=transforms)

In [None]:
# have a look at data 
plt.figure(figsize=(12, 8), dpi=100)
plt.subplot(1, 2, 1)
plt.title('Image')
plt.imshow((output_scans[10][:,:,1]), cmap='gray')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title('Mask')
plt.imshow(output_masks[10][:,:,1], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# recombine 2D slices into 3D scans - helper function for putting together 3D masks after training
def recombine_slices(num_imgs, dataset_type, output_path):
  start_slice = 0
  end_slice = 36
  for i in range(num_imgs): # loop that goes through 3D imgs
    scan_nifti = scan_imgs[i]
    mask_nifti = mask_imgs[i]

    curr_img = np.empty((64,64))
    curr_mask = np.empty((64,64))

    print(f"The start slice is {start_slice}")
    print(f"The end slice is {end_slice}")

    for j in range(start_slice, end_slice): # loop that goes through slices
      data = dataset_type.__getitem__(j)
      scan_slice = data['scan']
      mask_slice = data['mask']

      scan_slice = np.squeeze(scan_slice)
      mask_slice = np.squeeze(mask_slice)

      scan_slice = scan_slice.numpy()
      mask_slice = mask_slice.numpy()

      curr_img = np.dstack([curr_img, scan_slice[ :, :, None]])
      curr_mask = np.dstack([curr_mask, mask_slice[ :, :, None]])

      curr_img = np.squeeze(curr_img)
      curr_mask = np.squeeze(curr_mask)

    curr_img = curr_img[:,:,1:curr_img.shape[2]]
    curr_mask = curr_mask[:,:,1:curr_mask.shape[2]]

    print(f"Current image shape {curr_img.shape}")
    print(f"Current mask shape {curr_mask.shape}")

    curr_img_nifti = processing.conform(nib.Nifti1Image(curr_img, scan_nifti.affine), out_shape=curr_img.shape) # save curr_img as a nifti at output path
    curr_mask_nifti = processing.conform(nib.Nifti1Image(curr_mask, mask_nifti.affine), out_shape=curr_mask.shape)# save curr mask as a nifti at output path

    img_nifti_path = 'C:/Users/Susanna/Documents/Heart_Segmentation/output/scan_'  + str(i + 1) + '.nii.gz'
    mask_nifti_path = 'C:/Users/Susanna/Documents/Heart_Segmentation/output/mask'+  + str(i + 1) + '.nii.gz'
    # TO DO: Update this for non-colab version
    nib.save(curr_img_nifti, Path(img_nifti_path))
    nib.save(curr_mask_nifti, Path(mask_nifti_path))

    start_slice += 36
    end_slice += 36

    print("outer loop done")

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=2, num_workers=0, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, num_workers=0, pin_memory=True)

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up_conv = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
        )
        self.conv = DoubleConv(out_channels * 2, out_channels)


    def forward(self, x1, x2):
        x1 = self.up_conv(x1)
        x = torch.cat([x1, x2], dim=1)
        x = self.conv(x)
        return x

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv_sigmoid = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.conv_sigmoid(x)

In [None]:
# define model - standard unet architecture 
class UNet(nn.Module):
    def __init__(self, name, n_channels, n_classes):
        super(UNet, self).__init__()
        self.name = name
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inputL = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outputL = OutConv(64, n_classes)

    def forward(self, x):
        x0 = self.inputL(x)

        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.up4(x, x0)

        x = self.outputL(x)

        return x

In [None]:
my_UNet = UNet('MyUNet', n_channels=1, n_classes=1)
my_UNet.cuda()

In [None]:
optimizer = torch.optim.Adam(my_UNet.parameters(), lr=0.001)
loss_function = nn.BCELoss()

In [None]:
def dice_coeff_binary(y_pred, y_true):
        eps = 0.0001
        inter = torch.dot(y_pred.view(-1), y_true.view(-1))
        union = torch.sum(y_pred) + torch.sum(y_true)
        return ((2 * inter.float() + eps) / (union.float() + eps)).cpu().numpy()

In [None]:
def train_net(net, epochs, train_dataloader, valid_dataloader, optimizer, loss_function):

    if not os.path.isdir('{0}'.format(net.name)):
        os.mkdir('{0}'.format(net.name))

    n_train = len(train_dataloader)
    n_valid = len(valid_dataloader)

    train_loss = list()
    valid_loss = list()
    train_dice = list()
    valid_dice = list()

    for epoch in range(epochs):
        # training loop
        net.train()
        train_batch_loss = list()
        train_batch_dice = list()

        for i, batch in enumerate(train_dataloader):

            # Load a batch and pass it to the GPU
            imgs = batch['scan'].cuda()
            true_masks = batch['mask'].cuda()

            # Produce the estimated mask using current weights
            y_pred = net(imgs)

            # Compute the loss for this batch and append it to the epoch loss
            loss = loss_function(y_pred, true_masks)
            batch_loss = loss.item()
            train_batch_loss.append(batch_loss)

            # Make the thresholded mask to compute the DICE score
            pred_binary = (y_pred > 0.5).float()                    

            # Compute the DICE score 
            batch_dice_score = dice_coeff_binary(pred_binary, true_masks)
            train_batch_dice.append(batch_dice_score)


            # Reset gradient 
            optimizer.zero_grad()

            # Compute losses
            loss.backward()

            # Update the weights
            optimizer.step()

            # Print the progress
            print(f'EPOCH {epoch + 1}/{epochs} - Training Batch {i+1}/{n_train} - Loss: {batch_loss}, DICE score: {batch_dice_score}', end='\r')

        average_training_loss = np.array(train_batch_loss).mean()
        average_training_dice = np.array(train_batch_dice).mean()
        train_loss.append(average_training_loss)
        train_dice.append(average_training_dice)

        # validation loop

        net.eval()
        valid_batch_loss = list()
        valid_batch_dice = list()

     
        with torch.no_grad():
            for i, batch in enumerate(valid_dataloader):

                imgs = batch['scan'].cuda()
                true_masks = batch['mask'].cuda()

                y_pred = net(imgs)

                # Compute the loss for this batch and append it to the epoch loss
                loss = loss_function(y_pred, true_masks)
                batch_loss = loss.item()
                valid_batch_loss.append(batch_loss)

                # Make the thresholded mask to compute the DICE score
                pred_binary = (y_pred > 0.5).float()                  

                # Compute the DICE score
                batch_dice_score = dice_coeff_binary(pred_binary, true_masks)
                valid_batch_dice.append(batch_dice_score)

                print(f'EPOCH {epoch + 1}/{epochs} - Validation Batch {i+1}/{n_valid} - Loss: {batch_loss}, DICE score: {batch_dice_score}', end='\r')

        average_validation_loss = np.array(valid_batch_loss).mean()
        average_validation_dice = np.array(valid_batch_dice).mean()
        valid_loss.append(average_validation_loss)
        valid_dice.append(average_validation_dice)

        print(f'EPOCH {epoch + 1}/{epochs} - Training Loss: {average_training_loss}, Training DICE score: {average_training_dice}, Validation Loss: {average_validation_loss}, Validation DICE score: {average_validation_dice}')


    return train_loss, train_dice, valid_loss, valid_dice

In [None]:
model_path = Path('C:/Users/Susanna/Documents/Heart_Segmentation/2d_unet.pth')
EPOCHS = 30
train_loss, train_dice, valid_loss, valid_dice = train_net(my_UNet, EPOCHS, train_dataloader, valid_dataloader, optimizer, loss_function)
torch.save(my_UNet.state_dict(), model_path)

In [None]:
plt.figure(figsize=(15,8))
plt.suptitle('Learning Curve', fontsize=18)

plt.subplot(1,2,1)
plt.plot(np.arange(EPOCHS)+1, train_loss, '-o', label='Training Loss')
plt.plot(np.arange(EPOCHS)+1, valid_loss, '-o', label='Validation Loss')
# plt.xticks(np.arange(EPOCHS)+1)
plt.xlabel('Epoch', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.yticks(np.arange(0, 0.1, 0.02))
plt.legend()

plt.subplot(1,2,2)
plt.plot(np.arange(EPOCHS)+1, train_dice, '-o', label='Training DICE score')
plt.plot(np.arange(EPOCHS)+1, valid_dice, '-o', label='Validation DICE score')
# plt.xticks(np.arange(EPOCHS)+1)
plt.yticks(np.arange(0.6, 1, 0.05))
plt.xlabel('Epoch', fontsize=15)
plt.ylabel('DICE score', fontsize=15)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
torch.save(my_UNet.state_dict(), 'model_weights_2d_UNet.pth')

In [None]:
# Take the first batch
batches = []
test_num = 0

for batch in test_dataloader:
    sample_batch = batch
    batches.append(batch)

sample_batch = batches[0] 

# Generate network prediction
with torch.no_grad():
    y_pred = my_UNet(sample_batch['scan'].cuda())


test_num = 1 #0 #1
# Conver Pytorch tensor to numpy array then reverse the preprocessing steps
img = (sample_batch['scan'][test_num].numpy().transpose(1,2,0) * 255).astype('uint8')
msk = (sample_batch['mask'][test_num][0,:,:].numpy() * 255).astype('uint8')

# Exctract the relative prediction mask and threshold the probablities (>0.5)
pred_msk = (y_pred.cpu().numpy()[test_num][0,:,:] * 255).astype('uint8')
pred_msk_binary = ((y_pred.cpu().numpy()[test_num][0,:,:] > 0.5) * 255).astype('uint8')

# Take the image id for display
img_id = sample_batch['img_id'][0]

plt.figure(figsize=(24,9))
plt.suptitle(f'Test sample Image {img_id}', fontsize=18)

plt.subplot(1,4,1)
plt.title('Input Image', fontsize=15)
plt.imshow(img, cmap='gray')
plt.axis('off')

plt.subplot(1,4,2)
plt.title('Ground Truth', fontsize=15)
plt.imshow(msk, cmap='gray')
plt.axis('off')

plt.subplot(1,4,3)
plt.title('Final Thresholdded Binary Prediction (threshold > 0.5)', fontsize=15)
plt.imshow(pred_msk_binary, cmap='gray')
plt.axis('off')

input_overlayed_Pred = img.copy()
input_overlayed_Pred[pred_msk_binary == 255] = [255] 
plt.subplot(1,4,4)
plt.title('Input Image overlayed with Prediction', fontsize=15)
plt.imshow(input_overlayed_Pred, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# store output masks on test set in array 
masks_test = []
for batch in batches: 
    with torch.no_grad():
        y_pred = my_UNet(batch['scan'].cuda())
    for i in range(2): 
        pred_msk = ((y_pred.cpu().numpy()[i][0,:,:] > 0.5) * 255).astype('uint8')
        masks_test.append(pred_msk)

In [None]:
def recombine_slices(num_imgs):
  start_slice = 0
  end_slice = 36
  for i in range(num_imgs): # loop that goes through 3D imgs
    mask_nifti = mask_imgs[i + 71] 

    curr_mask = np.empty((64,64))


    for j in range(start_slice, end_slice): # loop that goes through slices
      mask_slice = masks_test[j]

      mask_slice = np.squeeze(mask_slice)

      curr_mask = np.dstack([curr_mask, mask_slice[ :, :, None]])

      curr_mask = np.squeeze(curr_mask)

    curr_mask = curr_mask[:,:,1:curr_mask.shape[2]]


    curr_mask_nifti = processing.conform(nib.Nifti1Image(curr_mask, mask_nifti.affine), out_shape=curr_mask.shape)# save curr mask as a nifti at output path

    mask_nifti_path = 'C:/Users/Susanna/Documents/Heart_Segmentation/output_vanilla/mask_'+  str(i + 1) + '.nii.gz'
    nib.save(curr_mask_nifti, Path(mask_nifti_path))

    start_slice += 36
    end_slice += 36



In [None]:
# recombine slices into 3D scans - 36 slices per scan 
recombine_slices(int(len(masks_test)/36))

In [None]:
def test_net(net, test_dataloader, loss_function):
    # Create the pred_mask folder
    
    net.eval()
    
    n_test = len(test_dataloader)
    test_batch_loss = list()
    test_batch_dice = list()
    test_batch_accuray = list()
    test_batch_CM = list()

    with torch.no_grad():
        for i, batch in enumerate(test_dataloader):

            # Load a batch 
            imgs = batch['scan'].cuda()
            true_masks = batch['mask'].cuda()
            img_ids = batch['img_id'].numpy().astype('int')

            # Produce the estimated mask 
            y_pred = net(imgs)

            # Compute the loss for this batch 
            loss = loss_function(y_pred, true_masks)
            batch_loss = loss.item()
            test_batch_loss.append(batch_loss)

            # Make the thresholded mask 
            pred_binary = (y_pred > 0.5).float()                    

            # Compute the DICE score 
            batch_dice_score = dice_coeff_binary(pred_binary, true_masks)
            test_batch_dice.append(batch_dice_score)
        
            
            # Vectorize the true mask 
            vectorize_true_masks = true_masks.view(-1).cpu().numpy()
            vectorize_pred_masks = pred_binary.view(-1).cpu().numpy()
            

            # Print the progress
            print(f'Test Batch {i+1}/{n_test} - DICE score: {batch_dice_score}', end='\r')

   
    test_dice = np.array(test_batch_dice).mean()
    
    
    return test_dice

In [None]:
test_dice = test_net(my_UNet, test_dataloader, loss_function)

print(f'Test DICE score: {test_dice}')

In [None]:
########################################################################################## Attention U-Net #################################################################################

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.conv = nn.Sequential(
          nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1,padding=1, bias=True),
          nn.BatchNorm2d(ch_out),
          nn.ReLU(inplace=True),
          nn.Conv2d(ch_out, ch_out,kernel_size=3, stride=1, padding=1, bias=True),
          nn.BatchNorm2d(ch_out),
          nn.ReLU(inplace=True),
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x  

In [None]:
class UpConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.up = nn.Sequential(nn.Upsample(scale_factor=2),nn.Conv2d(ch_in, ch_out,kernel_size=3,stride=1,padding=1, bias=True),
        nn.BatchNorm2d(ch_out),
        nn.ReLU(inplace=True),
        )
        
    def forward(self, x):
        x = self.up(x)
        return x

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, f_g, f_l, f_int):
        super().__init__()
        # gating signal weights
        self.w_g = nn.Sequential(
            nn.Conv2d(f_g, f_int, kernel_size=1, stride=1, padding=0, bias=True), 
            nn.BatchNorm2d(f_int)
        )
        # input feature scaling
        self.w_x = nn.Sequential(
            nn.Conv2d(f_l, f_int, kernel_size=1, stride=1, padding=0, bias=True),
        nn.BatchNorm2d(f_int)
        )

        # output of attention gate 
        self.psi = nn.Sequential(nn.Conv2d(f_int, 1, kernel_size=1, stride=1, padding=0,  bias=True),
        nn.BatchNorm2d(1),
        nn.Sigmoid(),
        )
        self.relu = nn.ReLU(inplace=True)

    # implementing the attention gate 
    def forward(self, g, x):
        g1 = self.w_g(g)
        x1 = self.w_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        
        return psi*x

In [None]:
class AttentionUNet(nn.Module):
    def __init__(self, n_classes=1, in_channel=1, out_channel=1):
        super().__init__() 
        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv1 = ConvBlock(ch_in=in_channel, ch_out=64)
        self.conv2 = ConvBlock(ch_in=64, ch_out=128)
        self.conv3 = ConvBlock(ch_in=128, ch_out=256)
        self.conv4 = ConvBlock(ch_in=256, ch_out=512)
        self.conv5 = ConvBlock(ch_in=512, ch_out=1024)
        
        self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
        self.att5 = AttentionBlock(f_g=512, f_l=512, f_int=256)
        self.upconv5 = ConvBlock(ch_in=1024, ch_out=512)
        
        self.up4 = UpConvBlock(ch_in=512, ch_out=256)
        self.att4 = AttentionBlock(f_g=256, f_l=256, f_int=128)
        self.upconv4 = ConvBlock(ch_in=512, ch_out=256)
        
        self.up3 = UpConvBlock(ch_in=256, ch_out=128)
        self.att3 = AttentionBlock(f_g=128, f_l=128, f_int=64)
        self.upconv3 = ConvBlock(ch_in=256, ch_out=128)
        
        self.up2 = UpConvBlock(ch_in=128, ch_out=64)
        self.att2 = AttentionBlock(f_g=64, f_l=64, f_int=32)
        self.upconv2 = ConvBlock(ch_in=128, ch_out=64)
        
        self.conv_1x1 = nn.Conv2d(64, out_channel,kernel_size=1, stride=1, padding=0)



    def forward(self, x):
        # encoder
        x1 = self.conv1(x)
        
        x2 = self.maxpool(x1)
        x2 = self.conv2(x2)
        
        x3 = self.maxpool(x2)
        x3 = self.conv3(x3)
        
        x4 = self.maxpool(x3)
        x4 = self.conv4(x4)
        
        x5 = self.maxpool(x4)
        x5 = self.conv5(x5)
        
        # decoder
        d5 = self.up5(x5)
        x4 = self.att5(g=d5, x=x4)
        d5 = torch.concat((x4, d5), dim=1)
        d5 = self.upconv5(d5)
        
        d4 = self.up4(d5)
        x3 = self.att4(g=d4, x=x3)
        d4 = torch.concat((x3, d4), dim=1)
        d4 = self.upconv4(d4)
        
        d3 = self.up3(d4)
        x2 = self.att3(g=d3, x=x2)
        d3 = torch.concat((x2, d3), dim=1)
        d3 = self.upconv3(d3)
        
        d2 = self.up2(d3)
        x1 = self.att2(g=d2, x=x1)
        d2 = torch.concat((x1, d2), dim=1)
        d2 = self.upconv2(d2)
        
        d1 = self.conv_1x1(d2)
        
        return d1

In [None]:
# check for GPU 
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

In [None]:
attention_unet = AttentionUNet(n_classes=1).to(device)

In [None]:
# show prediciton with pre trained wiehgts
for batch in train_dataloader:
    sample_batch = batch
    break


# Generat network prediction
with torch.no_grad():
    y_pred = attention_unet(sample_batch['image'].cuda())

img = (sample_batch['image'][0].numpy().transpose(1,2,0) * 255).astype('uint8')
msk = (sample_batch['mask'][0][0,:,:].numpy() * 255).astype('uint8')

pred_msk = (y_pred.cpu().numpy()[0][0,:,:] * 255).astype('uint8')
pred_msk_binary = ((y_pred.cpu().numpy()[0][0,:,:] > 0.5) * 255).astype('uint8')

# Take the image id for display
img_id = sample_batch['img_id'][0]

plt.figure(figsize=(24,9))
plt.suptitle(f'Test sample Image {img_id}', fontsize=18)

plt.subplot(1,4,1)
plt.title('Input Image', fontsize=15)
plt.imshow(img,cmap='gray')
plt.axis('off')

plt.subplot(1,4,2)
plt.title('Ground Truth', fontsize=15)
plt.imshow(msk, cmap='gray')
plt.axis('off')

plt.subplot(1,4,3)
plt.title('Non-trained Network Prediction Output \n(probability [0, 1])', fontsize=15)
plt.imshow(pred_msk, cmap='gray')
plt.axis('off')

plt.subplot(1,4,4)
plt.title('Non-trained Thresholdded Binary Prediction (threshold > 0.5)', fontsize=15)
plt.imshow(pred_msk_binary, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
def get_dice(model, loader, threshold=0.5):
    valloss = 0
    
    with torch.no_grad():

        for i_step, batch in enumerate(loader):
            
            data = batch['scan'].cuda()
            target = batch['mask'].cuda()
            
            outputs = model(data)

            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0
            picloss = dice_coeff_binary(out_cut, target.data.cpu().numpy())
            valloss += picloss

    return valloss / i_step

In [None]:
def train_model(model_name, model, train_loader, val_loader, train_loss, optimizer, lr_scheduler, num_epochs):    
    loss_history = []
    train_history = []
    val_history = []
    val_loss_history = []
    
    for epoch in range(num_epochs):
        model.train()
        
        losses = []
        train_iou = []
        
        for i_step, batch in enumerate(tqdm(train_loader)):
            
            data = batch['scan'].cuda()
            target = batch['mask'].cuda()
            
            outputs = model(data)
            
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            
            train_dice = dice_coeff_binary(out_cut, target.data.cpu().numpy())
            
            loss = train_loss(outputs, target)
            
            losses.append(loss.item())
            train_iou.append(train_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        val_mean_iou = compute_dice(model, val_loader)
        val_loss = train_loss(outputs, target)
        loss_history.append(np.array(losses).mean())
        train_history.append(np.array(train_iou).mean())
        val_history.append(val_mean_iou)
        val_loss_history.append(val_loss.item())
        
        print("Epoch [%d]" % (epoch+1))
        print("Training loss :", np.array(losses).mean(), 
              "DICE score training:", np.array(train_iou).mean(), 
              "DICE score validation", val_mean_iou,
              "Validation loss:", np.array(val_loss_history).mean())
        
    return loss_history, train_history, val_history, val_loss_history

In [None]:
opt = torch.optim.Adamax(attention_unet.parameters(), lr=1e-3) # use Adam optimizer 
loss_function = nn.BCELoss() # use BCE loss 

In [None]:
num_ep = 30
aun_lh, aun_th, aun_vh, aun_vl = train_model("Attention UNet", attention_unet, train_dataloader, valid_dataloader, loss_function, opt, False, num_ep)

In [None]:
# get test DICE score 
test_iou = get_dice(attention_unet, test_dataloader)
print(f"""Test DICE  - {test_iou}%""")

In [None]:
# Visualize model output 
batches = []

test_num = 0

for batch in test_dataloader:
    sample_batch = batch
    batches.append(batch)

sample_batch = batches[3] #0

# Generat network prediction
with torch.no_grad():
    y_pred = attention_unet(sample_batch['image'].cuda())


img = (sample_batch['image'][0].numpy().transpose(1,2,0) * 255).astype('uint8')
msk = (sample_batch['mask'][0][0,:,:].numpy() * 255).astype('uint8')

pred_msk = (y_pred.cpu().numpy()[0][0,:,:] * 255).astype('uint8')
pred_msk_binary = ((y_pred.cpu().numpy()[0][0,:,:] > 0.5) * 255).astype('uint8')

# Take the image id for display
img_id = sample_batch['img_id'][0]

plt.figure(figsize=(24,9))
plt.suptitle(f'Test sample Image {img_id}', fontsize=18)

plt.subplot(1,4,1)
plt.title('Input Image', fontsize=15)
plt.imshow(img,cmap='gray')
plt.axis('off')

plt.subplot(1,4,2)
plt.title('Ground Truth', fontsize=15)
plt.imshow(msk, cmap='gray')
plt.axis('off')

plt.subplot(1,4,3)
plt.title('Non-trained Thresholdded Binary Prediction (threshold > 0.5)', fontsize=15)
plt.imshow(pred_msk_binary, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# store output masks in array
masks_test = []
for batch in batches: 
    with torch.no_grad():
        y_pred = attention_unet(batch['image'].cuda())
    for i in range(2): 
        pred_msk = ((y_pred.cpu().numpy()[0][0,:,:] > 0.5) * 255).astype('uint8')
        # print(pred_msk.shape)
        # print(type(pred_msk))
        masks_test.append(pred_msk)

In [None]:
# recombine slices into full scans 
#recombine_slices(int(len(masks_test)/36))