In [1]:
%%javascript
if (IPython.notebook.kernel) {
    IPython.notebook.kernel.execute('book_name = "' + IPython.notebook.notebook_name+'"')
}

<IPython.core.display.Javascript object>

In [2]:
import sys
sys.path.append('..')

In [3]:
print(book_name)

experiment_id = int(book_name[book_name.rfind('seed')+4:len(book_name)-6:1]) # Experiment 1 - ResNet replicate
print("Experiment:",experiment_id)

Train_UNet_seed1.ipynb
Experiment: 1


In [4]:
# np seed max 2**32 - 1
# https://numpy.org/doc/stable/reference/random/legacy.html
seed_factor = ((2**32 - 1)/60)

def set_seeds(seed, experiment_id=None):
    if not seed:
        seed = 10
        
    seed = int(seed * experiment_id)

    print("[ Using Seed : ", seed, " ]")

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [5]:
import os
import json
import numpy as np
import torch
from sklearn.model_selection import train_test_split

from experiments.UNetExperiment import UNetExperiment
from data_prep.HippocampusDatasetLoader import LoadHippocampusData

"""
This module represents a UNet experiment and contains a class that handles
the experiment lifecycle
"""
import os
import time
import nibabel as nib

import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_prep.SlicesDataset import SlicesDataset
from utils.utils import log_to_tensorboard
from utils.volume_stats import Dice3d, Jaccard3d, Sensitivity, Specificity #, F1_score
# from networks.RecursiveUNet import UNet
from inference.UNetInferenceAgent import UNetInferenceAgent
from torch.nn import init
import random
from pathlib import Path
from glob import glob

from torchvision import transforms
import torchvision.transforms.functional as TF
from scipy.ndimage import map_coordinates
from scipy.ndimage import gaussian_filter

from collections import OrderedDict
from torch import nn
import copy

In [6]:
class Config:
    """
    Holds configuration parameters
    """
    def __init__(self):
        self.name = "basic_unet_"+"seed"+str(experiment_id)
        self.root_dir = r"../../data"
        self.n_epochs = 100 # 10
        self.learning_rate = 0.0002
        self.batch_size = 64
        self.patch_size = 64 #64
        self.test_results_dir = r"out"

In [7]:
c = Config()

from glob import glob
print('num_train = {}'.format(len(glob(c.root_dir+"/images/*"))))
print('Image dimensions:')
print('Train:')
import matplotlib.pyplot as plt

# Dimensions
for idx, img_name in enumerate(glob(c.root_dir+"/images/*")):
    img = nib.load(img_name)
    print(idx+1,":",img.shape)

# Find number of sub labels

# Images
img = nib.load(glob(c.root_dir+"/images/*")[0]).get_fdata()
label = nib.load(glob(c.root_dir+"/labels/*")[0]).get_fdata()
print('Image Min-Max values: Image={},{} and label={},{}'.format(img.max(), img.min(), label.max(), label.min()))
print('Number of subclasses = ', int(label.max())+1)

num_train = 260
Image dimensions:
Train:
1 : (33, 50, 35)
2 : (35, 55, 41)
3 : (34, 51, 38)
4 : (35, 56, 28)
5 : (36, 45, 39)
6 : (38, 51, 37)
7 : (37, 50, 38)
8 : (33, 55, 29)
9 : (36, 46, 43)
10 : (35, 44, 41)
11 : (38, 47, 37)
12 : (35, 51, 36)
13 : (35, 51, 36)
14 : (36, 47, 39)
15 : (34, 49, 36)
16 : (38, 48, 33)
17 : (34, 52, 35)
18 : (34, 52, 40)
19 : (39, 50, 40)
20 : (36, 51, 31)
21 : (37, 51, 35)
22 : (38, 51, 35)
23 : (31, 50, 36)
24 : (32, 51, 28)
25 : (35, 52, 34)
26 : (36, 51, 35)
27 : (37, 49, 34)
28 : (37, 52, 34)
29 : (36, 40, 43)
30 : (34, 53, 36)
31 : (34, 53, 34)
32 : (37, 48, 37)
33 : (35, 46, 42)
34 : (37, 51, 35)
35 : (33, 48, 38)
36 : (35, 49, 34)
37 : (35, 46, 39)
38 : (39, 41, 42)
39 : (38, 51, 37)
40 : (35, 52, 38)
41 : (34, 52, 38)
42 : (33, 51, 34)
43 : (35, 50, 36)
44 : (32, 45, 38)
45 : (36, 52, 37)
46 : (36, 51, 34)
47 : (34, 56, 31)
48 : (37, 47, 32)
49 : (35, 52, 34)
50 : (37, 56, 36)
51 : (33, 47, 34)
52 : (40, 52, 35)
53 : (36, 53, 37)
54 : (36, 48, 

In [8]:
c = Config()

# Load data
print("Loading data...")

set_seeds(seed_factor, experiment_id)

# TASK: LoadHippocampusData is not complete. Go to the implementation and complete it. 
data = LoadHippocampusData(c.root_dir, y_shape = c.patch_size, z_shape = c.patch_size)

Loading data...
[ Using Seed :  71582788  ]
Processed 260 files, total 9198 image slices, total 9198 mask slices


In [9]:
split_idx = np.arange((len(data)))

set_seeds(seed_factor, experiment_id)
# 60:20:20 split using train_test_split()
train, test = train_test_split(split_idx, test_size=0.2, shuffle=True)
train, val = train_test_split(train, test_size=0.25, shuffle=True)

print("Train Size:", len(train), "; Validation Size:", len(val), "; Test Size:", len(test))    
split=dict({'train': np.array(train),
            'val': np.array(val),
            'test': np.array(test)}
          )

[ Using Seed :  71582788  ]
Train Size: 156 ; Validation Size: 52 ; Test Size: 52


In [10]:
def save_model_parameters(self):
    """
    Saves model parameters to a file in results directory
    """
    path = os.path.join(self.out_dir, "model.pth")

    torch.save(self.model.state_dict(), path)

def load_model_parameters(self, path=''):
    """
    Loads model parameters from a supplied path or a
    results directory
    """
    if not path:
        model_path = os.path.join(self.out_dir, "model.pth")
    else:
        model_path = path

    if os.path.exists(model_path):
        self.model.load_state_dict(torch.load(model_path))
    else:
        raise Exception(f"Could not find path {model_path}")

In [11]:
# Do we have CUDA available?
if not torch.cuda.is_available():
    print("WARNING: No CUDA device is found. This may take significantly longer!")
else:
    print("GPU Status:",torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

GPU Status: True


In [12]:
n_epochs = c.n_epochs
split = split
_time_start = ""
_time_end = ""
epoch = 0
name = c.name

# Create output folders
# _{time.strftime("%Y_%m_%d_%H%M", time.gmtime())}
dirname = f'{c.name}'
out_dir = os.path.join(c.test_results_dir, dirname)
os.makedirs(out_dir, exist_ok=True)
out_dir

'out/basic_unet_seed1'

In [13]:
# Reference: https://github.com/hayashimasa/UNet-PyTorch/blob/main/augmentation.py
class DoubleHorizontalFlip:
    """Apply horizontal flips to both image and segmentation mask."""

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, mask, weight=None):
        p = random.random()
        if p < self.p:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if weight is None:
            return image, mask
        elif p > self.p:
            weight = TF.hflip(weight)
        return image, mask, weight

    def __repr__(self):
        return self.__class__.__name__ + f'(p={self.p})' 

class DoubleElasticTransform:
    """Based on implimentation on
    https://gist.github.com/erniejunior/601cdf56d2b424757de5"""

    def __init__(self, alpha=250, sigma=10, p=0.5, seed=None, randinit=True):
        if not seed:
            seed = random.randint(1, 100)
        self.random_state = np.random.RandomState(seed)
        self.alpha = alpha
        self.sigma = sigma
        self.p = p
        self.randinit = randinit

    def __call__(self, image, mask, weight=None):
        if random.random() < self.p:
            if self.randinit:
                seed = random.randint(1, 100)
                self.random_state = np.random.RandomState(seed)
                self.alpha = random.uniform(100, 300)
                self.sigma = random.uniform(10, 15)
                # print(self.alpha)
                # print(self.sigma)

            dim = image.shape
            dx = self.alpha * gaussian_filter(
                (self.random_state.rand(*dim[1:]) * 2 - 1),
                self.sigma,
                mode="constant",
                cval=0
            )
            dy = self.alpha * gaussian_filter(
                (self.random_state.rand(*dim[1:]) * 2 - 1),
                self.sigma,
                mode="constant",
                cval=0
            )
            image = image.view(*dim[1:]).numpy()
            mask = mask.view(*dim[1:]).numpy()
            x, y = np.meshgrid(np.arange(dim[1]), np.arange(dim[2]))
            indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
            image = map_coordinates(image, indices, order=1)
            mask = map_coordinates(mask, indices, order=1)
            image, mask = image.reshape(dim), mask.reshape(dim)
            image, mask = torch.Tensor(image), torch.Tensor(mask)
            if weight is None:
                return image, mask
            weight = weight.view(*dim[1:]).numpy()
            weight = map_coordinates(weight, indices, order=1)
            weight = weight.reshape(dim)
            weight = torch.Tensor(weight)

        return (image, mask) if weight is None else (image, mask, weight)

class DoubleCompose(transforms.Compose):

    def __call__(self, image, mask, weight=None):
        if weight is None:
            for t in self.transforms:
                image, mask = t(image, mask)
            return image, mask
        for t in self.transforms:
            image, mask, weight = t(image, mask, weight)
        return image, mask, weight

In [14]:
set_seeds(seed_factor, experiment_id)

# reference: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
# https://github.com/hayashimasa/UNet-PyTorch
# https://gist.github.com/ernestum/601cdf56d2b424757de5

train_transform = DoubleCompose([
    # DoubleElasticTransform(),
    # DoubleHorizontalFlip(p = 0.25),
])

valid_transform = DoubleCompose([
    # DoubleHorizontalFlip(p = 0.25),
])

[ Using Seed :  71582788  ]


In [15]:
train_df = copy.deepcopy(data[split["train"]])
val_df   = copy.deepcopy(data[split["val"]])
test_df  = copy.deepcopy(data[split["test"]])

train_loader = DataLoader(SlicesDataset(train_df, train_transform), 
                          batch_size=c.batch_size, shuffle=True, num_workers=0)
val_loader   = DataLoader(SlicesDataset(val_df, valid_transform), 
                          batch_size=c.batch_size, shuffle=True, num_workers=0)

# we will access volumes directly for testing
val_data  = val_df
test_data = test_df

In [16]:
val_loss_list = []
val_dc_list = []
val_dc_ap_mean_list = []
val_jc_list = []
val_jc_ap_mean_list = []

def train(epoch, model):
    """
    This method is executed once per epoch and takes 
    care of model weight update cycle
    """
    global val_loss_list, val_dc_list, val_jc_list
    
    print(f"\nTraining epoch {epoch}...")
    model.train()

    # Loop over our minibatches
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()

        data = batch['images'].to('cuda')
        target = batch['segs'].to('cuda')

        # prediction
        prediction = model(data)

        prediction_softmax = F.softmax(prediction, dim=1)
        loss = loss_function(prediction, target[:, 0, :, :])

        loss.backward()
        optimizer.step()

        if (i % 10) == 0:
            # Output to console on every 10th batch
            print(f"\nEpoch: {epoch} Train loss: {loss}, {100*(i+1)/len(train_loader):.1f}% complete")

            counter = 100*epoch + 100*(i/len(train_loader))

        print(".", end='')

    print("\nTraining complete")
    
    print(f"\nValidating epoch {epoch}...")

    # Turn off gradient accumulation by switching model to "eval" mode
    loss_list = []
    dc_list = []
    dc_mean_list = []
    jc_list = []
    jc_mean_list = []
    
    model.eval()
    
    with torch.no_grad():

        for i, batch in enumerate(val_loader):

            data = batch['images'].to('cuda')
            target = batch['segs'].to('cuda')

            prediction = model(data)
            prediction_softmax = F.softmax(prediction, dim=1)
            loss = loss_function(prediction, target[:, 0, :, :])
            # loss.requires_grad = True
            # loss.backward()
            print(f"Batch {i}. Loss {loss}") # Data shape {data.shape} 

            loss_list.append(loss.item())
    
    model.eval()
    
    with torch.no_grad():
        
        inference_agent = UNetInferenceAgent(model=model, device=device)
        
        for i, x in enumerate(val_data):

            gt = x["seg"]   # val image ground truth        
            ti = x["image"] # val image data

            mask3d = np.zeros(ti.shape)
            pred = inference_agent.single_volume_inference(ti)
            mask3d = np.array(torch.argmax(pred, dim=1))

            dc, dc_a, dc_p, _ = Dice3d(mask3d, gt)
            dc_list.append(dc)
            dc_mean_list.append((dc_a+dc_p)/2)

            jc, jc_a, jc_p = Jaccard3d(mask3d, gt)
            jc_list.append(jc)
            jc_mean_list.append((jc_a+jc_p)/2)
            
    scheduler.step(np.mean(loss_list))
    
    #print("Epoch DC:", np.sum(dc_list))
    
    val_loss_list.append(np.sum(loss_list))
    val_dc_list.append(np.sum(dc_list))
    val_dc_ap_mean_list.append(np.sum(dc_mean_list))
    val_jc_list.append(np.sum(jc_list))
    val_jc_ap_mean_list.append(np.sum(jc_mean_list))
    
    measure_ = "dc_mean" # loss
    
    if measure_ == "loss":
        if epoch==0:
            min_val_lose = np.min(val_loss_list)
            max_val_dc = np.max(val_dc_list)
        else:
            min_val_lose = np.min(val_loss_list[:-1])
            max_val_dc = np.max(val_dc_list[:-1])

        current_val_loss = np.sum(loss_list)
        current_val_dc = np.sum(dc_list)

        print(f"Current Validation Loss {round(current_val_loss,8)} | Minimum Validation Loss {round(min_val_lose,8)}")
        print(f"Current Validation DC {round(current_val_dc,8)} | Maximum Validation DC {round(max_val_dc,8)}")

        if current_val_loss < min_val_lose:
            print(f"Current Validation Loss improved from {round(min_val_lose,8)} to {round(current_val_loss,8)}")
            path = os.path.join(out_dir, f"model_epoch{epoch+1}.pth")
            torch.save(model.state_dict(), path)

    elif measure_ == "dc":
        if epoch==0:
            max_val_dc = np.max(val_dc_list)
        else:
            max_val_dc = np.max(val_dc_list[:-1])

        current_val_dc = np.sum(dc_list)

        print(f"Current Validation DC {round(current_val_dc,8)} | Maximum Validation DC {round(max_val_dc,8)}")

        if current_val_dc > max_val_dc:
            print(f"Current Validation DC improved from {round(max_val_dc,8)} to {round(current_val_dc,8)}")
            path = os.path.join(out_dir, f"model_epoch{epoch+1}.pth")
            torch.save(model.state_dict(), path)

    elif measure_ == "dc_mean":
        if epoch==0:
            max_val_dc = np.max(val_dc_ap_mean_list)
            max_val_dc_all = np.max(val_dc_list)
        else:
            max_val_dc = np.max(val_dc_ap_mean_list[:-1])
            max_val_dc_all = np.max(val_dc_list[:-1])

        current_val_dc = np.sum(dc_mean_list)

        print(f"Current Validation AP Mean DC {round(current_val_dc,8)} | Maximum Validation AP Mean DC {round(max_val_dc,8)}")
        print(f"Maximum Validation DC {round(max_val_dc_all,8)}")
        if current_val_dc > max_val_dc:
            print(f"Current Validation AP Mean DC improved from {round(max_val_dc,8)} to {round(current_val_dc,8)}")
            path = os.path.join(out_dir, f"model_epoch{epoch+1}.pth")
            torch.save(model.state_dict(), path)

    elif measure_ == "jc":
        if epoch==0:
            max_val_jc = np.max(val_jc_list)
        else:
            max_val_jc = np.max(val_jc_list[:-1])

        current_val_jc = np.sum(jc_list)

        print(f"Current Validation JC {round(current_val_jc,8)} | Maximum Validation JC {round(max_val_jc,8)}")

        if current_val_jc > max_val_jc:
            print(f"Current Validation JC improved from {round(max_val_jc,8)} to {round(current_val_jc,8)}")
            path = os.path.join(out_dir, f"model_epoch{epoch+1}.pth")
            torch.save(model.state_dict(), path)

    print("\nValidation complete")

In [17]:
def run_test(model):
    """
    This runs test cycle on the test dataset.
    Note that process and evaluations are quite different
    Here we are computing a lot more metrics and returning
    a dictionary that could later be persisted as JSON
    """
    model.eval()

    inference_agent = UNetInferenceAgent(model=model, device=device)

    out_dict = {}
    out_dict["volume_stats"] = []
    
    dc_list = []
    dc_anterior_list = []
    dc_posterior_list = []
    
    jc_list = []
    jc_anterior_list = []
    jc_posterior_list = []
    
    sens_list = []
    sens_anterior_list = []
    sens_posterior_list = []
    
    spec_list = []
    spec_anterior_list = []
    spec_posterior_list = []
    # f1_list = []

    # for every in test set
    for i, x in enumerate(test_data):

        gt = x["seg"]   # test image ground truth        
        ti = x["image"] # test image data
        original_filename = x['filename'] # test image file name
        pred_filename = 'predicted_'+x['filename'] # test image file name

        file_path = os.path.join("../../data","images",original_filename)

        original_images = nib.load(file_path)

        mask3d = np.zeros(ti.shape)
        pred = inference_agent.single_volume_inference(ti)
        mask3d = np.array(torch.argmax(pred, dim=1))

        # Save predicted labels to local environment for further verification 
        # with the original image NIFTI coordinate system
        pred_coord = nib.Nifti1Image(mask3d, original_images.affine, dtype=np.int16)    

        pred_out_path = os.path.join("","preds","seed"+str(experiment_id))
        pred_out_file = os.path.join(pred_out_path,pred_filename)

        if not os.path.exists(pred_out_path):
            os.makedirs(pred_out_path)

        nib.save(pred_coord, pred_out_file)

        dc, dc_a, dc_p, _ = Dice3d(mask3d, gt)
        dc_list.append(dc)
        dc_anterior_list.append(dc_a)
        dc_posterior_list.append(dc_p)

        jc, jc_a, jc_p = Jaccard3d(mask3d, gt)
        jc_list.append(jc)
        jc_anterior_list.append(jc_a)
        jc_posterior_list.append(jc_p)

        sens, sens_a, sens_p = Sensitivity(mask3d, gt)
        sens_list.append(sens)
        sens_anterior_list.append(sens_a)
        sens_posterior_list.append(sens_p)

        spec, spec_a, spec_p = Specificity(mask3d, gt)
        spec_list.append(spec)
        spec_anterior_list.append(spec_a)
        spec_posterior_list.append(spec_p)

        # f1 = F1_score(mask3d, gt)
        # f1_list.append(f1)

        out_dict["volume_stats"].append({
            "filename": x['filename'],
            
            "dice": dc,
            "dice_anterior": dc_a,
            "dice_posterior": dc_p,
            
            "jaccard": jc,
            "jaccard_anterior": jc_a,
            "jaccard_posterior": jc_p,
            
            "sensitivity": sens,
            "sensitivity_anterior": sens_a,
            "sensitivity_posterior": sens_p,
            
            "specificity": spec,
            "specificity_anterior": spec_a,
            "specificity_posterior": spec_p,
            # "f1": f1,
            })

        print(f"{x['filename']} Dice {dc:.4f}, Dice Anterior {dc_a:.4f}, Dice Posterior {dc_p:.4f}, Jaccard {jc:.4f}, Sensitivity {sens:.4f}, and Specificity {spec:.4f}. {100*(i+1)/len(test_data):.2f}% complete")

    avg_dc = np.mean(dc_list)
    avg_dc_a = np.mean(dc_anterior_list)
    avg_dc_p = np.mean(dc_posterior_list)

    avg_jc = np.mean(jc_list)
    avg_jc_a = np.mean(jc_anterior_list)
    avg_jc_p = np.mean(jc_posterior_list)
    
    avg_sens = np.mean(sens_list)
    avg_sens_a = np.mean(sens_anterior_list)
    avg_sens_p = np.mean(sens_posterior_list)
    
    avg_spec = np.mean(spec_list)
    avg_spec_a = np.mean(spec_anterior_list)
    avg_spec_p = np.mean(spec_posterior_list)
    
    # avg_f1 = np.mean(f1_list)

    out_dict["overall"] = {
        "mean_dice": avg_dc,
        "mean_dice_anterior": avg_dc_a,
        "mean_dice_posterior": avg_dc_p,
        "dice_ap_mean": np.mean([avg_dc_a,avg_dc_p]),
        
        "mean_jaccard": avg_jc,
        "mean_jaccard_anterior": avg_jc_a,
        "mean_jaccard_posterior": avg_jc_p,
        
        "mean_sensitivity": avg_sens,
        "mean_sensitivity_anterior": avg_sens_a,
        "mean_sensitivity_posterior": avg_sens_p,
        
        "mean_specificity": avg_spec,
        "mean_specificity_anterior": avg_spec_a,
        "mean_specificity_posterior": avg_spec_p,
        # "mean_f1": avg_f1,
        }

    print("\nTesting complete.")
    print("------------------------------")
    print(f"Average Dice {avg_dc:.4f}, Average Dice Anterior {avg_dc_a:.4f}, Average Dice Posterior {avg_dc_p:.4f}, Average Dice AP Mean {(avg_dc_a+avg_dc_p)/2:.4f}, Average Jaccard {avg_jc:.4f}, Average Sensitivity {avg_sens:.4f}, and Average Specificity {avg_spec:.4f}")

    return out_dict

In [18]:
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        # downconv
        pool = nn.MaxPool2d(2, stride=2)
        conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
        conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)

        # upconv
        conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)
        conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)

        if outermost:
            final = nn.Conv2d(out_channels, num_classes, kernel_size=1)
            down = [conv1, conv2]
            up = [conv3, conv4, final]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(in_channels*2, in_channels,
                                        kernel_size=2, stride=2)
            model = [pool, conv1, conv2, upconv]
        else:
            upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2)

            down = [pool, conv1, conv2]
            up = [conv3, conv4, upconv]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    @staticmethod
    def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            norm_layer(out_channels),
            nn.LeakyReLU(inplace=True))
        return layer

    @staticmethod
    def expand(in_channels, out_channels, kernel_size=3):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        return layer

    @staticmethod
    def center_crop(layer, target_width, target_height):
        batch_size, n_channels, layer_width, layer_height = layer.size()
        xy1 = (layer_width - target_width) // 2
        xy2 = (layer_height - target_height) // 2
        return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            crop = self.center_crop(self.model(x), x.size()[2], x.size()[3])
            return torch.cat([x, crop], 1)

In [19]:
class UNet(nn.Module):
    def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d):
        # norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UNet, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
                                             num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)
        for i in range(1, num_downs):
            unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
                                                 out_channels=initial_filter_size * 2 ** (num_downs-i),
                                                 num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,
                                             num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,
                                             outermost=True)

        self.model = unet_block

    def forward(self, x):
        return self.model(x)

In [20]:
set_seeds(seed_factor, experiment_id)
model = UNet(num_classes=3, initial_filter_size=64, num_downs=4)

model.to(device)

if device.type == "cuda":
    model = torch.nn.DataParallel(model)

loss_function = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=c.learning_rate)
# eps (float) â€“ Minimal decay applied to lr. 
# If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True, eps=1e-20)

[ Using Seed :  71582788  ]


In [21]:
def get_n_params(model):

    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

get_n_params(model)

31030723

In [21]:
_time_start = time.time()
set_seeds(seed_factor, experiment_id)
for i in range(50):
    train(epoch=i, model=model)

_time_end = time.time()
print(f"Run complete. Total time: {time.strftime('%H:%M:%S', time.gmtime(_time_end - _time_start))}")

[ Using Seed :  71582788  ]

Training epoch 0...

Epoch: 0 Train loss: 1.055749535560608, 1.1% complete
..........
Epoch: 0 Train loss: 0.16964256763458252, 12.6% complete
..........
Epoch: 0 Train loss: 0.0869569256901741, 24.1% complete
..........
Epoch: 0 Train loss: 0.07537630200386047, 35.6% complete
..........
Epoch: 0 Train loss: 0.06616166234016418, 47.1% complete
..........
Epoch: 0 Train loss: 0.04276079311966896, 58.6% complete
..........
Epoch: 0 Train loss: 0.050687406212091446, 70.1% complete
..........
Epoch: 0 Train loss: 0.04397191107273102, 81.6% complete
..........
Epoch: 0 Train loss: 0.02751169539988041, 93.1% complete
.......
Training complete

Validating epoch 0...
Batch 0. Loss 0.03216904029250145
Batch 1. Loss 0.03226707875728607
Batch 2. Loss 0.03060026653110981
Batch 3. Loss 0.040598802268505096
Batch 4. Loss 0.03779638186097145
Batch 5. Loss 0.03463990241289139
Batch 6. Loss 0.03537403792142868
Batch 7. Loss 0.037249695509672165
Batch 8. Loss 0.0393788777291

.......
Training complete

Validating epoch 4...
Batch 0. Loss 0.016492921859025955
Batch 1. Loss 0.013751785270869732
Batch 2. Loss 0.012152030132710934
Batch 3. Loss 0.014509651809930801
Batch 4. Loss 0.017340969294309616
Batch 5. Loss 0.017274653539061546
Batch 6. Loss 0.0162795502692461
Batch 7. Loss 0.012872434221208096
Batch 8. Loss 0.013331882655620575
Batch 9. Loss 0.015448863618075848
Batch 10. Loss 0.014994904398918152
Batch 11. Loss 0.014806190505623817
Batch 12. Loss 0.012778949923813343
Batch 13. Loss 0.014600779861211777
Batch 14. Loss 0.016192782670259476
Batch 15. Loss 0.018945002928376198
Batch 16. Loss 0.014050074853003025
Batch 17. Loss 0.015744756907224655
Batch 18. Loss 0.013200989924371243
Batch 19. Loss 0.011862007901072502
Batch 20. Loss 0.014370213262736797
Batch 21. Loss 0.018234629184007645
Batch 22. Loss 0.016084538772702217
Batch 23. Loss 0.016106151044368744
Batch 24. Loss 0.01580096036195755
Batch 25. Loss 0.014223374426364899
Batch 26. Loss 0.01436166558

Batch 16. Loss 0.01738051138818264
Batch 17. Loss 0.016847608610987663
Batch 18. Loss 0.017040885984897614
Batch 19. Loss 0.01331520639359951
Batch 20. Loss 0.013407624326646328
Batch 21. Loss 0.014204456470906734
Batch 22. Loss 0.014599418267607689
Batch 23. Loss 0.01535065472126007
Batch 24. Loss 0.014696854166686535
Batch 25. Loss 0.016181712970137596
Batch 26. Loss 0.011269218288362026
Batch 27. Loss 0.011945679783821106
Batch 28. Loss 0.012746275402605534
Current Validation AP Mean DC 45.20473049 | Maximum Validation AP Mean DC 44.73743788
Maximum Validation DC 46.21535485
Current Validation AP Mean DC improved from 44.73743788 to 45.20473049

Validation complete

Training epoch 9...

Epoch: 9 Train loss: 0.008464332669973373, 1.1% complete
..........
Epoch: 9 Train loss: 0.010789196006953716, 12.6% complete
..........
Epoch: 9 Train loss: 0.01072657946497202, 24.1% complete
..........
Epoch: 9 Train loss: 0.011522332206368446, 35.6% complete
..........
Epoch: 9 Train loss: 0.0111

..........
Epoch: 13 Train loss: 0.008431646972894669, 24.1% complete
..........
Epoch: 13 Train loss: 0.008686269633471966, 35.6% complete
..........
Epoch: 13 Train loss: 0.007753417827188969, 47.1% complete
..........
Epoch: 13 Train loss: 0.007835796102881432, 58.6% complete
..........
Epoch: 13 Train loss: 0.010843868367373943, 70.1% complete
..........
Epoch: 13 Train loss: 0.008174520917236805, 81.6% complete
..........
Epoch: 13 Train loss: 0.010309431701898575, 93.1% complete
.......
Training complete

Validating epoch 13...
Batch 0. Loss 0.014743686653673649
Batch 1. Loss 0.014961124397814274
Batch 2. Loss 0.014322347939014435
Batch 3. Loss 0.017053136602044106
Batch 4. Loss 0.015568814240396023
Batch 5. Loss 0.015975836664438248
Batch 6. Loss 0.015324709936976433
Batch 7. Loss 0.013157888315618038
Batch 8. Loss 0.012589134275913239
Batch 9. Loss 0.01405167393386364
Batch 10. Loss 0.013699720613658428
Batch 11. Loss 0.011730986647307873
Batch 12. Loss 0.01329169049859047
Batc

Batch 1. Loss 0.013317720964550972
Batch 2. Loss 0.014977630227804184
Batch 3. Loss 0.016823239624500275
Batch 4. Loss 0.02063438668847084
Batch 5. Loss 0.015344977378845215
Batch 6. Loss 0.015413099899888039
Batch 7. Loss 0.014477328397333622
Batch 8. Loss 0.012898465618491173
Batch 9. Loss 0.015698393806815147
Batch 10. Loss 0.018572421744465828
Batch 11. Loss 0.016285177320241928
Batch 12. Loss 0.015180273912847042
Batch 13. Loss 0.014955602586269379
Batch 14. Loss 0.013392829336225986
Batch 15. Loss 0.01619706302881241
Batch 16. Loss 0.018608063459396362
Batch 17. Loss 0.01267162524163723
Batch 18. Loss 0.014594155363738537
Batch 19. Loss 0.015031712129712105
Batch 20. Loss 0.01421604584902525
Batch 21. Loss 0.018775662407279015
Batch 22. Loss 0.01506762858480215
Batch 23. Loss 0.012578858993947506
Batch 24. Loss 0.019046630710363388
Batch 25. Loss 0.01766008324921131
Batch 26. Loss 0.020952625200152397
Batch 27. Loss 0.014937667176127434
Batch 28. Loss 0.012178434990346432
Current

Batch 22. Loss 0.014443938620388508
Batch 23. Loss 0.016700461506843567
Batch 24. Loss 0.018118131905794144
Batch 25. Loss 0.017872439697384834
Batch 26. Loss 0.01888856664299965
Batch 27. Loss 0.022059623152017593
Batch 28. Loss 0.012083862908184528
Current Validation AP Mean DC 45.05146321 | Maximum Validation AP Mean DC 45.28226554
Maximum Validation DC 46.67858658

Validation complete

Training epoch 22...

Epoch: 22 Train loss: 0.006635845638811588, 1.1% complete
..........
Epoch: 22 Train loss: 0.006146200932562351, 12.6% complete
..........
Epoch: 22 Train loss: 0.006315212696790695, 24.1% complete
..........
Epoch: 22 Train loss: 0.007590541150420904, 35.6% complete
..........
Epoch: 22 Train loss: 0.006443270482122898, 47.1% complete
..........
Epoch: 22 Train loss: 0.0046949381940066814, 58.6% complete
..........
Epoch: 22 Train loss: 0.006877305917441845, 70.1% complete
..........
Epoch: 22 Train loss: 0.004903652239590883, 81.6% complete
..........
Epoch: 22 Train loss: 0.0

..........
Epoch: 26 Train loss: 0.006790658459067345, 70.1% complete
..........
Epoch: 26 Train loss: 0.0070107365027070045, 81.6% complete
..........
Epoch: 26 Train loss: 0.005967090371996164, 93.1% complete
.......
Training complete

Validating epoch 26...
Batch 0. Loss 0.01649881713092327
Batch 1. Loss 0.016465242952108383
Batch 2. Loss 0.016941005364060402
Batch 3. Loss 0.017211973667144775
Batch 4. Loss 0.0158708319067955
Batch 5. Loss 0.016232576221227646
Batch 6. Loss 0.017269087955355644
Batch 7. Loss 0.017080850899219513
Batch 8. Loss 0.015066277235746384
Batch 9. Loss 0.019804853945970535
Batch 10. Loss 0.01779145933687687
Batch 11. Loss 0.01888125017285347
Batch 12. Loss 0.014147325418889523
Batch 13. Loss 0.013752075843513012
Batch 14. Loss 0.019256072118878365
Batch 15. Loss 0.018923470750451088
Batch 16. Loss 0.01657470129430294
Batch 17. Loss 0.015846816822886467
Batch 18. Loss 0.013376295566558838
Batch 19. Loss 0.0173171479254961
Batch 20. Loss 0.020395971834659576
B

Batch 13. Loss 0.0157126747071743
Batch 14. Loss 0.01712738163769245
Batch 15. Loss 0.018360471352934837
Batch 16. Loss 0.01626560091972351
Batch 17. Loss 0.019205622375011444
Batch 18. Loss 0.014142856933176517
Batch 19. Loss 0.01716024987399578
Batch 20. Loss 0.015022937208414078
Batch 21. Loss 0.019270554184913635
Batch 22. Loss 0.017750713974237442
Batch 23. Loss 0.01709337718784809
Batch 24. Loss 0.020222589373588562
Batch 25. Loss 0.02328125387430191
Batch 26. Loss 0.02152642048895359
Batch 27. Loss 0.019916148856282234
Batch 28. Loss 0.014527395367622375
Current Validation AP Mean DC 45.02439341 | Maximum Validation AP Mean DC 45.28226554
Maximum Validation DC 46.67858658

Validation complete

Training epoch 31...

Epoch: 31 Train loss: 0.006574876140803099, 1.1% complete
..........
Epoch: 31 Train loss: 0.006564803887158632, 12.6% complete
..........
Epoch: 31 Train loss: 0.006637916900217533, 24.1% complete
..........
Epoch: 31 Train loss: 0.00688869459554553, 35.6% complete
.

..........
Epoch: 35 Train loss: 0.005598034244030714, 12.6% complete
..........
Epoch: 35 Train loss: 0.006181231699883938, 24.1% complete
..........
Epoch: 35 Train loss: 0.006694058421999216, 35.6% complete
..........
Epoch: 35 Train loss: 0.006951282266527414, 47.1% complete
..........
Epoch: 35 Train loss: 0.006942233070731163, 58.6% complete
..........
Epoch: 35 Train loss: 0.00666380301117897, 70.1% complete
..........
Epoch: 35 Train loss: 0.0069564939476549625, 81.6% complete
..........
Epoch: 35 Train loss: 0.007180798798799515, 93.1% complete
.......
Training complete

Validating epoch 35...
Batch 0. Loss 0.018354106694459915
Batch 1. Loss 0.012289020232856274
Batch 2. Loss 0.02116401493549347
Batch 3. Loss 0.022917116060853004
Batch 4. Loss 0.01538698747754097
Batch 5. Loss 0.012789459899067879
Batch 6. Loss 0.018497906625270844
Batch 7. Loss 0.01855531521141529
Batch 8. Loss 0.018161233514547348
Batch 9. Loss 0.017990950495004654
Batch 10. Loss 0.014321698807179928
Batch 1

Batch 4. Loss 0.015978269279003143
Batch 5. Loss 0.01734272390604019
Batch 6. Loss 0.01702059991657734
Batch 7. Loss 0.015983060002326965
Batch 8. Loss 0.018035979941487312
Batch 9. Loss 0.016940472647547722
Batch 10. Loss 0.015885332599282265
Batch 11. Loss 0.01746520958840847
Batch 12. Loss 0.01768764853477478
Batch 13. Loss 0.016749363392591476
Batch 14. Loss 0.018423564732074738
Batch 15. Loss 0.01599244773387909
Batch 16. Loss 0.017256606370210648
Batch 17. Loss 0.015952015295624733
Batch 18. Loss 0.016334695741534233
Batch 19. Loss 0.018865151330828667
Batch 20. Loss 0.017978642135858536
Batch 21. Loss 0.020965764299035072
Batch 22. Loss 0.018160579726099968
Batch 23. Loss 0.017856301739811897
Batch 24. Loss 0.0156431682407856
Batch 25. Loss 0.01747976243495941
Batch 26. Loss 0.019026516005396843
Batch 27. Loss 0.01447012834250927
Batch 28. Loss 0.01927766390144825
Current Validation AP Mean DC 45.02013713 | Maximum Validation AP Mean DC 45.28226554
Maximum Validation DC 46.67858

Batch 28. Loss 0.013961070217192173
Current Validation AP Mean DC 45.01944718 | Maximum Validation AP Mean DC 45.28226554
Maximum Validation DC 46.67858658

Validation complete

Training epoch 44...

Epoch: 44 Train loss: 0.006223106756806374, 1.1% complete
..........
Epoch: 44 Train loss: 0.007634714711457491, 12.6% complete
..........
Epoch: 44 Train loss: 0.006018831394612789, 24.1% complete
..........
Epoch: 44 Train loss: 0.007811223156750202, 35.6% complete
..........
Epoch: 44 Train loss: 0.006596884690225124, 47.1% complete
..........
Epoch: 44 Train loss: 0.0069845207035541534, 58.6% complete
..........
Epoch: 44 Train loss: 0.006002056412398815, 70.1% complete
..........
Epoch: 44 Train loss: 0.006698923651129007, 81.6% complete
..........
Epoch: 44 Train loss: 0.0056399148888885975, 93.1% complete
.......
Training complete

Validating epoch 44...
Batch 0. Loss 0.015684084966778755
Batch 1. Loss 0.01954255998134613
Batch 2. Loss 0.016407761722803116
Batch 3. Loss 0.0187328662

..........
Epoch: 48 Train loss: 0.006096189375966787, 93.1% complete
.......
Training complete

Validating epoch 48...
Batch 0. Loss 0.016181794926524162
Batch 1. Loss 0.018568024039268494
Batch 2. Loss 0.014603356830775738
Batch 3. Loss 0.018802287057042122
Batch 4. Loss 0.019340602681040764
Batch 5. Loss 0.016261694952845573
Batch 6. Loss 0.014216084964573383
Batch 7. Loss 0.01545658614486456
Batch 8. Loss 0.018109487369656563
Batch 9. Loss 0.018974803388118744
Batch 10. Loss 0.016574107110500336
Batch 11. Loss 0.013351657427847385
Batch 12. Loss 0.017027894034981728
Batch 13. Loss 0.01688573695719242
Batch 14. Loss 0.017778653651475906
Batch 15. Loss 0.017447181046009064
Batch 16. Loss 0.022184208035469055
Batch 17. Loss 0.015528670512139797
Batch 18. Loss 0.01690077781677246
Batch 19. Loss 0.020200518891215324
Batch 20. Loss 0.015293719246983528
Batch 21. Loss 0.021165132522583008
Batch 22. Loss 0.01477949321269989
Batch 23. Loss 0.01830805279314518
Batch 24. Loss 0.01702262833714

In [22]:
weight_list = glob(f"out/{dirname}/*")
weight_list.sort(reverse=True, key=os.path.getmtime)
weight_list

['out/basic_unet_seed1/model_epoch16.pth',
 'out/basic_unet_seed1/model_epoch15.pth',
 'out/basic_unet_seed1/model_epoch9.pth',
 'out/basic_unet_seed1/model_epoch8.pth',
 'out/basic_unet_seed1/model_epoch6.pth',
 'out/basic_unet_seed1/model_epoch5.pth',
 'out/basic_unet_seed1/model_epoch4.pth',
 'out/basic_unet_seed1/model_epoch3.pth',
 'out/basic_unet_seed1/model_epoch2.pth',
 'out/basic_unet_seed1/model_epoch13.pth',
 'out/basic_unet_seed1/model_epoch7.pth',
 'out/basic_unet_seed1/model_epoch11.pth',
 'out/basic_unet_seed1/model_epoch14.pth',
 'out/basic_unet_seed1/model_epoch12.pth']

In [23]:
weight_list = glob(f"out/{dirname}/*")
weight_list.sort(reverse=True, key=os.path.getmtime)
best_weight = weight_list[0]
print(best_weight)
model.load_state_dict(torch.load(best_weight))
set_seeds(seed_factor, experiment_id)
test_out = run_test(model)

out/basic_unet_seed1/model_epoch16.pth
[ Using Seed :  71582788  ]
hippocampus_260.nii.gz Dice 0.8978, Dice Anterior 0.8556, Dice Posterior 0.8749, Jaccard 0.8145, Sensitivity 0.8821, and Specificity 0.9982. 1.92% complete
hippocampus_378.nii.gz Dice 0.8930, Dice Anterior 0.8589, Dice Posterior 0.8760, Jaccard 0.8067, Sensitivity 0.9039, and Specificity 0.9976. 3.85% complete
hippocampus_321.nii.gz Dice 0.8905, Dice Anterior 0.8435, Dice Posterior 0.8420, Jaccard 0.8025, Sensitivity 0.8997, and Specificity 0.9974. 5.77% complete
hippocampus_319.nii.gz Dice 0.8985, Dice Anterior 0.9102, Dice Posterior 0.8556, Jaccard 0.8156, Sensitivity 0.9096, and Specificity 0.9979. 7.69% complete
hippocampus_041.nii.gz Dice 0.8362, Dice Anterior 0.8270, Dice Posterior 0.8290, Jaccard 0.7185, Sensitivity 0.8220, and Specificity 0.9962. 9.62% complete
hippocampus_088.nii.gz Dice 0.9149, Dice Anterior 0.9128, Dice Posterior 0.8896, Jaccard 0.8431, Sensitivity 0.9162, and Specificity 0.9979. 11.54% compl

In [24]:
# del test; del train; del val

In [25]:
mean_dice = []
for i in test_out['volume_stats']:
    mean_dice.append(i['dice_anterior'])
    mean_dice.append(i['dice_posterior'])

np.mean(mean_dice)

0.8660525620157116

In [26]:
test_out_path = 'test/test_out_'+book_name.replace("ipynb",'txt')
test_out_path

'test/test_out_Train_UNet_seed1.txt'

In [27]:
with open(test_out_path, 'w') as convert_file:
     convert_file.write(json.dumps(test_out))

In [28]:
import json
  
# reading the data from the file
with open(test_out_path) as f:
    data = f.read()

data_js = json.loads(data)

In [29]:
print(f"Experiment ID: {experiment_id}")
data_js['overall']

Experiment ID: 1


{'mean_dice': 0.8955788039180423,
 'mean_dice_anterior': 0.8687932304178995,
 'mean_dice_posterior': 0.8633118936135236,
 'dice_ap_mean': 0.8660525620157116,
 'mean_jaccard': 0.8117259549545448,
 'mean_jaccard_anterior': 0.7697578409115746,
 'mean_jaccard_posterior': 0.7608864530712249,
 'mean_sensitivity': 0.8829941936590842,
 'mean_sensitivity_anterior': 0.8601237510859008,
 'mean_sensitivity_posterior': 0.8518153621023841,
 'mean_specificity': 0.9979490919754727,
 'mean_specificity_anterior': 0.9986252232556727,
 'mean_specificity_posterior': 0.9986942574434164}

In [34]:
mean_jac = []
for i in test_out['volume_stats']:
    mean_jac.append(i['jaccard_anterior'])
    mean_jac.append(i['jaccard_posterior'])

np.mean(mean_jac)

0.7653221469913998