# MIA Project 1

In [None]:
# Link to google drive
from google.colab import drive
drive.mount('/content/drive')


Install the library for U-Net Model (segmentation_models_pytorch)

In [None]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp


In [None]:
!pip install SimpleITK
import SimpleITK as sitk


In [None]:
import numpy as np
import os

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as functional
from torch.utils.data import Dataset, DataLoader

import torchvision.models
from torchvision import transforms
import torchvision.transforms.functional as tF

from skimage import io
import matplotlib.pyplot as plt
from PIL import Image

from pathlib import Path

import time


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")


# Set Training Data Path

In [None]:
root_dir = Path('/content/drive/MyDrive/MIA23_Project1_data') # Data path
dir_list = os.listdir(root_dir)
dir_list.sort()
print(dir_list)


In [None]:
training_data = []
ground_truth = []
for i in range(len(dir_list)):
    if 'patient' in dir_list[i]:
      
      # Run 2 Chamber view

        training_data.append(root_dir/dir_list[i]/Path(dir_list[i]+"_2CH_ED.mhd"))
        training_data.append(root_dir/dir_list[i]/Path(dir_list[i]+"_2CH_ES.mhd"))
        ground_truth.append(root_dir/dir_list[i]/Path(dir_list[i]+"_2CH_ED_gt.mhd"))
        ground_truth.append(root_dir/dir_list[i]/Path(dir_list[i]+"_2CH_ES_gt.mhd"))


      # Run 4 Chamber view

        training_data.append(root_dir/dir_list[i]/Path(dir_list[i]+"_4CH_ED.mhd"))
        training_data.append(root_dir/dir_list[i]/Path(dir_list[i]+"_4CH_ES.mhd"))
        ground_truth.append(root_dir/dir_list[i]/Path(dir_list[i]+"_4CH_ED_gt.mhd"))
        ground_truth.append(root_dir/dir_list[i]/Path(dir_list[i]+"_4CH_ES_gt.mhd"))

print(training_data)
print(ground_truth)


# Data Loader

In [None]:
## Image Dataloader


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((256, 256))
    ]
)



class ImageDataset(Dataset):
    
    """
    ImageDataset
    """
    
    def __init__(self,
                 img_dir,
                 mask_dir,
                 num_classes = 4,
                #  transforms=None,
                #  transforms_mask = None,
                 data_augmentation = False):
        """       
        Args:
            img_dir (list): Path to images
            mask_dir (list): Path to masks
            transforms (list or None): Image transformations to apply upon loading.
        """
        
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        
        self.num_classes = num_classes
        self.data_augmentation = data_augmentation
  
    
    def __len__(self):
        """
        
        """
        return len(self.mask_dir)

    

    def __getitem__(self,
                    idx):

        ## Load Images and Masks
        img_name = self.img_dir[idx]
        mask_name = self.mask_dir[idx]

        img = io.imread(img_name)
        mask = io.imread(mask_name)

        img = np.transpose(img,(1, 2, 0))
        mask = np.transpose(mask,(1, 2, 0))

        list_label = np.unique(mask)

        # all labels
        # new_label = np.zeros((mask.shape[0], mask.shape[1], len(list_label))) 
        # for i in range(len(list_label)):
        #   new_label[:, :, i] = (mask[:, :, 0] == list_label[i])

        new_label = np.zeros((mask.shape[0], mask.shape[1], 1))
        new_label = (mask[:, :, 0] == list_label[1])
     
        new_img = transform(img)
        new_mask = transform(new_label)
  
        return new_img, new_mask


## Sanity Test

In [None]:
## Batch Size
train_batch_size = 32

## Initialize Dataloaders
train_dataset = ImageDataset(img_dir=training_data, mask_dir=ground_truth, data_augmentation = False)
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

dataiter = iter(train_dataloader)
images, labels = next(dataiter)

plt.figure(figsize=(15,7))
plt.subplot(121)
plt.imshow(images[0,:,:,:].cpu().detach().permute(1, 2, 0)[:, :, 0], cmap='gray')
plt.subplot(122) 
plt.imshow(labels[0,:,:,:].cpu().detach().permute(1, 2, 0)[:, :, 0], cmap='gray')
plt.title('Target/Mask')
plt.show()


# Loss Function

In [None]:
class DICELoss(nn.Module):
    def __init__(self):
        super(DICELoss, self).__init__()

    def forward(self, inputs, targets):
        num = 2.0 * torch.sum((inputs * targets), dim = [2,3])
        den = torch.sum((inputs**2), dim=[2,3]) + torch.sum((targets**2), dim=[2,3])
        dice = (num+1e-7) / (den+1e-7)
        #print(dice.shape)
        return 1. - 1.0*torch.mean(dice)


class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):        
        #flatten label and prediction tensors
        num = 2. * torch.sum((inputs * targets), dim = [2,3])
        den = torch.sum((inputs**2), dim=[2,3]) + torch.sum((targets**2), dim=[2,3])
        dice = (num+1e-7) / (den+1e-7)
        dice_loss = 1. - 1.0*torch.mean(dice)
        
        BCE = nn.functional.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

class IOULoss(nn.Module):
    def __init__(self):
        super(IOULoss, self).__init__()

    def forward(self, inputs, targets):
        num = torch.sum((inputs * targets), dim=[2,3])
        den = torch.sum(inputs, dim=[2,3]) + torch.sum(targets, dim=[2,3]) - num
        IOU = (num+1e-7) / (den+1e-7)
        return 1. - 1.0*torch.mean(IOU)




# Metrics

In [None]:
def IOU_metric(inputs, targets):
    batch_size = inputs.shape[0]
    IOU_batch = np.zeros(batch_size)
    for i in range(batch_size):
        IOU_score = 0
        
        prediction = inputs[i]
        target = targets[i]                 
        num = torch.sum(prediction * target)
        den = torch.sum(prediction) + torch.sum(target) - num
        
        if torch.sum(target) == 0:
            if torch.sum(prediction) == 0:
                IOU_score = 1
            else:
                IOU_score = 0
        else:
            IOU_score = (num) / (den)
        
        IOU_batch[i] = IOU_score
    return (IOU_batch).mean()



def DICE_metric(inputs, targets):
    batch_size = inputs.shape[0]
    dice_batch = np.zeros(batch_size)
    for i in range(batch_size):
        dice_score = 0
        prediction = inputs[i]
        target = targets[i]                 
        TP = torch.sum(prediction * target)
        FP = torch.sum(torch.where(target[prediction == 1] == 0, 1, 0)).float()
        FN = torch.sum(torch.where(target[prediction == 0] == 1, 1, 0)).float()

        if torch.sum(target) == 0:
            if FP == 0:
                dice_score = 1
            elif FP > 0:
                dice_score = 0
            else:
                dice_score = (2 * TP + 1e-7)/(TP + FP + TP + FN + 1e-7)
        dice_batch[i] = dice_score
    return (dice_batch).mean()


# Trainning Procedure

In [None]:
## Batch Size
train_batch_size = 32
validation_batch_size = 32

kfold = 5
total_size = len(training_data)
sub_size = int(total_size/kfold)

# Number of random seed
num_seed = 1

# Epochs (Consider setting high and implementing early stopping)
#num_epochs = 3000
num_epochs = 100

#threshold
threshold = 0.5

# loss_fn = nn.BCELoss()
# loss_fn = DICELoss()
# loss_fn = IOULoss()
loss_fn = DiceBCELoss()
# loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
# loss_fn = smp.losses.JaccardLoss(smp.losses.BINARY_MODE, from_logits=True)


all_seed_IOU = []
all_seed_DICE = []
all_seed_epoch = []


for s in range(num_seed):
    torch.manual_seed(s)
    
    print("\n------------------------------------------------")
    print(f"Seed Number: {s+1}")

    z = []
    z = list(zip(training_data, ground_truth))
    np.random.shuffle(z)
    all_training_files, all_training_truth_files = zip(*z)

    all_fold_IOU = []
    all_fold_DICE = []
    all_fold_epoch = []

    for k in range(kfold):
    # for k in [0]:
        print("\n-------------------------------")
        print(f"Fold Number: {k+1}")
        val_img_dir = all_training_files[k*sub_size:(k+1)*sub_size]
        val_mask_dir = all_training_truth_files[k*sub_size:(k+1)*sub_size]

        if k == 0:
          train_img_dir = all_training_files[(k+1)*sub_size:] 
          train_mask_dir = all_training_truth_files[(k+1)*sub_size:] 
          #train_img_dir.sort()
          #train_mask_dir.sort()
        else:
          train_img_dir = all_training_files[0:(k)*sub_size] + all_training_files[(k+1)*sub_size:]
          train_mask_dir = all_training_truth_files[0:(k)*sub_size] + all_training_truth_files[(k+1)*sub_size:]
          #train_img_dir.sort()
          #train_mask_dir.sort()

        
        ## Initialize Dataloaders
        train_dataset = ImageDataset(img_dir=train_img_dir, mask_dir=train_mask_dir, data_augmentation = True)
        train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

        validation_dataset = ImageDataset(img_dir=val_img_dir, mask_dir=val_mask_dir, data_augmentation = False)
        validation_dataloader = DataLoader(validation_dataset, batch_size=validation_batch_size, shuffle=True)

        print("Start Training...")
        
        model = smp.Unet("resnet34", encoder_weights='imagenet', in_channels = 1, classes=1, activation='sigmoid')
        #model = smp.Unet("vgg16", encoder_weights='imagenet', classes=1, activation='sigmoid')
        #model = smp.Unet("inceptionv4", encoder_weights='imagenet', classes=1, activation='sigmoid')
        #model = smp.Unet("densenet201", encoder_weights='imagenet', in_channels = 1, classes=1, activation='sigmoid')        
        #model = smp.FPN("resnet34", encoder_weights='imagenet', decoder_dropout = 0.2, in_channels = 1, classes=1, activation='sigmoid')
        #model = smp.PSPNet("resnet18", encoder_weights='imagenet', psp_dropout = 0.2, in_channels = 1, classes=1, activation='sigmoid')
        #model = smp.Unet("vgg16", encoder_weights='imagenet', classes=1, activation='sigmoid')
        #model = smp.Unet("inceptionresnetv2", encoder_weights='imagenet', classes=1, activation='sigmoid')
        model = model.to(device)

        # model = torch.load('/content/drive/MyDrive/Colab Notebooks/Project1Model_20230512-0150_2CH_O1.pth')

        #learning_rate = 0.0001
        learning_rate = 0.001
        #learning_rate = 0.01

        ## Initialize Optimizer and Learning Rate Scheduler
        optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
        #optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate, momentum=0.9)
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)


        patience = 0
        current_val_loss = 100000000000
        previous_val_loss = current_val_loss

        epoch_train_loss = []
        epoch_val_loss = []
        epoch_IOU_score = []
        epoch_DICE_score = []
        
        for epoch in range(num_epochs):
            #print("\nEPOCH " +str(epoch+1)+" of "+str(num_epochs)+"\n")
            
            batch_train_loss = 0
            batch_train_length = 0
            batch_val_loss = 0
            batch_val_length = 0
            batch_IOU = 0
            batch_DICE = 0
            batch_HD = []


            ########################### Training ###################################
            model.train()
            for batch, (input, target) in enumerate(train_dataloader):
                
                input = input.to(device, dtype=torch.float)
                target = target.to(device, dtype=torch.float)
        
                # print(input.shape)
                
                # Compute prediction error
                pred = model.forward(input)
                
                #pred = pred.sigmoid()
                train_loss = loss_fn(pred, target)
                batch_train_loss += train_loss.item()
                batch_train_length += 1
                
                # Backpropagation
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()
            
            
            ########################### Validation #################################
            model.eval()
            for batch, (input, target) in enumerate(validation_dataloader):
                
                input = input.to(device, dtype=torch.float)
                target = target.to(device, dtype=torch.float)
                
                # Compute prediction error
                pred = model.forward(input)
                val_loss = loss_fn(pred, target)
                batch_val_loss += val_loss.item()

                #pred = pred.sigmoid()
                pred = (pred>threshold).float()

                batch_IOU += IOU_metric(pred, target)
                batch_DICE += DICE_metric(pred, target)
                batch_val_length += 1


            current_val_loss = float(batch_val_loss)/float(batch_val_length)
            print(current_val_loss)

            epoch_train_loss.append(float(batch_train_loss)/float(batch_train_length))
            epoch_val_loss.append(current_val_loss)
            epoch_IOU_score.append(float(batch_IOU)/float(batch_val_length))
            epoch_DICE_score.append(float(batch_DICE)/float(batch_val_length))

            
            if current_val_loss >= previous_val_loss:
              patience += 1
            else:
              patience = 0
              previous_val_loss = current_val_loss

            
            if patience > 10:
              learning_rate = learning_rate * 0.5
            
            
            if patience > 20:
              print("stop at: ", epoch+1, " epoch")
              break

        
        # plot the loss
        plt.figure(figsize=(15,4))
        plt.title("Training and Validation Loss")
        plt.plot(epoch_train_loss,label="train loss")
        plt.plot(epoch_val_loss,label="validation loss")
        plt.xlabel("Number of Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()

        
        # plot the metrics
        plt.figure(figsize=(15,4))
        plt.title("Metrics on Validation dataset")
        plt.plot(epoch_IOU_score,label="IOU Score")
        plt.plot(epoch_DICE_score,label="DICE Score")
        plt.xlabel("Number of Epoch")
        plt.ylabel("Metrics")
        plt.legend()
        plt.show()
        

        all_fold_IOU.append(float(batch_IOU)/float(batch_val_length))
        all_fold_DICE.append(float(batch_DICE)/float(batch_val_length))
        all_fold_epoch.append(epoch+1)

        all_seed_IOU.append(float(batch_IOU)/float(batch_val_length))
        all_seed_DICE.append(float(batch_DICE)/float(batch_val_length))
        all_seed_epoch.append(epoch+1)

    print("\n-------------------------------")
    print(f"Result of Seed Number: {s+1}")

    print("IOU:")
    print("mean: {:0.3f}".format(np.mean(all_fold_IOU)))
    print("std: {:0.3f}".format(np.std(all_fold_IOU)))
    print("max: {:0.3f}".format(max(all_fold_IOU)))
    print("min: {:0.3f}".format(min(all_fold_IOU)))
    print("ALL: ", np.around(np.array(all_fold_IOU),3))

    print("\n-------------------------------")
    print("DICE:")
    print("mean: {:0.3f}".format(np.mean(all_fold_DICE)))
    print("std: {:0.3f}".format(np.std(all_fold_DICE)))
    print("max: {:0.3f}".format(max(all_fold_DICE)))
    print("min: {:0.3f}".format(min(all_fold_DICE)))
    print("ALL: ", np.around(np.array(all_fold_DICE),3))


    print("\n-------------------------------")
    print("Average Epochs: {:0.3f}".format(np.mean(all_fold_epoch)))


Save Model

In [None]:
PATH = '/content/drive/MyDrive/Colab Notebooks/Project1Model_' + time.strftime("%Y%m%d-%H%M") + '.pth'
torch.save(model, PATH)

# Results

In [None]:
print("All seed result")
print("\n-------------------------------")
print("IOU:")
print("mean: {:0.3f}".format(np.mean(all_seed_IOU)))
print("std: {:0.3f}".format(np.std(all_seed_IOU)))
print("max: {:0.3f}".format(max(all_seed_IOU)))
print("min: {:0.3f}".format(min(all_seed_IOU)))

print("ALL: ", np.around(np.array(all_seed_IOU),3))
print()

print("\n-------------------------------")
print("DICE:")
print("mean: {:0.3f}".format(np.mean(all_seed_DICE)))
print("std: {:0.3f}".format(np.std(all_seed_DICE)))
print("max: {:0.3f}".format(max(all_seed_DICE)))
print("min: {:0.3f}".format(min(all_seed_DICE)))
print("ALL: ", np.around(np.array(all_seed_DICE),3))

print("\n-------------------------------")
print("Average Epochs: {:0.3f}".format(np.mean(all_seed_epoch)))

## Visualization

In [None]:
def evaluation(model, dataloader):
    n_batches = len(dataloader)
    IOU_scores = np.zeros(n_batches)
    DICE_scores = np.zeros(n_batches)
    ## Evaluate
    model.eval()
    idx = 0
    for data in dataloader:
        ## Format Data
        input, target = data
        
        input = input.to(device)
        # print(input.shape)
        target = target.to(device)
        
        ## Make Predictions
        out = model(input)
        
        #out = out.sigmoid()
        #out = (out).float()
        out = (out>threshold).float()

        
        #Visualization
        for i in range(len(input)):
          plt.figure(figsize=(15,15))
          plt.subplot(131)
          plt.imshow(input[i,:,:,:].cpu().detach().permute(1, 2, 0)[:, :, 0], cmap='gray')
        
          plt.subplot(132) 
          plt.imshow(target[i,:,:,:].cpu().detach().permute(1, 2, 0)[:, :, 0], cmap='gray')
          plt.title('target')

          plt.subplot(133) 
          plt.imshow(out[i,:,:,:].cpu().detach().permute(1, 2, 0)[:,:,0], cmap='gray')
          plt.title('prediction'+str(IOU_metric(out[i,:,:,:], target[i,:,:,:])))
          plt.show()

        IOU_scores[idx] = IOU_metric(out, target)
        DICE_scores[idx] = DICE_metric(out, target)
        idx += 1
        
    ## Average IOU and Dice Score Over Images
    m_IOU = IOU_scores.mean()
    m_dice = DICE_scores.mean()
    print("IOU: ",m_IOU)
    print("DICE: ", m_dice)
    return


In [None]:
evaluation(model, validation_dataloader)

# Test (Run on testset)

Set testset path

In [None]:
root_dir = Path('/Users/ychen215/Desktop/Project1/testset') # Testset Path
dir_list = os.listdir(root_dir)
dir_list.sort()
print(dir_list)

Load trained model

In [None]:
model = smp.Unet("resnet34", encoder_weights='imagenet', in_channels = 1, classes=1, activation='sigmoid')
model = model.to(device)
model = torch.load('/content/drive/MyDrive/Colab Notebooks/Project1Model_20230508-0501_4CH.pth', map_location=torch.device('cpu'))

Save prediction to .mhd/.raw file

In [None]:
for i in range(len(dir_list)):
    
#     if i != 0 and i != 1:
     if i != 0:
      
      # Run 2 Chamber view
        
#         seq_name = root_dir/dir_list[i]/Path(dir_list[i]+"_2CH_sequence.mhd")
#         mdh_data = root_dir/dir_list[i]/Path("R_2CH_sequence.mhd")
#         mdh_data = str(root_dir/dir_list[i]/Path("R_2CH_sequence"))
#         raw_data = root_dir/dir_list[i]/Path("R_2CH_sequence.raw")


      # Run 4 Chamber view

        seq_name = root_dir/dir_list[i]/Path(dir_list[i]+"_4CH_sequence.mhd")
        mdh_data = root_dir/dir_list[i]/Path("R_4CH_sequence.mhd")
        mdh_data = str(root_dir/dir_list[i]/Path("R_4CH_sequence"))
        raw_data = root_dir/dir_list[i]/Path("R_4CH_sequence.raw")
      
      # print(seq_name)
        print(mdh_data)
      # print(raw_data)
    
    
        seq = io.imread(seq_name)
        original_shape = seq.shape
        
        seq = np.transpose(seq,(1, 2, 0))

        transform = transforms.Compose(
          [
              transforms.ToTensor(),
              transforms.Resize((256, 256))
          ]
        )

        new_seq = transform(seq)

        output = np.zeros(original_shape)
        input = new_seq
        input = torch.unsqueeze(input, dim=1)
        input = input.to(device)

        out = model(input)
        out = (out>0.5).int()

        transform_back = transforms.Compose(
          [
              transforms.Resize((original_shape[1], original_shape[2]))
          ]
        )

        out = transform_back(out)
           
        
        for i in range(new_seq.shape[0]):
            output[i,:,:] = out[i].cpu().detach().permute(1, 2, 0)[:,:,0]
            
        # Create a SimpleITK image from the numpy array
        image = sitk.GetImageFromArray(output)

        # Set the image origin and spacing
        image.SetOrigin((0,0,0))
        image.SetSpacing((0.308, 0.154,  1.54))        
     
        # Set the data type of the image to be MET_CHAR
        image = sitk.Cast(image, sitk.sitkInt8)
        
        # Save the image in .mhd and .raw format
        sitk.WriteImage(image, mdh_data + ".mhd")

