In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import sys
import re
import os
import torch.optim as optim
import time
import nibabel as nib
import matplotlib.pylab as plt
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy import ndimage
from datetime import datetime
from glob import glob


# 3D UNET
- the model of 3D UNET is provided at <mark> from classes.models.unet3d import UNet3D </mark>
They are few important parameters that are essential to extract features better.
- The UNET model uses 3D convolution. It has 4 layers in the model.
- Default kernel size for Double convolution is 3 or (3x3x3)
- Number of features channels: Increase the number of channels for features enable prediction of classes.
- channel selector 0: (4, 8, 16, 32, 64) and kernel size 3has failed to extract any class but background.
- Channel selector 1, with channels (8, 16, 32, 64, 128) and a kernel size of 3, can effectively segment the kidney. However, it is unable to successfully predict features for tumors and cysts.
- channel selector 2: (16, 32, 64, 128, 256) and kernel size 5, can obtain segmentation for kidney well and tumor can also be predicted. However, feature of cyst still can not be extracted.


In [2]:
base_dir = "./"
raw_dataset_dir = "dataset/"
transformed_dataset_dir_path = "dataset/affine_transformed/"

In [3]:
is_colab = True
if is_colab:
    base_dir = "/content/drive/MyDrive/Colab Notebooks/"
    if not os.path.isdir(base_dir):
        from google.colab import drive
        drive.mount('/content/drive')

raw_dataset_dir = os.path.join(base_dir, raw_dataset_dir)
transformed_dataset_dir_path = os.path.join(base_dir, transformed_dataset_dir_path)

if os.path.isdir(raw_dataset_dir) and os.path.isdir(transformed_dataset_dir_path):
    print("dataset folder exists, OK")
else:
    raise Exception("check path for dataset:{} \n path for transformed dataset: {}"
                    .format(raw_dataset_dir, transformed_dataset_dir_path))




dataset folder exists, OK


In [4]:
sys.path.append(base_dir)
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models.unet3d import UNet3D
from classes.epoch_results import EpochResult

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
training_data = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir =transformed_dataset_dir_path)
test_data = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir =transformed_dataset_dir_path)
print("size of training data:{}    size of testing dat:{}".format(len(training_data), len(test_data)))

size of training data:366    size of testing dat:123


## Reduce Training Cases and Test Cases
- Following is used to reduce number of Training and Test casess

In [7]:
is_simplified = True
# to demo, only 10 test cases are tested.
if is_simplified:
    training_data.case_dirs = training_data.case_dirs[:100]
    training_data.case_names = training_data.case_names[:100]
    test_data.case_dirs = test_data.case_dirs[:10]
    test_data.case_names = test_data.case_names[:10]

In [8]:
channel_selection = 1
ks = 3
is_upsampling = True   # For NOT using CV transpose, use True
model = UNet3D(1, 4, channel_selection=channel_selection, double_conv_kernel_size=ks, is_upsampling = True).to(device)
model._initialize_weights()

## Optimizer or Gradient Descent Model
- Enable choose of ADAM or SGD
- Adjust learning rate decay manually. Higher gamma if there are high number of test data. For 100 cases, gamma 0.95 is used.

In [9]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-3)
is_ADAM = True
if is_ADAM:
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = nn.CrossEntropyLoss(ignore_index=-1)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

In [10]:
continue_from_checkpoint = True
epoch_res = EpochResult()
epoch_start = 0
if continue_from_checkpoint:
    print("Unet3D - loading from trained weight")
    checkpoint_ref_filepath = None
    # this continues from certain training points
    if is_ADAM:
        # checkpoint_ref_filepath = "training_checkpoints/Model_UNET_ch2_ks5_epoch7.pth.tar"
        # checkpoint_ref_filepath = "training_checkpoints/Model_UNET_epoch40.pth.tar"
        checkpoint_ref_filepath = "training_checkpoints/Model_UNET_ch1_ks3_up_epoch30.pth.tar"
    else:
        checkpoint_ref_filepath = "training_checkpoints/Model_UNET_SGD_ch1_ks3_epoch40.pth.tar"
    checkpoint_file = os.path.join(base_dir, checkpoint_ref_filepath)
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['state_dict'])
    # load additional customised info from checkpoint
    optimizer.load_state_dict(checkpoint['optimizer'])
    ep_list = checkpoint['epoch_list']
    loss_list = checkpoint['loss_list']
    lr_list = checkpoint['lr_list']
    epoch_res = EpochResult(_epoch_list =ep_list, _loss_list=loss_list, _lr_list=lr_list)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    epoch_start = epoch_res.epoch_list[-1] + 1
else:
    print("Unet3D - was initialised with weight")


Unet3D - loading from trained weight


## Check Points filename
- it is configurable for channel selector
- kernel size of double conv
- Adamn or SGD info is at the filename path
- Use of CV Transpose or Upsample is also considered in the naming.

## Training params
Batch size used  
- channel selector 0: batch size 6
- channel selector 1: batch size 3
- channel selector 2: batch size 1

## Dataset into Dataloader
- Dataloader allow setting of batch size, which is another useful parameter for training.
- shuffle allow data change of data orders, only useful for data training.

In [11]:
batch_size = 3
total_batches = math.ceil(len(training_data) / batch_size)
num_epochs = 100
model_unet_save_path = os.path.join(base_dir,"training_checkpoints/Model_UNET_ch{}_ks{}_epoch{}.pth.tar")
if is_simplified:
    model_unet_save_path = os.path.join(base_dir,"training_checkpoints/Model_UNET_ch{}_ks{}_up_epoch{}.pth.tar")

if not is_ADAM:
    model_unet_save_path = os.path.join(base_dir,"training_checkpoints/Model_UNET_ch{}_ks{}_SGD_epoch{}.pth.tar")
    if is_simplified:
        model_unet_save_path = os.path.join(base_dir,"training_checkpoints/Model_UNET_ch{}_ks{}_SGD_up_epoch{}.pth.tar")

train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

## Training Loop
- Cross validation during training is commented out. This is because training is extremely costly and the team has already used Colab GPU Tesla T4 for the task.
- Please note that compute unit is not free in Colab.
- The loop save model's weight at every epoch.
- Therefore, it allows termination of training at anytime.

In [12]:
train_time_start = time.time()
batches_per_epoch = len(train_loader)

for epoch in range(epoch_start, num_epochs):
    model.train()
    current_lr = scheduler.get_last_lr()[0]
    for batch_idx, batch in enumerate(train_loader):
        images, masks = batch
        images, masks = images.to(device), masks.to(device)
        masks = masks.long().squeeze(1)
        optimizer.zero_grad()
        outputs = model(images.float())
        loss = criterion(outputs, masks)
        running_loss = loss.item()
        loss.backward()
        optimizer.step()

        total_processed_batches = (epoch - epoch_start) * batches_per_epoch + 1 + batch_idx
        avg_batch_time = (time.time() - train_time_start) / total_processed_batches
        if batch_idx % 5 == 0:
            print("Epoch:{}/{} batch:{}/{}   Loss:{:.4f}  avg batch time:{:.1f} LR={:.6f}".format(epoch, num_epochs, batch_idx, total_batches,running_loss, avg_batch_time, current_lr))
    scheduler.step()
    epoch_res.append_result(epoch, running_loss, current_lr)
    model_checkpoint_path = model_unet_save_path.format(channel_selection, ks, epoch)
    torch.save({'epoch_list': epoch_res.epoch_list, 'loss_list': epoch_res.loss_list,
                'lr_list': epoch_res.lr_list, 'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()},model_checkpoint_path, _use_new_zipfile_serialization=True)
    # Validation after each epoch
    # model.eval()
    # total_loss = 0.0
    # with torch.no_grad():
    #     for batch in test_loader:
    #         images, masks = batch
    #         images, masks = images.to(device), masks.to(device)
    #         masks = masks.long().squeeze(1)

    #         optimizer.zero_grad()
    #         outputs = model(images.float())
    #         loss = criterion(outputs, masks)
    #         total_loss += loss.item()

    # average_loss = total_loss / len(test_loader)
    # print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")

print('Finished Training')


Epoch:31/100 batch:0/34   Loss:0.0592  avg batch time:18.1 LR=0.000204


KeyboardInterrupt: ignored