In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/My Drive/dp_tomastik/code')
    !bash "/content/drive/My Drive/dp_tomastik/code/scripts/install_libs.sh"

import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import pandas as pd
import logging
import datetime
from torchio import RandomAffine, Compose, ZNormalization

import src.dataset.oars_labels_consts as OARS_LABELS
from src.consts import DATASET_MAX_BOUNDING_BOX, DESIRE_BOUNDING_BOX_SIZE
from src.helpers.threshold_calc_helpers import get_threshold_info_df
from src.helpers.show_model_dataset_pred_preview import show_model_dataset_pred_preview
from src.dataset.get_cut_lists import get_cut_lists
from src.dataset.get_full_res_cut import get_full_res_cut
from src.dataset.get_dataset import get_dataset
from src.dataset.get_dataset_info import get_dataset_info
from src.dataset.preview_dataset import preview_dataset
from src.dataset.get_dataset_transform import get_dataset_transform
from src.model_and_training.prepare_model import prepare_model
from src.model_and_training.train_loop import train_loop
from src.model_and_training.show_model_info import show_model_info
from src.model_and_training.load_checkpoint_model_info import load_checkpoint_model_info
from src.helpers.show_cuda_usage import show_cuda_usage
from src.helpers.get_rescaled_pred import get_rescaled_preds
from src.dataset.split_dataset import split_dataset, copy_split_dataset
from src.helpers.compare_prediction_with_ground_true import compare_prediction_with_ground_true, compare_one_prediction_with_ground_true
from src.helpers.get_img_outliers_pixels import get_img_outliers_pixels
from src.helpers.get_raw_with_prediction import get_raw_with_prediction
from src.model_and_training.getters.get_device import get_device


from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets

torch.manual_seed(20)
logging.basicConfig(filename='logs/pdd_data_check.log', level=logging.DEBUG)

print('Dataset biggest bounding box wihtout spinal cord', DATASET_MAX_BOUNDING_BOX)
print('Cut target size', DESIRE_BOUNDING_BOX_SIZE)
print('Done Init')

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits

Dataset biggest bounding box wihtout spinal cord [56, 177, 156]
Cut target size [72, 192, 168]
Done Init


# PDDCA

In [2]:
def preview_3d_image(img):
    if type(img) is sitk.SimpleITK.Image:
        img = sitk.GetArrayFromImage(img)

    max_slices = img.shape[0]
    def f(slice_index):    
        plt.figure(figsize=(16, 16))
        plt.imshow(img[slice_index])
        plt.show()
        print(f"debug: {img.min()}, {img.max()}")
        print(f"debug: unique {np.unique(img[slice_index])}")

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider})
    # noinspection PyTypeChecker
    display(ui, out)

In [3]:
from pathlib import Path
import SimpleITK as sitk
import nrrd
import os
import cv2
import sys


# PDDCA
d ="./data/PDDCA-1.4.1"
pddca_dir_items = sorted([o for o in os.listdir(d) if os.path.isdir(os.path.join(d,o))])

ignore_items = ['0522c0014', '0522c0077', '0522c0079', '0522c0147', '0522c0159', '0522c0161', '0522c0190', '0522c0226', 
                '0522c0329', '0522c0330', '0522c0427', '0522c0433', '0522c0441', '0522c0455', '0522c0457', '0522c0479']
print(f'Loading {len(pddca_dir_items) - len(ignore_items)} items')

pddca_items = list() 
for item_id in pddca_dir_items:
    if item_id in ignore_items:
        # print(f"pddca {item_id}: ignoring")
        continue
    
    # parsing data
    data_filepath = Path.joinpath(Path(d), f'./{item_id}/img.nrrd')
    pddca_data, header = nrrd.read(data_filepath)
    pddca_data = pddca_data.astype(np.int16)
    pddca_data = np.transpose(pddca_data, axes=[2, 0, 1]).swapaxes(-2,-1)[...,::-1]
    
    # parsing labels
    oar_labels = ["BrainStem", "Chiasm", "Mandible", "OpticNerve_L", "OpticNerve_R", "Parotid_L", "Parotid_R", "Submandibular_L", "Submandibular_R"]
    pddca_label = np.zeros(pddca_data.shape, dtype=np.int8)

    for OAR_INDEX, OAR_KEY in enumerate(oar_labels):
        label_filepath = Path.joinpath(Path(d), f'./{item_id}/structures/{OAR_KEY}.nrrd')
        oar_pddca_label, header = nrrd.read(label_filepath)
        oar_pddca_label = oar_pddca_label.astype(np.int8)
        oar_pddca_label = np.transpose(oar_pddca_label, axes=[2, 0, 1]).swapaxes(-2,-1)[...,::-1]
        pddca_label += oar_pddca_label*(OAR_INDEX+1)
    
    # appending
    pddca_items.append((pddca_data, pddca_label))
    print(f"pddca {item_id}: {pddca_data.max()}, {pddca_data.min()}, {pddca_label.max()}, {pddca_label.min()}, {pddca_data.dtype}, {pddca_label.dtype}, {pddca_data.shape}, {pddca_label.shape}")

print('Done loading')

Loading 32 items
pddca 0522c0001: 3051, -1024, 16, 0, int16, int8, (107, 512, 512), (107, 512, 512)
pddca 0522c0002: 3059, -1024, 9, 0, int16, int8, (130, 512, 512), (130, 512, 512)
pddca 0522c0003: 3059, -1024, 9, 0, int16, int8, (134, 512, 512), (134, 512, 512)
pddca 0522c0009: 3051, -1024, 9, 0, int16, int8, (144, 512, 512), (144, 512, 512)
pddca 0522c0013: 3059, -1024, 9, 0, int16, int8, (138, 512, 512), (138, 512, 512)
pddca 0522c0017: 2117, -1024, 10, 0, int16, int8, (156, 512, 512), (156, 512, 512)
pddca 0522c0057: 2955, -1024, 9, 0, int16, int8, (145, 512, 512), (145, 512, 512)
pddca 0522c0070: 3051, -1024, 9, 0, int16, int8, (128, 512, 512), (128, 512, 512)
pddca 0522c0081: 3075, -1024, 9, 0, int16, int8, (160, 512, 512), (160, 512, 512)
pddca 0522c0125: 3075, -1024, 16, 0, int16, int8, (76, 512, 512), (76, 512, 512)
pddca 0522c0132: 3051, -1024, 9, 0, int16, int8, (115, 512, 512), (115, 512, 512)
pddca 0522c0149: 2980, -1024, 9, 0, int16, int8, (129, 512, 512), (129, 512, 512

In [4]:
item_index = 1
pddca_data, pddca_label = pddca_items[item_index]

max_slices = pddca_data.shape[0]
def f(slice_index):    
    plt.figure(figsize=(20, 20))
    plt.subplot(2, 2, 1)
    plt.imshow(pddca_data[slice_index], cmap="gray")
    plt.subplot(2, 2, 2)
    plt.imshow(pddca_label[slice_index])
    plt.subplot(2, 2, 3)
    
    tmp_combine = np.stack((pddca_data[slice_index],) * 3, axis=-1)
    tmp_combine -= tmp_combine.min()
    tmp_combine = tmp_combine / tmp_combine.max()    
    tmp = (pddca_label[slice_index] > 1) * 1
    tmp_cond = tmp > 0
    tmp_combine[tmp_cond, 0] = tmp[tmp_cond]
    
    plt.imshow(tmp_combine)
    plt.show()
    print(f"debug: {pddca_data.min()}, {pddca_data.max()}")
    print(f"debug: {tmp_combine.min()}, {tmp_combine.max()}")
    print(f"debug: unique {np.unique(pddca_label[slice_index])}")

sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
ui = widgets.VBox([widgets.HBox([sliceSlider])])
out = widgets.interactive_output(f, {'slice_index': sliceSlider})
# noinspection PyTypeChecker
display(ui, out)

VBox(children=(HBox(children=(IntSlider(value=64, max=129),)),))

Output()

# STRUCT SEG 2019

In [5]:
filter_labels = OARS_LABELS.OARS_LABELS_LIST
if OARS_LABELS.SPINAL_CORD in filter_labels:
    filter_labels.remove(OARS_LABELS.SPINAL_CORD)

full_res_dataset = get_dataset(dataset_size=50, shrink_factor=1, filter_labels=filter_labels, unify_labels=False)
full_res_dataset.to_numpy()

CUDA using 1x dataset
filtering labels
filtering labels done
parsing dataset to numpy
numpy parsing done


<src.dataset.han_oars_dataset.HaNOarsDataset at 0x7fe21d16bc70>

# Registration

In [6]:
def get_registration_transform_sitk(fixed, moving, show=True):
    """https://simpleitk.readthedocs.io/en/master/link_ImageRegistrationMethod3_docs.html"""
    def command_iteration(method):
        if (method.GetOptimizerIteration() == 0):
            print("Estimated Scales: ", method.GetOptimizerScales())
        print("{0:3} = {1:7.5f} : {2}".format(method.GetOptimizerIteration(),
                                              method.GetMetricValue(),
                                              method.GetOptimizerPosition()))

    R = sitk.ImageRegistrationMethod()
    R.SetMetricAsCorrelation()
    R.SetOptimizerAsRegularStepGradientDescent(learningRate=2.0,
                                               minStep=1e-4,
                                               numberOfIterations=500,
                                               gradientMagnitudeTolerance=1e-6)
    R.SetOptimizerScalesFromIndexShift()
    tx = sitk.CenteredTransformInitializer(fixed, moving, sitk.Similarity3DTransform())
    R.SetInitialTransform(tx)
    R.SetInterpolator(sitk.sitkLinear)
    R.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(R))

    output_transform = R.Execute(fixed, moving)

    print("-------")
    print(output_transform)
    print("Optimizer stop condition: {0}".format(R.GetOptimizerStopConditionDescription()))
    print(" Iteration: {0}".format(R.GetOptimizerIteration()))
    print(" Metric value: {0}".format(R.GetMetricValue()))

    if show:
        resampler = sitk.ResampleImageFilter()
        resampler.SetReferenceImage(fixed)
        resampler.SetInterpolator(sitk.sitkLinear)
        resampler.SetDefaultPixelValue(1)
        resampler.SetTransform(output_transform)

        out = resampler.Execute(moving)

        simg1 = sitk.Cast(sitk.RescaleIntensity(fixed), sitk.sitkUInt8)
        simg2 = sitk.Cast(sitk.RescaleIntensity(out), sitk.sitkUInt8)
        cimg = sitk.Compose(simg1, simg2, simg1 // 2. + simg2 // 2.)
        preview_3d_image(cimg)

    return output_transform

In [7]:
fixed_data, fixed_label = pddca_items[0]
moving_data, moving_label = full_res_dataset.get_raw_item_with_label_filter(0)

fixed_data = fixed_data.astype(np.float32)
moving_data = moving_data.astype(np.float32)[0]
print(fixed_data.dtype, moving_data.dtype, fixed_data.shape, moving_data.shape)

fixed_sitk = sitk.GetImageFromArray(fixed_data)
moving_sitk = sitk.GetImageFromArray(moving_data)

output_transform = get_registration_transform_sitk(fixed_sitk, moving_sitk)

float32 float32 (107, 512, 512) (160, 512, 512)
Estimated Scales:  (69930.39030288723, 68534.22609481464, 132764.98430403895, 0.999999999998181, 0.999999999998181, 1.0000000000010232, 135615.9304798581)
  0 = -0.24074 : (-9.03999541615783e-05, -0.0006956644143428822, 1.8119611505115703e-05, -0.30084559640486885, -2.5238117596468914, 27.543387185992074, 1.0000102782554676)
  1 = -0.26220 : (-4.734718303809073e-05, -0.0013934479569414964, 1.303580288686701e-05, -0.1724647745955242, -1.5527427578484887, 25.799672357628484, 1.00003384788021)
  2 = -0.28217 : (0.00015230925029894636, -0.002061645242968341, -2.6736294277833434e-05, -0.060547843633535256, -0.35929056635263934, 24.198690002692118, 1.0001547867202332)
  3 = -0.30296 : (0.0005230718264163379, -0.0026729484163188397, -0.00010278069967748866, 0.032421779372570556, 1.0753295267904637, 22.808287294940605, 1.0004377329345466)
  4 = -0.32377 : (0.0010243264180359148, -0.0032354094372611802, -0.00020411087734637675, 0.11489799178957057

VBox(children=(HBox(children=(IntSlider(value=53, max=106),)),))

Output()

In [8]:
def transform(image, transform):
    ref_image = image
    interpolator = sitk.sitkNearestNeighbor
    default_value = 0
    return sitk.Resample(image, ref_image, transform, interpolator, default_value)

fixed_label_sitk = sitk.GetImageFromArray(fixed_label)
trans_fixed_label = transform(fixed_label_sitk, output_transform.GetInverse())
trans_fixed_label_np = sitk.GetArrayFromImage(trans_fixed_label)

In [9]:
max_slices = trans_fixed_label_np.shape[0]

def f(slice_index):
    plt.figure(figsize=(20, 20))
    plt.subplot(2, 2, 1).title.set_text("Transformed label from atlas")
    plt.imshow(trans_fixed_label_np[slice_index])
    plt.subplot(2, 2, 2).title.set_text("Dataset label")
    plt.imshow(moving_label[0, slice_index])
    plt.show()

sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
ui = widgets.VBox([widgets.HBox([sliceSlider])])
out = widgets.interactive_output(f, {'slice_index': sliceSlider})
# noinspection PyTypeChecker
display(ui, out)

VBox(children=(HBox(children=(IntSlider(value=53, max=106),)),))

Output()

# Atlas loading

In [10]:
atlas_ri = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/RI.mhd'))
atlas_brainstem_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/brain_stem_map.mhd'))
atlas_left_parotid_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/left_parotid_map.mhd'))
atlas_right_parotid_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/right_parotid_map.mhd'))

atlas.shape, atlas_brainstem_map.shape, atlas_left_parotid_map.shape, atlas_right_parotid_map.shape

NameError: name 'atlas' is not defined

In [None]:
max_slices = atlas.shape[0]
def f(slice_index):
    plt.figure(figsize=(30, 16))
    plt.subplot(2, 2, 1)
    plt.imshow(atlas[slice_index], cmap="gray")
    plt.subplot(2, 2, 2)
    plt.imshow(atlas_brainstem_map[slice_index], cmap="gray")
    plt.subplot(2, 2, 3)
    plt.imshow(atlas_left_parotid_map[slice_index], cmap="gray")
    plt.subplot(2, 2, 4)
    plt.imshow(atlas_right_parotid_map[slice_index], cmap="gray")
    plt.show()

sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
ui = widgets.VBox([widgets.HBox([sliceSlider])])
out = widgets.interactive_output(f, {'slice_index': sliceSlider})
# noinspection PyTypeChecker
display(ui, out)