In [1]:
import sys
from src.consts import IN_COLAB, MAX_PADDING_SLICES, DATASET_MAX_BOUNDING_BOX, DATASET_PADDING_VALUE

if IN_COLAB:
    print('Found Google Colab')
    !pip3 install torch torchvision torchsummary
    !pip3 install simpleitk

    # noinspection PyUnresolvedReferences
    from google.colab import drive
    drive.mount('/content/drive')

import numpy as np
import matplotlib.pyplot as plt
import torch
import SimpleITK as sitk
import cv2 as cv

from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets
from functools import reduce

import src.helpers.oars_labels_consts as OARS_LABELS
from src.training_helpers import loss_batch, show_model_info
from src.helpers.prepare_model import prepare_model
from src.helpers.train_loop import train_loop
from src.helpers.get_dataset import get_dataset, get_dataloaders
from src.helpers.get_dataset_info import get_dataset_info
from src.helpers.preview_dataset import preview_dataset
from src.helpers.get_bounding_box import get_bounding_box, get_bounding_box_3D, get_bounding_box_3D_size

torch.manual_seed(20)
print('Done Init')

# %matplotlib widget

Done Init


# Full Datset loading

In [2]:
from pathlib import Path

dataset_shrink = 1
root_dir_path = Path(f'./data/{"HaN_OAR"}_shrink{dataset_shrink}x_padded160')
size = 50
label_list = []

for i in range(size):
    label_filepath = Path.joinpath(root_dir_path, f'./{i + 1}/label.nii.gz')
    label = sitk.ReadImage(str(label_filepath))
    label_list.append(label)


## max sizes of individual orgas

In [4]:
for label_name, label_const in list(OARS_LABELS.OARS_LABELS_DICT.items()):
    tmp_label_list = [None] * size
    tmp_label_list_box = [None] * size
    
    for i in range(size):
        tmp_label_list[i] = label_list[i] == label_const
        tmp_label_list[i] = sitk.GetArrayFromImage(tmp_label_list[i])
        tmp_label_list_box[i] = get_bounding_box_3D_size(*get_bounding_box_3D(tmp_label_list[i]))
        
    print(label_name, label_const)
    print(f'max size', np.array(tmp_label_list_box).max(axis=0))


BRAIN_STEM 1
max size [21 37 44]
EYE_L 2
max size [10 27 27]
EYE_R 3
max size [11 27 28]
LENS_L 4
max size [ 4  7 11]
LENS_R 5
max size [ 3  7 11]
OPT_NERVE_L 6
max size [ 3 34 17]
OPT_NERVE_R 7
max size [ 5 33 18]
OPT_CHIASMA 8
max size [ 2 20 31]
TEMPORAL_LOBES_L 9
max size [ 22 107  58]
TEMPORAL_LOBES_R 10
max size [ 21 108  57]
PITUITARY 11
max size [ 3 12 16]
PAROTID_GLAND_L 12
max size [22 71 39]
PAROTID_GLAND_R 13
max size [24 76 42]
INNER_EAR_L 14
max size [ 5 23 23]
INNER_EAR_R 15
max size [ 5 23 26]
MID_EAR_L 16
max size [17 61 39]
MID_EAR_R 17
max size [15 61 40]
T_M_JOINT_L 18
max size [ 8 21 30]
T_M_JOINT_R 19
max size [ 7 21 31]
SPINAL_CORD 20
max size [ 89 104  30]
MANDIBLE_L 21
max size [31 99 72]
MANDIBLE_R 22
max size [ 33 102  70]


In [5]:
tmp_label_list = [None] * size
tmp_label_list_box = [None] * size
filter_labels = OARS_LABELS.OARS_LABELS_LIST

for i in range(size):
    tmp_label_list[i] = reduce(lambda a, b: a | (label_list[i] == b), filter_labels, 0)
    tmp_label_list[i] = sitk.GetArrayFromImage(tmp_label_list[i])
    tmp_label_list_box[i] = get_bounding_box_3D_size(*get_bounding_box_3D(tmp_label_list[i]))
    
print('All organs in single mask')
print(f'max size', np.array(tmp_label_list_box).max(axis=0))

All organs in single mask
max size [119 212 156]


In [6]:
tmp_label_list = [None] * size
tmp_label_list_box = [None] * size
filter_labels = OARS_LABELS.OARS_LABELS_LIST
if OARS_LABELS.SPINAL_CORD in filter_labels:
    filter_labels.remove(OARS_LABELS.SPINAL_CORD)

for i in range(size):
    tmp_label_list[i] = reduce(lambda a, b: a | (label_list[i] == b), filter_labels, 0)
    tmp_label_list[i] = sitk.GetArrayFromImage(tmp_label_list[i])
    tmp_label_list_box[i] = get_bounding_box_3D_size(*get_bounding_box_3D(tmp_label_list[i]))
    
print('All organs without spinal cord in single mask')
print(f'max size', np.array(tmp_label_list_box).max(axis=0))

All organs without spinal cord in single mask
max size [ 56 177 156]


## Bounding box preview

In [107]:
data_index = 2
label = tmp_label_list[data_index]
data = sitk.GetArrayFromImage(sitk.ReadImage(str(Path.joinpath(root_dir_path, f'./{data_index + 1}/data.nii.gz'))))
data = np.stack((data,)*3, axis=-1) + (-1 * data.min())
data = data / data.max()
idx = (label == 1)
data[idx, 0] = 1
data[idx, 1] = 1

box = get_bounding_box_3D(label)

print(f'label max {label.max()}, min {label.min()}')
print(f'box {box}')
print(f'box size', *get_bounding_box_3D_size(*box))


def f(slice_index):
    plt.figure(figsize=(5, 5))

    tmp = data[slice_index].copy()
    if slice_index >= box[0] and slice_index <= box[1]:
        tmp = cv.rectangle(tmp, (box[4], box[2]), (box[5], box[3]), (1, 0, 0), 1)
    plt.imshow(tmp)
    plt.show()


sliceSlider = widgets.IntSlider(min=0, max=MAX_PADDING_SLICES - 1, step=1, value=91)
ui = widgets.VBox([widgets.HBox([sliceSlider])])
out = widgets.interactive_output(f, {'slice_index': sliceSlider})
display(ui, out)

label max 1, min 0
box (55, 105, 156, 314, 183, 333)
box size 50 158 150


VBox(children=(HBox(children=(IntSlider(value=91, max=159),)),))

Output()