In [4]:
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/My Drive/dp_tomastik/code')
    !bash "/content/drive/My Drive/dp_tomastik/code/scripts/install_libs.sh"

import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import pandas as pd
import logging
import datetime

from pathlib import Path
import SimpleITK as sitk
import nrrd
import os
import cv2
import sys

from torchio import RandomAffine, Compose, ZNormalization



from src.dataset import get_cut_lists
from src.dataset import get_full_res_cut
from src.dataset import get_dataset
from src.dataset import get_dataset_info
from src.dataset import get_dataset_transform
from src.dataset import split_dataset, copy_split_dataset

from src.model_and_training import prepare_model
from src.model_and_training import train_loop
from src.model_and_training import show_model_info
from src.model_and_training import load_checkpoint_model_info
from src.model_and_training.getters import get_device

from src.helpers import preview_dataset
from src.helpers import get_threshold_info_df
from src.helpers import preview_model_dataset_pred
from src.helpers import show_cuda_usage
from src.helpers import get_rescaled_preds
from src.helpers import compare_prediction_with_ground_true, compare_one_prediction_with_ground_true
from src.helpers import get_transformed_label_np, create_regis_trans_list, trans_list
from src.helpers import get_img_outliers_pixels
from src.helpers import get_raw_with_prediction

from src.consts import DATASET_MAX_BOUNDING_BOX, DESIRE_BOUNDING_BOX_SIZE


from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets

torch.manual_seed(20)
logging.basicConfig(filename='logs/pdd_data_check.log', level=logging.DEBUG)

print('Dataset biggest bounding box wihtout spinal cord', DATASET_MAX_BOUNDING_BOX)
print('Cut target size', DESIRE_BOUNDING_BOX_SIZE)
print('Done Init')

Dataset biggest bounding box wihtout spinal cord [56, 177, 156]
Cut target size [72, 192, 168]
Done Init


In [5]:
TRAIN_REGISTRATION=True
TRANSFORM_REGISTRATION=True
DISPLAY_REGISTRATION=True

# PDDCA

In [None]:
def preview_3d_image(img):
    if type(img) is sitk.SimpleITK.Image:
        img = sitk.GetArrayFromImage(img)

    max_slices = img.shape[0]
    def f(slice_index):    
        plt.figure(figsize=(16, 16))
        plt.imshow(img[slice_index])
        plt.show()
        print(f"debug: {img.min()}, {img.max()}")
        print(f"debug: unique {np.unique(img[slice_index])}")

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider})
    # noinspection PyTypeChecker
    display(ui, out)

In [None]:
from pathlib import Path
import SimpleITK as sitk
import nrrd
import os
import cv2
import sys

LOAD_PDDCA = True
if LOAD_PDDCA:
    # PDDCA
    d ="./data/PDDCA-1.4.1"
    pddca_dir_items = sorted([o for o in os.listdir(d) if os.path.isdir(os.path.join(d,o))])

    ignore_items = ['0522c0014', '0522c0077', '0522c0079', '0522c0147', '0522c0159', '0522c0161', '0522c0190', '0522c0226', 
                    '0522c0329', '0522c0330', '0522c0427', '0522c0433', '0522c0441', '0522c0455', '0522c0457', '0522c0479']
    print(f'Loading {len(pddca_dir_items) - len(ignore_items)} items')

    pddca_items = list() 
    for item_id in pddca_dir_items:
        if item_id in ignore_items:
            # print(f"pddca {item_id}: ignoring")
            continue

        # parsing data
        data_filepath = Path.joinpath(Path(d), f'./{item_id}/img.nrrd')
        pddca_data, header = nrrd.read(data_filepath)
        pddca_data = pddca_data.astype(np.int16)
        pddca_data = np.transpose(pddca_data, axes=[2, 0, 1]).swapaxes(-2,-1)[...,::-1]

        # parsing labels
        oar_labels = ["BrainStem", "Chiasm", "Mandible", "OpticNerve_L", "OpticNerve_R", "Parotid_L", "Parotid_R", "Submandibular_L", "Submandibular_R"]
        pddca_label = np.zeros(pddca_data.shape, dtype=np.int8)

        for OAR_INDEX, OAR_KEY in enumerate(oar_labels):
            label_filepath = Path.joinpath(Path(d), f'./{item_id}/structures/{OAR_KEY}.nrrd')
            oar_pddca_label, header = nrrd.read(label_filepath)
            oar_pddca_label = oar_pddca_label.astype(np.int8)
            oar_pddca_label = np.transpose(oar_pddca_label, axes=[2, 0, 1]).swapaxes(-2,-1)[...,::-1]
            pddca_label += oar_pddca_label*(OAR_INDEX+1)

        # appending
        pddca_items.append((pddca_data, pddca_label))
        print(f"pddca {item_id}: {pddca_data.max()}, {pddca_data.min()}, {pddca_label.max()}, {pddca_label.min()}, {pddca_data.dtype}, {pddca_label.dtype}, {pddca_data.shape}, {pddca_label.shape}")

    print('Done loading')

In [None]:
pddca_label.shape

In [None]:
if LOAD_PDDCA:
    item_index = 1
    pddca_data, pddca_label = pddca_items[item_index]

    max_slices = pddca_data.shape[0]
    def f(slice_index):    
        plt.figure(figsize=(20, 20))
        plt.subplot(2, 2, 1)
        plt.imshow(pddca_data[slice_index], cmap="gray")
        plt.subplot(2, 2, 2)
        plt.imshow(pddca_label[slice_index])
        plt.subplot(2, 2, 3)

        tmp_combine = np.stack((pddca_data[slice_index],) * 3, axis=-1)
        tmp_combine -= tmp_combine.min()
        tmp_combine = tmp_combine / tmp_combine.max()    
        tmp = (pddca_label[slice_index] > 1) * 1
        tmp_cond = tmp > 0
        tmp_combine[tmp_cond, 0] = tmp[tmp_cond]

        plt.imshow(tmp_combine)
        plt.show()
        print(f"debug: {pddca_data.min()}, {pddca_data.max()}")
        print(f"debug: {tmp_combine.min()}, {tmp_combine.max()}")
        print(f"debug: unique {np.unique(pddca_label[slice_index])}")

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider})
    # noinspection PyTypeChecker
    display(ui, out)

# STRUCT SEG 2019

In [None]:
filter_labels = OARS_LABELS.OARS_LABELS_LIST
if OARS_LABELS.SPINAL_CORD in filter_labels:
    filter_labels.remove(OARS_LABELS.SPINAL_CORD)

full_res_dataset = get_dataset(dataset_size=50, shrink_factor=4, filter_labels=filter_labels, unify_labels=False)
full_res_dataset.to_numpy()

# Registration

In [6]:
atlas_ri = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/RI.mhd'))
atlas_brainstem_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/brain_stem_map.mhd'))
atlas_left_parotid_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/left_parotid_map.mhd'))
atlas_right_parotid_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/right_parotid_map.mhd'))

atlas_ri.shape, atlas_brainstem_map.shape, atlas_left_parotid_map.shape, atlas_right_parotid_map.shape

NameError: name 'sitk' is not defined

In [None]:
data, label = pddca_items[0]
moving_data, moving_label = full_res_dataset.get_raw_item_with_label_filter(0)

moving_data.shape, moving_label.shape, data.shape, label.shape

In [None]:
from src.helpers import get_registration_transform_rigid_sitk, get_registration_transform_non_rigid_sitk

In [None]:
if TRAIN_REGISTRATION:
    fixed_data, fixed_label = full_res_dataset.get_raw_item_with_label_filter(0)
    moving_data, moving_label = (atlas_ri[60:, :, 46:-45], atlas_brainstem_map[60:, :, 46:-45])

    fixed_data = fixed_data.astype(np.float32)[0]
    moving_data = moving_data.astype(np.float32)
    print(fixed_data.dtype, moving_data.dtype, fixed_data.shape, moving_data.shape)

    fixed_data_sitk = sitk.GetImageFromArray(fixed_data)
    moving_data_sitk = sitk.GetImageFromArray(moving_data)

    # output_transform = get_registration_transform_rigid_sitk(fixed_data_sitk, moving_data_sitk, show=True)
    output_transform = get_registration_transform_non_rigid_sitk(fixed_data_sitk, moving_data_sitk, show=True)

In [None]:
if TRANSFORM_REGISTRATION:
    moving_label_sitk = sitk.GetImageFromArray(moving_label)
    trans_fixed_label = transform_sitk(fixed_data_sitk, moving_label_sitk, output_transform)
    trans_fixed_label_np = sitk.GetArrayFromImage(trans_fixed_label)

In [None]:
if DISPLAY_REGISTRATION:
    max_slices = trans_fixed_label_np.shape[0]

    def f(slice_index):
        plt.figure(figsize=(20, 20))
        plt.subplot(2, 2, 1).title.set_text("Transformed label from atlas")
        plt.imshow(trans_fixed_label_np[slice_index])
        plt.subplot(2, 2, 2).title.set_text("Dataset label")
        plt.imshow(moving_label[slice_index])
        plt.show()

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider})
    # noinspection PyTypeChecker
    display(ui, out)

# Atlas loading

In [None]:
atlas_ri = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/RI.mhd'))
atlas_brainstem_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/brain_stem_map.mhd'))
atlas_left_parotid_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/left_parotid_map.mhd'))
atlas_right_parotid_map = sitk.GetArrayFromImage(sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/right_parotid_map.mhd'))

atlas_ri.shape, atlas_brainstem_map.shape, atlas_left_parotid_map.shape, atlas_right_parotid_map.shape

In [None]:
max_slices = atlas_ri.shape[0]
def f(slice_index):
    plt.figure(figsize=(30, 16))
    plt.subplot(2, 2, 1)
    plt.imshow(atlas_ri[slice_index], cmap="gray")
    plt.subplot(2, 2, 2)
    plt.imshow(atlas_brainstem_map[slice_index], cmap="gray")
    plt.subplot(2, 2, 3)
    plt.imshow(atlas_left_parotid_map[slice_index], cmap="gray")
    plt.subplot(2, 2, 4)
    plt.imshow(atlas_right_parotid_map[slice_index], cmap="gray")
    plt.show()

sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
ui = widgets.VBox([widgets.HBox([sliceSlider])])
out = widgets.interactive_output(f, {'slice_index': sliceSlider})
# noinspection PyTypeChecker
display(ui, out)

In [None]:
## TODO: register atlas to NN input, think about speeding up because of data augmentation
## TODO: implement architecture CRNF

In [None]:
atlas_ri.shape, DESIRE_BOUNDING_BOX_SIZE