In [1]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print('Found Google Colab')
    !pip3 install torch torchvision torchsummary
    !pip3 install simpleitk

    # noinspection PyUnresolvedReferences
    from google.colab import drive
    drive.mount('/content/drive')

import numpy as np
import matplotlib.pyplot as plt
import torch
import SimpleITK as sitk

import src.helpers.oars_labels_consts as OARS_LABELS

from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets
from importlib import reload

from src.training_helpers import loss_batch, show_model_info
from src.helpers.prepare_model import prepare_model
from src.helpers.train_loop import train_loop
from src.helpers.get_dataset import get_dataset, get_dataloaders
from src.helpers.get_dataset_info import get_dataset_info
from src.helpers.preview_dataset import preview_dataset
from src.helpers.get_bounding_box import get_bounding_box, get_bounding_box_3D, get_bounding_box_3D_size, get_dividable_bounding_box, get_final_bounding_box_slice

MAX_PADDING_SLICES = 160
torch.manual_seed(20)

# [56 177 156] is bounding box size without spinal cord in dataset, so we get bounding box which can be divided by pooling/unpooling layers and in the end still persist size
dataset_max_bounding_box = [56, 177, 156]
desire_bounding_box_size = get_dividable_bounding_box(dataset_max_bounding_box, pooling_layers=3, offset=12)
print('Dataset biggest bounding box wihtout spinal cord', dataset_max_bounding_box)
print('Cut target size', desire_bounding_box_size)

print('Done Init')

Dataset biggest bounding box wihtout spinal cord [56, 177, 156]
Cut target size [72, 192, 168]
Done Init


# Neural Network

## loading low res dataset

In [2]:
filter_labels = OARS_LABELS.OARS_LABELS_LIST
if OARS_LABELS.SPINAL_CORD in filter_labels:
    filter_labels.remove(OARS_LABELS.SPINAL_CORD)

low_res_dataset = get_dataset(dataset_size=50, shrink_factor=16, filter_labels=filter_labels)
low_res_dataset.dilatate_labels(repeat=1)
low_res_dataset.to_numpy()
low_res_dataloaders_obj = get_dataloaders(low_res_dataset, train_size=40, valid_size=5, test_size=5)

get_dataset_info(low_res_dataset, low_res_dataloaders_obj)
train_low_res_dataset, valid_low_res_dataset, test_low_res_dataset = itemgetter('train_dataset', 'valid_dataset', 'test_dataset')(low_res_dataloaders_obj)

CUDA using 16x dataset
normalizing dataset
parsing dataset to numpy
train 40, valid_size 5, test 5, full 50
train indeces [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 20, 21, 22, 23, 24, 28, 30, 31, 32, 33, 34, 35, 36, 37, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
valid indeces [6, 13, 19, 25, 38]
test indeces [16, 26, 27, 29, 39]


In [3]:
preview_dataset(low_res_dataset, preview_index=0, show_hist=False, max_padding_slices=MAX_PADDING_SLICES)

data max 12.505709639268096, min -0.40698009878688973
label max 1, min 0


VBox(children=(HBox(children=(IntSlider(value=101, max=159),)),))

Output()

## training low res model

In [4]:
# preparing model loop params
low_res_model_info = prepare_model(epochs=30, in_channels=8, train_dataset=train_low_res_dataset, valid_dataset=valid_low_res_dataset, test_dataset=test_low_res_dataset)
show_model_info(low_res_model_info, MAX_PADDING_SLICES)

# getting everything necessary for model training
low_res_train_loop_params = {k:v for k,v in low_res_model_info.items() if k not in ['model_total_params', 'model_total_trainable_params']}
# running training loop
train_loop(**low_res_train_loop_params)

low_res_model = itemgetter('model')(low_res_model_info)

Device running "cuda"
max output channels 128
Model number of params: 1193537, trainable 1193537
Running training loop
Batch eval [1] loss 0.95115, dsc 0.04885
Batch eval [2] loss 0.96262, dsc 0.03738
Batch eval [3] loss 0.96249, dsc 0.03751
Batch eval [4] loss 0.96250, dsc 0.03750
Batch eval [5] loss 0.95840, dsc 0.04160
Epoch [1] T 6.97s, deltaT 6.97s, loss: train 0.96122, valid 0.95943, dsc: train 0.03878, valid 0.04057
Batch eval [1] loss 0.94273, dsc 0.05727
Batch eval [2] loss 0.95209, dsc 0.04791
Batch eval [3] loss 0.94611, dsc 0.05389
Batch eval [4] loss 0.95591, dsc 0.04409
Batch eval [5] loss 0.94411, dsc 0.05589
Epoch [2] T 13.98s, deltaT 7.01s, loss: train 0.95455, valid 0.94819, dsc: train 0.04545, valid 0.05181
Batch eval [1] loss 0.94189, dsc 0.05811
Batch eval [2] loss 0.95095, dsc 0.04905
Batch eval [3] loss 0.94386, dsc 0.05614
Batch eval [4] loss 0.95524, dsc 0.04476
Batch eval [5] loss 0.94215, dsc 0.05785
Epoch [3] T 20.97s, deltaT 6.99s, loss: train 0.95045, vali

## loading high/full res dataset

In [5]:
full_res_dataset = get_dataset(dataset_size=5, shrink_factor=1)
full_res_dataset.to_numpy()

print('dataset data and label shapes', low_res_dataset.data_list[0].shape, full_res_dataset.data_list[0].shape)

CUDA using 1x dataset
normalizing dataset
parsing dataset to numpy
dataset data and label shapes (1, 160, 32, 32) (1, 160, 512, 512)


## testing precoarse network


In [6]:
# moving model to cpu and setting to eval mode, preventing model params changes/training
low_res_model.to('cpu')
low_res_model.eval()
print('moved model to cpu')

moved model to cpu


In [11]:
def expand_image(input_img, expand_factor=16): # input numpy shape (1, 1, MAX_PADDING_SLICES, x, x)
    # expanded_input_img = np.resize(tmp_output.copy(), (1, 1, 160, 512, 512))
    # for i in range(MAX_PADDING_SLICES):
    #    expanded_input_img[0, 0, i] = scipy.ndimage.zoom(tmp_output[0, 0, i], expand_factor, order=0) # TODO: zoom is using some interpolation, thus its not exact

    expanded_input_img = np.repeat(np.repeat(input_img, expand_factor, axis=3), expand_factor, axis=4)
        
    return expanded_input_img


def debug_preview_low_expand(model_output_img, exp_model_output_img, img_slice=100):
    # preview of 32x32 segmentation and his expanded 512x512 version
    model_output_img_percents = model_output_img[0, 0, img_slice].sum() / model_output_img[0, 0, img_slice].size
    exp_model_output_img_percents = exp_model_output_img[0, 0, img_slice].sum() / exp_model_output_img[0, 0, img_slice].size
    print(model_output_img_percents, exp_model_output_img_percents)
    
    plt.figure(figsize=(12, 12))
    
    plt.subplot(1, 2, 1)
    plt.imshow(model_output_img[0, 0, img_slice], cmap="gray", vmin=0, vmax=1)
    
    plt.subplot(1, 2, 2)
    plt.imshow(exp_model_output_img[0, 0, img_slice], cmap="gray", vmin=0, vmax=1)

    plt.show()


def debug_preview_cuts(exp_model_output_img, new_bounding_box, data_cut, label_cut):
    def f(slice_index):
        tmp_cut = exp_model_output_img[0, 0, new_bounding_box[0]:new_bounding_box[1] + 1, new_bounding_box[2]:new_bounding_box[3] + 1, new_bounding_box[4]:new_bounding_box[5] + 1]
        
        plt.figure(figsize=(18, 12))

        plt.subplot(1, 3, 1)
        plt.imshow(tmp_cut[slice_index], cmap="gray", vmin=0, vmax=1)
        
        plt.subplot(1, 3, 2)
        plt.imshow(data_cut[slice_index], cmap="gray")

        plt.subplot(1, 3, 3)
        plt.imshow(label_cut[slice_index])

        plt.show()

    slices_count = label_cut.shape[0]-1
    sliceSlider = widgets.IntSlider(min=0, max=slices_count, step=1, value=slices_count // 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider})
    # noinspection PyTypeChecker
    display(ui, out)


def get_high_res_cut(low_res_model, low_res_data_img, full_res_data_img, full_res_label_img, low_res_mask_threshold, desire_bounding_box_size):
    # getting low res segmentation
    exp_low_res_data_img = np.expand_dims(low_res_data_img, axis=0)
    model_output_img = low_res_model(torch.from_numpy(exp_low_res_data_img).float())
    model_output_img = model_output_img.cpu().detach().numpy()
    
    # parsing low res float to int mask
    model_output_img = (model_output_img > low_res_mask_threshold) * 1 # shape (1, 1, 160, 32, 32)

    # expanding low res int mask to high res
    exp_model_output_img = expand_image(model_output_img, expand_factor=16) # shape (1, 1, 160, 512, 512)

    # getting bounding box
    bounding_box = get_bounding_box_3D(exp_model_output_img[0][0])
    new_bounding_box = get_final_bounding_box_slice(bounding_box, desire_bounding_box_size)

    # getting bounding box cut
    data_cut = full_res_data_img[0, new_bounding_box[0]:new_bounding_box[1] + 1, new_bounding_box[2]:new_bounding_box[3] + 1, new_bounding_box[4]:new_bounding_box[5] + 1]
    label_cut = full_res_label_img[new_bounding_box[0]:new_bounding_box[1] + 1, new_bounding_box[2]:new_bounding_box[3] + 1, new_bounding_box[4]:new_bounding_box[5] + 1]
    print('debug, does cut and original label contain the same amount of pixels?', label_cut.sum(), full_res_label_img.sum())

    # debug
    # debug_preview_low_expand(model_output_img, exp_model_output_img)
    print('debug bounding box sizes', get_bounding_box_3D_size(*bounding_box), get_bounding_box_3D_size(*new_bounding_box))
    debug_preview_cuts(exp_model_output_img, new_bounding_box, data_cut, label_cut)
    
    return data_cut, label_cut, new_bounding_box


### getting bounding box cut in high res

In [14]:
dataset_index = 4
low_res_mask_threshold = 0.5
low_res_data_img = low_res_dataset.data_list[dataset_index]
full_res_data_img = full_res_dataset.data_list[dataset_index]
full_res_label_img = full_res_dataset.label_list[dataset_index]

data_cut, label_cut, new_bounding_box = get_high_res_cut(low_res_model, low_res_data_img, full_res_data_img, full_res_label_img, low_res_mask_threshold, desire_bounding_box_size)

debug box delta [21 32 -8]
debug, does cut and original label contain the same amount of pixels? 1494767 1560867
debug bounding box sizes (51, 160, 176) (72, 192, 168)


VBox(children=(HBox(children=(IntSlider(value=35, max=71),)),))

Output()