<a href="https://colab.research.google.com/github/seyma-tas/Brain-Tumor-Segmentation-Project/blob/master/2AdamW_DICE_BrainTumorGenesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Mount Colab to Drive

This cell is to mount Colab to Drive, Colab is going to read data from Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## Import Necessarry Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

# read .mat files
import h5py 

import random
import os

# Pytorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import SGD, lr_scheduler,AdamW

#Train -test split
from sklearn.model_selection import train_test_split
# in Pytorch we define a device CPU or GPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

  import pandas.util.testing as tm


## Accuracy and Loss Metrics

In [None]:
def dice_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0

    return intersection / union

def dice_loss(inputs, target):
    num = target.size(0)
    inputs = inputs.reshape(num, -1)
    target = target.reshape(num, -1)
    smooth = 1.0
    intersection = (inputs * target)
    dice = (2. * intersection.sum(1) + smooth) / (inputs.sum(1) + target.sum(1) + smooth)
    dice = 1 - dice.sum() / num
    return dice

def bce_dice_loss(inputs, target):
    dicescore = dice_loss(inputs, target)
    bcescore = nn.BCELoss()
    bceloss = bcescore(inputs, target)

    return bceloss + dicescore

## Models Genesis

This is the structure from Models Genesis 

https://github.com/MrGiovanni/ModelsGenesis/blob/master/pytorch/unet3d.py

In [None]:
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
    def _check_input_dim(self, input):

        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'.format(input.dim()))
        #super(ContBatchNorm3d, self)._check_input_dim(input)

    def forward(self, input):
        self._check_input_dim(input)
        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            True, self.momentum, self.eps)


class LUConv(nn.Module):
    def __init__(self, in_chan, out_chan, act):
        super(LUConv, self).__init__()
        self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
        self.bn1 = ContBatchNorm3d(out_chan)

        if act == 'relu':
            self.activation = nn.ReLU(out_chan)
        elif act == 'prelu':
            self.activation = nn.PReLU(out_chan)
        elif act == 'elu':
            self.activation = nn.ELU(inplace=True)
        else:
            raise

    def forward(self, x):
        out = self.activation(self.bn1(self.conv1(x)))
        return out


def _make_nConv(in_channel, depth, act, double_chnnel=False):
    if double_chnnel:
        layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act)
        layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act)
    else:
        layer1 = LUConv(in_channel, 32*(2**depth),act)
        layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act)

    return nn.Sequential(layer1,layer2)

class DownTransition(nn.Module):
    def __init__(self, in_channel,depth, act):
        super(DownTransition, self).__init__()
        self.ops = _make_nConv(in_channel, depth,act)
        self.maxpool = nn.MaxPool3d(2)
        self.current_depth = depth

    def forward(self, x):
        if self.current_depth == 3:
            out = self.ops(x)
            out_before_pool = out
        else:
            out_before_pool = self.ops(x)
            out = self.maxpool(out_before_pool)
        return out, out_before_pool

class UpTransition(nn.Module):
    def __init__(self, inChans, outChans, depth,act):
        super(UpTransition, self).__init__()
        self.depth = depth
        self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
        self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True)

    def forward(self, x, skip_x):
        out_up_conv = self.up_conv(x)
        concat = torch.cat((out_up_conv,skip_x),1)
        out = self.ops(concat)
        return out


class OutputTransition(nn.Module):
    def __init__(self, inChans, n_labels):

        super(OutputTransition, self).__init__()
        self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.sigmoid(self.final_conv(x))
        return out

class UNet3D(nn.Module):
    # the number of convolutions in each layer corresponds
    # to what is in the actual prototxt, not the intent
    def __init__(self, n_class=1, act='relu'):
        super(UNet3D, self).__init__()

        self.down_tr64 = DownTransition(1,0,act)
        self.down_tr128 = DownTransition(64,1,act)
        self.down_tr256 = DownTransition(128,2,act)
        self.down_tr512 = DownTransition(256,3,act)

        self.up_tr256 = UpTransition(512, 512,2,act)
        self.up_tr128 = UpTransition(256,256, 1,act)
        self.up_tr64 = UpTransition(128,128,0,act)
        self.out_tr = OutputTransition(64, n_class)

    def forward(self, x):
        self.out64, self.skip_out64 = self.down_tr64(x)
        self.out128,self.skip_out128 = self.down_tr128(self.out64)
        self.out256,self.skip_out256 = self.down_tr256(self.out128)
        self.out512,self.skip_out512 = self.down_tr512(self.out256)

        self.out_up_256 = self.up_tr256(self.out512,self.skip_out256)
        self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
        self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
        self.out = self.out_tr(self.out_up_64)

        return self.out

## Load Data

To load ".mat" images h5py library is imported. 

The shapes of MRI images are (512, 512) but there are 15 images whose shapes are (128, 128).  These 15 images are omitted.



In [None]:
# Load 3064 images, images are in .mat format, names are numbers.(1.mat, 2.mat ...)
image_data = []
mask_data = []

filenames = range(1,3065)
# Choose random number of images when needed
# filenames = random.sample(filenames, 200) 

for name in filenames:
    file = h5py.File('/content/drive/My Drive/Brain-Tumor-Segmentation-Project/brainTumorData/'+str(name)+'.mat', 'r').get('cjdata')
   
    input1 = file.get('image')[()]
    mask = file.get('tumorMask')[()]

    if input1.shape == (512, 512) and mask.shape == (512, 512):

        input2 = np.reshape(input1,(64,64,64))
        input3 = np.expand_dims(input2,axis=0)
        image_data.append(input3)
        mask = np.reshape(mask,(64,64,64))
        mask = np.expand_dims(mask,axis=0)
        mask_data.append(mask)

    else: 
        print(name)
# There are 15 images whose sizes are not suitable to the model. I omitted these images.

955
956
957
1070
1071
1072
1073
1074
1075
1076
1203
1204
1205
1206
1207


## Split the data into train, test and validation sets

In [None]:
image_train_data, image_test_data = train_test_split(image_data, test_size = 0.1, random_state=123)
image_train_data, image_val_data = train_test_split(image_train_data, test_size = 0.111, random_state=123)
print(len(image_train_data),len(image_test_data), len(image_val_data) )

2439 305 305


In [None]:
mask_train_data, mask_test_data = train_test_split(mask_data, test_size = 0.1, random_state=123)
mask_train_data, mask_val_data = train_test_split(mask_train_data, test_size = 0.111, random_state=123)
print(len(mask_train_data),len(mask_test_data), len(mask_val_data) )

2439 305 305


# Train, Validation and Test Datasets and DataLoaders

In [None]:
#Train Data
image_train_data = torch.Tensor(image_train_data) # transform to torch tensor
mask_train_data = torch.Tensor(mask_train_data)

train_dataset = TensorDataset(image_train_data,mask_train_data) # create your datset
train_dataloader = DataLoader(train_dataset,batch_size= 4, num_workers=2, shuffle=False) # create your dataloader

In [None]:
# Test Data
image_test_data = torch.Tensor(image_test_data) # transform to torch tensor
mask_test_data = torch.Tensor(mask_test_data)

test_dataset = TensorDataset(image_test_data,mask_test_data) # create your datset
test_dataloader = DataLoader(test_dataset,batch_size= 4, num_workers=2, shuffle=False) # create your dataloader

In [None]:
#Validation Data
image_val_data = torch.Tensor(image_val_data) # transform to torch tensor
mask_val_data = torch.Tensor(mask_val_data)

val_dataset = TensorDataset(image_val_data,mask_val_data) # create your datset
val_dataloader = DataLoader(val_dataset,batch_size= 4, num_workers=2, shuffle=False) # create your dataloader

# Model Structure

In [None]:
# https://github.com/MrGiovanni/ModelsGenesis/tree/master/pytorch
model = UNet3D()
model

UNet3D(
  (down_tr64): DownTransition(
    (ops): Sequential(
      (0): LUConv(
        (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (bn1): ContBatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU(inplace=True)
      )
      (1): LUConv(
        (conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (bn1): ContBatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU(inplace=True)
      )
    )
    (maxpool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down_tr128): DownTransition(
    (ops): Sequential(
      (0): LUConv(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (bn1): ContBatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU(inplace=True)
      )
      (

## The pre-trained weights

In [None]:
#Define the file keeping the pre-trained weights 
# Genesis_Chest_CT.pt is a file I downloaded with the permission of Model Genesis
weight_dir = '/content/drive/My Drive/Brain-Tumor-Segmentation-Project/Genesis_Chest_CT.pt'

# Load the weights of Model Genesis
checkpoint = torch.load(weight_dir, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
# Initialize a dictionary to store weights
unParalled_state_dict = {}
# Store weights in unParalled_state_dict
for key in state_dict.keys():
    unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
# Load the new dictionary to the model
model.load_state_dict(unParalled_state_dict)


<All keys matched successfully>

## Define parameters

In [None]:
#Define the criterion
criterion = bce_dice_loss
# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), 0.1)
# Change the learning rate to reach the global minimum
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Train the model

In [None]:
# Move the model to GPU 
model.to(device)

# Initialize lists to store loss values
loss_history = []
loss_history_val = []

best_loss_val = float('inf')

# Train
print("Start train...")
for epoch in range(50):
   #Train mode
    model.train()
    loss_running = []
    for _, (x,y) in enumerate(train_dataloader):
        x, y = x.float().to(device), y.float().to(device)
        
        pred = model(x)
        loss = criterion(pred, y)
        loss_running.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
       
    loss_history.append(np.mean(loss_running))
    # Evaluate mode
    model.eval()
    with torch.no_grad():
        loss_val_running = []
        for _, (x_val, y_val) in enumerate(val_dataloader):
            x_val, y_val = x_val.to(device), y_val.to(device)
            pred_val = model.forward(x_val) #pred_val = model(x_val)
            loss_val= criterion(pred_val, y_val)
            loss_val_running.append(loss_val.item())
    
    curr_loss_val = np.mean(loss_val_running)
    loss_history_val.append(curr_loss_val)
    # Save the best weights
    if curr_loss_val < best_loss_val:
        best_loss_val = curr_loss_val
        torch.save(model.state_dict(), '/content/drive/My Drive/Brain-Tumor-Segmentation-Project/best_model.pth')
    # Change the learning rate
    scheduler.step()
    # Print the results
    print("epoch", epoch, "train loss", loss_history[-1], "val loss", loss_history_val[-1])

Start train...
epoch 0 train loss 0.9750198336898304 val loss 0.9472322549138751
epoch 1 train loss 0.8770960611886666 val loss 0.8264056815729512
epoch 2 train loss 0.8104350524359062 val loss 0.7786627853071535
epoch 3 train loss 0.7722033008688786 val loss 0.7684954346774461
epoch 4 train loss 0.7563204318773551 val loss 0.7626414922150698
epoch 5 train loss 0.7057667979451476 val loss 0.6936392524799744
epoch 6 train loss 0.6887377186632547 val loss 0.6984348324212161
epoch 7 train loss 0.6803527646377439 val loss 0.6699143632665857
epoch 8 train loss 0.6741126392219887 val loss 0.699267801526305
epoch 9 train loss 0.6690489853014712 val loss 0.6775594036300461


## Load the saved best model's weights

In [None]:
checkpoint = torch.load('/content/drive/My Drive/Brain-Tumor-Segmentation-Project/best_model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

## Function to plot the mask

In [None]:
def plot_mask(mask_3d_array, axx): # takes 64*64*64 array
    mask_cpu = mask_3d_array.cpu().detach().numpy()
    reshaped_mask_cpu = np.reshape(mask_cpu,(512, 512))
    print(np.max(reshaped_mask_cpu), np.min(reshaped_mask_cpu))
    reshaped_mask_cpu_bin = np.round(reshaped_mask_cpu)
    axx.imshow(reshaped_mask_cpu_bin)

# Plot 4 random predicted and ground truth masks

In [None]:
dataloader = test_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model.forward(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = test_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = test_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = test_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = train_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = test_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = train_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model.forward(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = val_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

In [None]:
dataloader = val_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    if n in rand_ndx:
        pred = model(x)
        plot_mask(pred[0,0,:,:,:], ax[0][i])
        plot_mask(y[0,0,:,:,:], ax[1][i])
        i+=1

## Heatmap of one predicted mask

In [None]:
mask_cpu = pred.cpu().detach().numpy()
reshaped_mask_cpu = np.reshape(mask_cpu[0,0,:,:,:],(512, 512))

reshaped_mask_cpu += .1
reshaped_mask_cpu_bin = np.round(reshaped_mask_cpu)
plt.figure(figsize=(15, 15))
sns.heatmap(reshaped_mask_cpu)

## BCE-Dice Loss and Dice Accuracy Plot

In [None]:
plt.figure(figsize=(15, 7))
plt.plot(loss_history)
plt.plot(loss_history_val)

## Function to compute the dice accuracy

In [None]:
def compute_acc(dataloader, model):
    acc = []
    loss = []
    #model.eval()
    #with torch.no_grad():
    for _, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss.append(bce_dice_loss(pred, y).item())
        # loss.append(BCELoss)    
        acc.append(dice_metric(pred.data.cpu().numpy(), y.data.cpu().numpy()))
       
    print(np.mean(loss), np.mean(acc))

## Compute dice accuracy for train, validation and test data

In [None]:
compute_acc(test_dataloader, model)

In [None]:
compute_acc(train_dataloader, model)

In [None]:
compute_acc(val_dataloader, model)