In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/My Drive/dp_tomastik/code')
    !bash "/content/drive/My Drive/dp_tomastik/code/scripts/install_libs.sh"

import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import pandas as pd
import logging
import datetime
from torchio import RandomAffine, Compose, ZNormalization

import src.dataset.oars_labels_consts as OARS_LABELS
from src.consts import DATASET_MAX_BOUNDING_BOX, DESIRE_BOUNDING_BOX_SIZE
from src.helpers.threshold_calc_helpers import get_threshold_info_df
from src.helpers.show_model_dataset_pred_preview import show_model_dataset_pred_preview
from src.dataset.get_cut_lists import get_cut_lists
from src.dataset.get_full_res_cut import get_full_res_cut
from src.dataset.get_dataset import get_dataset
from src.dataset.get_dataset_info import get_dataset_info
from src.dataset.preview_dataset import preview_dataset
from src.dataset.get_dataset_transform import get_dataset_transform
from src.model_and_training.prepare_model import prepare_model
from src.model_and_training.train_loop import train_loop
from src.model_and_training.show_model_info import show_model_info
from src.model_and_training.load_checkpoint_model_info import load_checkpoint_model_info
from src.helpers.show_cuda_usage import show_cuda_usage
from src.helpers.get_rescaled_pred import get_rescaled_preds
from src.dataset.split_dataset import split_dataset, copy_split_dataset
from src.helpers.compare_prediction_with_ground_true import compare_prediction_with_ground_true, compare_one_prediction_with_ground_true
from src.helpers.get_img_outliers_pixels import get_img_outliers_pixels
from src.helpers.get_raw_with_prediction import get_raw_with_prediction
from src.model_and_training.getters.get_device import get_device


from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets

torch.manual_seed(20)
logging.basicConfig(filename='logs/all_organs_jupyter.log', level=logging.DEBUG)

print('Dataset biggest bounding box wihtout spinal cord', DATASET_MAX_BOUNDING_BOX)
print('Cut target size', DESIRE_BOUNDING_BOX_SIZE)
print('Done Init')

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits

Dataset biggest bounding box wihtout spinal cord [56, 177, 156]
Cut target size [72, 192, 168]
Done Init


In [2]:
TRAIN_MODELS = False
PREVIEW_ORGAN_MODEL = False
SHOW_DSC_INFO = True
PARSE_CUT_DATASET = True

In [3]:
def get_possible_models(oar_key):
    possible_models = [folder_name for folder_name in os.listdir('./models') if oar_key in folder_name]    
    
    return possible_models

# Loading precourse neural network with datasets

In [4]:
datasets_params = ['train_dataset', 'valid_dataset', 'test_dataset']
filter_labels = OARS_LABELS.OARS_LABELS_LIST
if OARS_LABELS.SPINAL_CORD in filter_labels:
    filter_labels.remove(OARS_LABELS.SPINAL_CORD)

# low res
low_res_dataset = get_dataset(dataset_size=50, shrink_factor=16, filter_labels=filter_labels, unify_labels=True)
low_res_dataset.dilatate_labels(repeat=1)
low_res_dataset.to_numpy()
low_res_split_dataset_obj = split_dataset(low_res_dataset, train_size=40, valid_size=5, test_size=5)
train_low_res_dataset, valid_low_res_dataset, test_low_res_dataset = itemgetter(*datasets_params)(low_res_split_dataset_obj)

# full res
full_res_dataset = get_dataset(dataset_size=50, shrink_factor=1, filter_labels=filter_labels, unify_labels=False)
full_res_dataset.to_numpy()
full_res_split_dataset_obj = copy_split_dataset(full_res_dataset, low_res_split_dataset_obj)

# low res model - precourse model
epoch = 500
log_date = datetime.datetime(year=2020, month=10, day=27, hour=11, minute=45, second=30).strftime("%Y%m%d-%H%M%S")
model_name = f'{log_date}_3d_unet_PRECOURSE'

low_res_model_info = load_checkpoint_model_info(model_name, epoch, train_low_res_dataset, valid_low_res_dataset, test_low_res_dataset)
show_model_info(low_res_model_info)

# moving low res to gpu
low_res_model_info['device'] = get_device()
# low_res_model_info['device'] = 'cpu'
low_res_model_info['model'] = low_res_model_info['model'].to(low_res_model_info['device'])
low_res_model_info['model'].eval()

# cut res
cut_full_res_dataset = full_res_dataset.copy(copy_lists=False)
cut_full_res_dataset = get_cut_lists(low_res_model_info['model'],
                                     low_res_model_info['device'],
                                     low_res_dataset, 
                                     full_res_dataset, 
                                     cut_full_res_dataset, 
                                     low_res_mask_threshold=0.5)
cut_full_res_dataset.set_output_label(None)
cut_split_dataset_obj = copy_split_dataset(cut_full_res_dataset, low_res_split_dataset_obj)
cut_train_dataset, cut_valid_dataset, cut_test_dataset = itemgetter(*datasets_params)(cut_split_dataset_obj)

# moving low res model to cpu
low_res_model_info['device'] = 'cpu'
low_res_model_info['model'] = low_res_model_info['model'].to(low_res_model_info['device'])

CUDA using 16x dataset
filtering labels
filtering labels done
dilatating 1x dataset
parsing dataset to numpy
numpy parsing done
CUDA using 1x dataset
filtering labels
filtering labels done
parsing dataset to numpy
numpy parsing done




Model number of params: 298881, trainable 298881
get_cut_lists: Cutting index 0
get_full_res_cut: Removing 10/1335 outlier pixels
get_final_bounding_box_slice: box delta [21 48 24]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1223526 1223526
get_cut_lists: Cutting index 1
get_full_res_cut: Removing 0/1416 outlier pixels
get_final_bounding_box_slice: box delta [24 16  8]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1326052 1326052
get_cut_lists: Cutting index 2
get_full_res_cut: Removing 0/1873 outlier pixels
get_final_bounding_box_slice: box delta [ 20   0 -24]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1890464 1890464
get_cut_lists: Cutting index 3
get_full_res_cut: Removing 0/1545 outlier pixels
get_final_bounding_box_slice: box delta [17 32  8]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1560217 1560217
get_cut_lists: 

In [5]:
get_dataset_info(low_res_dataset, low_res_split_dataset_obj)

train 40, valid_size 5, test 5, full 50
train indices [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 20, 21, 22, 23, 24, 28, 30, 31, 32, 33, 34, 35, 36, 37, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
valid indices [6, 13, 19, 25, 38]
test indices [16, 26, 27, 29, 39]


In [6]:
preview_dataset(cut_full_res_dataset)

data max 3071, min -1024
label max 22, min 0


VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=0, max=0))),))

Output()

# Training all organs models

In [7]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']

for OAR_KEY, OAR_VALUE in list(filter_labels.items())[:]:
    cut_full_res_dataset.set_output_label(OAR_VALUE)
    print(f'dataset label \'{OAR_KEY}\', \t value \'{OAR_VALUE}\'')

dataset label 'BRAIN_STEM', 	 value '1'
dataset label 'EYE_L', 	 value '2'
dataset label 'EYE_R', 	 value '3'
dataset label 'LENS_L', 	 value '4'
dataset label 'LENS_R', 	 value '5'
dataset label 'OPT_NERVE_L', 	 value '6'
dataset label 'OPT_NERVE_R', 	 value '7'
dataset label 'OPT_CHIASMA', 	 value '8'
dataset label 'TEMPORAL_LOBES_L', 	 value '9'
dataset label 'TEMPORAL_LOBES_R', 	 value '10'
dataset label 'PITUITARY', 	 value '11'
dataset label 'PAROTID_GLAND_L', 	 value '12'
dataset label 'PAROTID_GLAND_R', 	 value '13'
dataset label 'INNER_EAR_L', 	 value '14'
dataset label 'INNER_EAR_R', 	 value '15'
dataset label 'MID_EAR_L', 	 value '16'
dataset label 'MID_EAR_R', 	 value '17'
dataset label 'T_M_JOINT_L', 	 value '18'
dataset label 'T_M_JOINT_R', 	 value '19'
dataset label 'MANDIBLE_L', 	 value '21'
dataset label 'MANDIBLE_R', 	 value '22'


In [8]:
if TRAIN_MODELS:
    filter_labels = OARS_LABELS.OARS_LABELS_DICT
    if 'SPINAL_CORD' in filter_labels:
        del filter_labels['SPINAL_CORD']

    for OAR_KEY, OAR_VALUE in list(filter_labels.items())[:]:
        cut_full_res_dataset.set_output_label(OAR_VALUE)
        log_date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        model_name = f'{log_date}_3d_unet_{OAR_KEY}'

        print(f'Training model with dataset label \'{OAR_KEY}\', value \'{OAR_VALUE}\'')
        print(f'folder \'{model_name}\'')
        cut_model_info = prepare_model(epochs=75,
                                       learning_rate=3e-4,
                                       in_channels=8,
                                       input_data_channels=1,
                                       output_label_channels=1,
                                       dropout_rate=0.2,
                                       train_batch_size=2,
                                       model_name=model_name,
                                       train_dataset=cut_train_dataset, 
                                       valid_dataset=cut_valid_dataset, 
                                       test_dataset=cut_test_dataset)
        show_model_info(cut_model_info)
        print('\n\n')
        train_loop(cut_model_info)
        print('\n\n')

        # clearing memory
        del cut_model_info
        torch.cuda.empty_cache()

# Preview organ model

In [9]:
if PREVIEW_ORGAN_MODEL:
    # loading organ model
    # OAR_KEY = 'EYE_L'
    # epoch = 75
    # log_date_dict = {
    #     "year": 2020, 
    #     "month": 11, 
    #     "day": 2, 
    #     "hour": 15, 
    #     "minute": 19, 
    #     "second": 45
    # }
    # log_date = datetime.datetime(**log_date_dict).strftime("%Y%m%d-%H%M%S")
    # model_name = f'{log_date}_3d_unet_{OAR_KEY}'

    # OAR_VALUE = OARS_LABELS.EYE_L
    # OAR_VALUE = OARS_LABELS.OPT_NERVE_L
    # OAR_VALUE = OARS_LABELS.INNER_EAR_L
    # OAR_VALUE = OARS_LABELS.T_M_JOINT_L
    # OAR_VALUE = OARS_LABELS.MID_EAR_R
    # OAR_VALUE = OARS_LABELS.MID_EAR_L
    # OAR_VALUE = OARS_LABELS.BRAIN_STEM
    OAR_VALUE = OARS_LABELS.OPT_CHIASMA
    # OAR_VALUE = OARS_LABELS.PITUITARY
    # OAR_VALUE = OARS_LABELS.MANDIBLE_L
    # OAR_VALUE = OARS_LABELS.MANDIBLE_R

    OAR_KEY = OARS_LABELS.OARS_LABELS_R_DICT[OAR_VALUE]
    epoch = 75
    model_name = get_possible_models(OAR_KEY)[0]
    print(f'Loading {OAR_KEY} model')

    # loading model checkpoint
    cut_model_info = load_checkpoint_model_info(model_name, epoch, cut_train_dataset, cut_valid_dataset, cut_test_dataset)

    # moving model to cpu/cuda with eval mode
    # cut_model_info['device'] = 'cpu'
    cut_model_info['device'] = get_device()
    cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
    cut_model_info['model'].eval()

    # preparing dataset for comparison
    OAR_VALUE = OARS_LABELS.OARS_LABELS_DICT[OAR_KEY]
    cut_full_res_dataset.set_output_label(OAR_VALUE)

    # train
    rnd_train_idx = low_res_split_dataset_obj['train_dataset'].indices[0]
    print(f'Train index {rnd_train_idx}')
    raw_data, raw_label, raw_prediction = get_raw_with_prediction(cut_model_info['model'], cut_full_res_dataset, cut_model_info["device"], rnd_train_idx)
    compare_one_prediction_with_ground_true(raw_data,
                                            raw_label,
                                            raw_prediction,
                                            pred_threshold=0.5)

    # valid
    rnd_valid_idx = low_res_split_dataset_obj['valid_dataset'].indices[0]
    print(f'Valid index {rnd_valid_idx}')
    raw_data, raw_label, raw_prediction = get_raw_with_prediction(cut_model_info['model'], cut_full_res_dataset, cut_model_info["device"], rnd_valid_idx)
    compare_one_prediction_with_ground_true(raw_data,
                                            raw_label,
                                            raw_prediction,
                                            pred_threshold=0.5)
    # show dsc for model
    info_df, preds, rescaled_preds = get_threshold_info_df(model=cut_model_info['model'], 
                            dataset=cut_full_res_dataset, 
                            device=cut_model_info['device'], 
                            train_indices=cut_train_dataset.indices, 
                            valid_indices=cut_valid_dataset.indices, 
                            test_indices=cut_test_dataset.indices,
                            step=0.5)

    # final results with treshold
    best_threshold_col = 'thres_rescaled_dsc_0.50'
    train_dsc = info_df[info_df['is_train']][best_threshold_col].mean()
    valid_dsc = info_df[info_df['is_valid']][best_threshold_col].mean()
    print(f'{OAR_KEY} Model: DSC train {round(train_dsc, 4)} valid {round(valid_dsc, 4)}')
    display(info_df[info_df['is_train']].sort_values(by='thres_rescaled_dsc_0.50').drop(columns=['is_train', 'is_valid', 'is_test', 'thres_rescaled_dsc_0.00', 'thres_rescaled_dsc_1.00']))
    display(info_df[info_df['is_valid']].sort_values(by='thres_rescaled_dsc_0.50').drop(columns=['is_train', 'is_valid', 'is_test', 'thres_rescaled_dsc_0.00', 'thres_rescaled_dsc_1.00']))
    

# Loading all models

In [10]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']
    
models = dict()
for OAR_KEY, OAR_VALUE in list(filter_labels.items())[:]:
    epoch = 75
    possible_models = get_possible_models(OAR_KEY)
    if len(possible_models) <= 0:
        print(f'{OAR_KEY} Model: No avaiable model')
        continue

    model_name = possible_models[0]
    print(f'{OAR_KEY} Model: Loading model {model_name}')

    # loading model checkpoint
    cut_model_info = load_checkpoint_model_info(model_name, epoch, cut_train_dataset, cut_valid_dataset, cut_test_dataset)

    # moving model to cpu/cuda with eval mode
    cut_model_info['device'] = 'cpu'
    cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
    cut_model_info['model'].eval()
    
    models[OAR_KEY] = cut_model_info

BRAIN_STEM Model: Loading model 20201102-135642_3d_unet_BRAIN_STEM
EYE_L Model: Loading model 20201102-151945_3d_unet_EYE_L
EYE_R Model: Loading model 20201130-143833_3d_unet_EYE_R
LENS_L Model: Loading model 20201130-160023_3d_unet_LENS_L
LENS_R Model: Loading model 20201130-174740_3d_unet_LENS_R
OPT_NERVE_L Model: Loading model 20201102-180129_3d_unet_OPT_NERVE_L
OPT_NERVE_R Model: Loading model 20201102-192217_3d_unet_OPT_NERVE_R
OPT_CHIASMA Model: Loading model 20201102-215932_3d_unet_OPT_CHIASMA
TEMPORAL_LOBES_L Model: Loading model 20201102-231903_3d_unet_TEMPORAL_LOBES_L
TEMPORAL_LOBES_R Model: Loading model 20201103-003758_3d_unet_TEMPORAL_LOBES_R
PITUITARY Model: Loading model 20201103-105237_3d_unet_PITUITARY
PAROTID_GLAND_L Model: Loading model 20201103-121517_3d_unet_PAROTID_GLAND_L
PAROTID_GLAND_R Model: Loading model 20201103-160718_3d_unet_PAROTID_GLAND_R
INNER_EAR_L Model: Loading model 20201103-172811_3d_unet_INNER_EAR_L
INNER_EAR_R Model: Loading model 20201103-184739

# Calculating DSC for all models

In [11]:
if SHOW_DSC_INFO:
    info_per_organs_df = {}
    models_info = list()
    for OAR_KEY, OAR_VALUE in list(filter_labels.items())[:]:
        if OAR_KEY not in models:
            print(f'{OAR_KEY} Model: No avaiable model')
            continue

        # getting model to gpu
        cut_model_info = models[OAR_KEY]
        cut_model_info['device'] = get_device()
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        cut_model_info['model'].eval()

        # preparing dataset for comparison
        cut_full_res_dataset.set_output_label(OAR_VALUE)

        # calculating dsc predictions        
        info_df, preds, rescaled_preds = get_threshold_info_df(model=cut_model_info['model'], 
                                    dataset=cut_full_res_dataset, 
                                    device=cut_model_info['device'], 
                                    train_indices=cut_train_dataset.indices, 
                                    valid_indices=cut_valid_dataset.indices, 
                                    test_indices=cut_test_dataset.indices,
                                    step=0.5)
        info_per_organs_df[OAR_KEY] = info_df

        # moving model back to cpu
        cut_model_info['device'] = 'cpu'
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])

        # parsing data
        best_threshold_col = 'thres_rescaled_dsc_0.50'
        train_tmp_df = info_df[info_df['is_train']][best_threshold_col]
        valid_tmp_df = info_df[info_df['is_valid']][best_threshold_col]
        train_dsc = train_tmp_df.mean()
        valid_dsc = valid_tmp_df.mean()
        print(f'{OAR_KEY} Model: DSC train {round(train_dsc, 4)} valid {round(valid_dsc, 4)}')

        models_info.append({
            'oar_key': OAR_KEY,
            'model_name': model_name,
            # Train
            'train_dsc_mean': train_dsc,
            'train_dsc_std': train_tmp_df.std(),
            'train_dsc_median': train_tmp_df.median(),
            'train_dsc_min': train_tmp_df.min(),
            'train_dsc_max': train_tmp_df.max(),
            # Valid
            'valid_dsc_mean': valid_dsc,
            'valid_dsc_std': valid_tmp_df.std(),
            'valid_dsc_median': valid_tmp_df.median(),
            'valid_dsc_min': valid_tmp_df.min(),
            'valid_dsc_max': valid_tmp_df.max(),
            # Both
            'train_valid_mean_delta': train_dsc - valid_dsc
        })

    models_info_df = pd.DataFrame(models_info)
    
    tmp_df = models_info_df[['oar_key', 'train_dsc_mean', 'train_dsc_std', 'valid_dsc_mean', 'valid_dsc_std']].copy()
    tmp_df['train_dsc_mean'] = (tmp_df['train_dsc_mean'] * 100).round(2)
    tmp_df['valid_dsc_mean'] = (tmp_df['valid_dsc_mean'] * 100).round(2)
    tmp_df['train_dsc_std'] = (tmp_df['train_dsc_std'] * 100).round(2)
    tmp_df['valid_dsc_std'] = (tmp_df['valid_dsc_std'] * 100).round(2)
    
    display(tmp_df.mean().round(2))
    display(tmp_df.round(2))
    display(tmp_df.sort_values(by=['train_dsc_std']).round(2))
    display(models_info_df.sort_values(by=['train_dsc_mean']).drop(columns=['model_name']).round(2))
    display(models_info_df.sort_values(by=['train_valid_mean_delta']).drop(columns=['model_name']).round(2))


BRAIN_STEM Model: DSC train 0.8973 valid 0.8357
EYE_L Model: DSC train 0.9096 valid 0.8659
EYE_R Model: DSC train 0.8742 valid 0.8615
LENS_L Model: DSC train 0.6698 valid 0.6012
LENS_R Model: DSC train 0.704 valid 0.5999
OPT_NERVE_L Model: DSC train 0.657 valid 0.6066
OPT_NERVE_R Model: DSC train 0.7528 valid 0.6549
OPT_CHIASMA Model: DSC train 0.6481 valid 0.4717
TEMPORAL_LOBES_L Model: DSC train 0.8495 valid 0.7969
TEMPORAL_LOBES_R Model: DSC train 0.8332 valid 0.7857
PITUITARY Model: DSC train 0.6532 valid 0.5429
PAROTID_GLAND_L Model: DSC train 0.8696 valid 0.8503
PAROTID_GLAND_R Model: DSC train 0.8628 valid 0.8007
INNER_EAR_L Model: DSC train 0.825 valid 0.7889
INNER_EAR_R Model: DSC train 0.815 valid 0.7507
MID_EAR_L Model: DSC train 0.8491 valid 0.8326
MID_EAR_R Model: DSC train 0.7623 valid 0.7623
T_M_JOINT_L Model: DSC train 0.7649 valid 0.7352
T_M_JOINT_R Model: DSC train 0.8162 valid 0.7818
MANDIBLE_L Model: DSC train 0.9227 valid 0.9152
MANDIBLE_R Model: DSC train 0.9302 v

train_dsc_mean    80.32
train_dsc_std      5.87
valid_dsc_mean    75.05
valid_dsc_std      8.28
dtype: float64

Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,valid_dsc_mean,valid_dsc_std
0,BRAIN_STEM,89.73,1.97,83.57,4.15
1,EYE_L,90.96,1.85,86.59,3.14
2,EYE_R,87.42,2.95,86.15,5.46
3,LENS_L,66.98,8.68,60.12,16.21
4,LENS_R,70.4,9.65,59.99,19.03
5,OPT_NERVE_L,65.7,9.44,60.66,6.86
6,OPT_NERVE_R,75.28,9.89,65.49,11.11
7,OPT_CHIASMA,64.81,7.9,47.17,9.21
8,TEMPORAL_LOBES_L,84.95,4.56,79.69,5.85
9,TEMPORAL_LOBES_R,83.32,3.31,78.57,9.54


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,valid_dsc_mean,valid_dsc_std
19,MANDIBLE_L,92.27,1.44,91.52,1.11
20,MANDIBLE_R,93.02,1.56,91.91,2.14
1,EYE_L,90.96,1.85,86.59,3.14
0,BRAIN_STEM,89.73,1.97,83.57,4.15
12,PAROTID_GLAND_R,86.28,2.55,80.07,8.52
15,MID_EAR_L,84.91,2.91,83.26,3.37
11,PAROTID_GLAND_L,86.96,2.94,85.03,2.65
2,EYE_R,87.42,2.95,86.15,5.46
9,TEMPORAL_LOBES_R,83.32,3.31,78.57,9.54
14,INNER_EAR_R,81.5,4.4,75.07,5.51


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,train_dsc_median,train_dsc_min,train_dsc_max,valid_dsc_mean,valid_dsc_std,valid_dsc_median,valid_dsc_min,valid_dsc_max,train_valid_mean_delta
7,OPT_CHIASMA,0.65,0.08,0.65,0.39,0.82,0.47,0.09,0.45,0.39,0.63,0.18
10,PITUITARY,0.65,0.2,0.72,0.0,0.87,0.54,0.32,0.69,0.02,0.85,0.11
5,OPT_NERVE_L,0.66,0.09,0.66,0.47,0.81,0.61,0.07,0.62,0.49,0.68,0.05
3,LENS_L,0.67,0.09,0.66,0.48,0.83,0.6,0.16,0.62,0.34,0.74,0.07
4,LENS_R,0.7,0.1,0.71,0.41,0.87,0.6,0.19,0.62,0.29,0.8,0.1
6,OPT_NERVE_R,0.75,0.1,0.79,0.45,0.9,0.65,0.11,0.64,0.51,0.79,0.1
16,MID_EAR_R,0.76,0.09,0.79,0.43,0.86,0.76,0.04,0.78,0.71,0.8,0.0
17,T_M_JOINT_L,0.76,0.07,0.77,0.52,0.86,0.74,0.06,0.72,0.67,0.8,0.03
14,INNER_EAR_R,0.81,0.04,0.83,0.69,0.87,0.75,0.06,0.72,0.69,0.82,0.06
18,T_M_JOINT_R,0.82,0.06,0.82,0.66,0.89,0.78,0.11,0.74,0.65,0.92,0.03


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,train_dsc_median,train_dsc_min,train_dsc_max,valid_dsc_mean,valid_dsc_std,valid_dsc_median,valid_dsc_min,valid_dsc_max,train_valid_mean_delta
16,MID_EAR_R,0.76,0.09,0.79,0.43,0.86,0.76,0.04,0.78,0.71,0.8,0.0
19,MANDIBLE_L,0.92,0.01,0.92,0.86,0.95,0.92,0.01,0.92,0.9,0.92,0.01
20,MANDIBLE_R,0.93,0.02,0.93,0.88,0.96,0.92,0.02,0.92,0.89,0.95,0.01
2,EYE_R,0.87,0.03,0.88,0.78,0.92,0.86,0.05,0.88,0.78,0.92,0.01
15,MID_EAR_L,0.85,0.03,0.85,0.79,0.9,0.83,0.03,0.85,0.79,0.87,0.02
11,PAROTID_GLAND_L,0.87,0.03,0.88,0.8,0.91,0.85,0.03,0.86,0.82,0.88,0.02
17,T_M_JOINT_L,0.76,0.07,0.77,0.52,0.86,0.74,0.06,0.72,0.67,0.8,0.03
18,T_M_JOINT_R,0.82,0.06,0.82,0.66,0.89,0.78,0.11,0.74,0.65,0.92,0.03
13,INNER_EAR_L,0.83,0.05,0.82,0.64,0.89,0.79,0.06,0.8,0.7,0.88,0.04
1,EYE_L,0.91,0.02,0.91,0.84,0.94,0.87,0.03,0.88,0.83,0.9,0.04


In [12]:
tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.EYE_R]]
tmp_df[tmp_df['is_train']].sort_values(by='thres_rescaled_dsc_0.50')

Unnamed: 0_level_0,dsc,rescaled_dsc,is_train,is_valid,is_test,thres_rescaled_dsc_0.00,thres_rescaled_dsc_0.50,thres_rescaled_dsc_1.00
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
36,0.754958,0.756221,True,False,False,0.002161,0.776856,3.980892e-10
11,0.777232,0.778236,True,False,False,0.001738,0.801601,4.950495e-10
12,0.803535,0.804218,True,False,False,0.002372,0.829318,3.625816e-10
35,0.815841,0.816227,True,False,False,0.002685,0.83393,3.203075e-10
47,0.815123,0.815965,True,False,False,0.001594,0.841491,5.396654e-10
5,0.823594,0.824237,True,False,False,0.001895,0.849975,4.539265e-10
8,0.825789,0.826466,True,False,False,0.001735,0.851137,4.957858e-10
10,0.830481,0.831037,True,False,False,0.002515,0.852392,3.419973e-10
40,0.830494,0.831411,True,False,False,0.002411,0.853592,3.566334e-10
2,0.839977,0.84043,True,False,False,0.003058,0.855355,2.811358e-10


# Predictions merging and checking

In [13]:
filter_labels_dict = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels_dict['SPINAL_CORD']

cut_full_res_dataset.set_output_label(filter_labels_dict)
preview_dataset(cut_full_res_dataset)

data max 3071, min -1024
label max 1, min 0


VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=10, max=20))),))

Output()

In [14]:
def custom_preview_dataset(dataset, preview_index=0, show_hist=False, use_transform=False):
    if use_transform:
        data, label = dataset[preview_index]
    else:
        data, label = dataset.get_raw_item_with_label_filter(preview_index)
    max_channels = label.shape[0]
    max_slices = label.shape[1]

    print(f'data max {data.max()}, min {data.min()}')
    print(f'label max {label.max()}, min {label.min()}')

    def f(slice_index, label_channel):
        plt.figure(figsize=(20, 10))
        plt.subplot(1, 3, 1)
        plt.imshow(data[0, slice_index], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.imshow(data[label_channel+1, slice_index], cmap="gray")
        plt.subplot(1, 3, 3)
        plt.imshow(label[label_channel, slice_index])
        plt.show()

        if show_hist:
            plt.figure(figsize=(20, 10))
            plt.subplot(1, 2, 1)
            plt.hist(data.flatten(), 128)
            plt.subplot(1, 2, 2)
            plt.hist(label.flatten(), 128)
            plt.show()

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    labelChannelSlider = widgets.IntSlider(min=0, max=max_channels - 1, step=1, value=(max_channels - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider, labelChannelSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider, 'label_channel': labelChannelSlider})
    # noinspection PyTypeChecker
    display(ui, out)


In [None]:
from src.dataset.get_norm_transform import get_norm_transform
from src.dataset.transform_input import transform_input
from src.helpers.get_rescaled_pred import get_rescaled_pred

if PARSE_CUT_DATASET:
    extended_cut_full_res_dataset = cut_full_res_dataset.copy()
    
    # preparing cut dataset
    for index in range(len(extended_cut_full_res_dataset)):
        tmp_label = extended_cut_full_res_dataset.label_list[index]

        new_data_channels = len(extended_cut_full_res_dataset.output_label) + 1
        new_data_shape = (new_data_channels, *tmp_label[0].shape)
        new_data = np.zeros(new_data_shape, dtype=np.int16)
        new_data[0] = extended_cut_full_res_dataset.data_list[index][0]

        extended_cut_full_res_dataset.data_list[index] = new_data

    prediction_threshold = 0.5
    output_label_items = list(extended_cut_full_res_dataset.output_label.items())[:]
    # for each label
    for label_index, val in enumerate(output_label_items[:]):
        OAR_KEY, OAR_VALUE = val
        # loading model
        if OAR_KEY not in models:
            print(f'{label_index+1}/{len(output_label_items)}: {OAR_KEY} Model: No avaiable model')
            continue
        print(f'{label_index+1}/{len(output_label_items)}: {OAR_KEY} Model: got model {datetime.datetime.now()}')

        # getting model to gpu
        cut_model_info = models[OAR_KEY]
        cut_model_info['device'] = get_device()
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        cut_model_info['model'].eval()

        # for label in whole dataset
        for index in range(len(extended_cut_full_res_dataset)):
            prediction, rescaled_pred = get_rescaled_pred(cut_model_info['model'], 
                                                          cut_full_res_dataset, 
                                                          cut_model_info['device'], 
                                                          index,
                                                          use_only_one_dimension=True)
            
            extended_cut_full_res_dataset.data_list[index][label_index + 1] = ((rescaled_pred > prediction_threshold) * 1).astype(np.int8)

        # moving model back to cpu
        cut_model_info['device'] = 'cpu'
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        
    custom_preview_dataset(extended_cut_full_res_dataset)

1/21: BRAIN_STEM Model: got model 2020-12-14 15:02:47.750288


## Merging predictions

In [None]:
merged_predictions = [None] * len(extended_cut_full_res_dataset)
for index in range(len(extended_cut_full_res_dataset)):
    print(f"{index+1}/{len(extended_cut_full_res_dataset)}: Merging predictions to single label")
    data, label = extended_cut_full_res_dataset.get_raw_item_with_label_filter(index)
    
    new_data = np.zeros(data[0].shape, dtype=np.int16)
    for i in range(1, 22):
        new_data += data[i]

    merged_predictions[index] = new_data

#### checking how many masks are overlapping

In [None]:
for tmp_merged in merged_predictions:
    display(f'{np.where(tmp_merged == 1)[0].shape[0]}, {np.where(tmp_merged == 2)[0].shape[0]}, {np.where(tmp_merged == 3)[0].shape[0]}, {np.where(tmp_merged == 4)[0].shape[0]}')

In [None]:
from src.helpers.calc_dsc import calc_dsc

def custom_preview_dataset2(dataset, preview_index=0, show_hist=False, use_transform=False):
    if use_transform:
        data, label = dataset[preview_index]
    else:
        data, label = dataset.get_raw_item_with_label_filter(preview_index)
        
    cut_data, cut_label = cut_full_res_dataset.get_raw_item_with_label_filter(preview_index)
    max_channels = label.shape[0]
    max_slices = label.shape[1]
    
    print(f'data max {data.max()}, min {data.min()}')
    print(f'label max {label.max()}, min {label.min()}')
    print(f'{data.shape}, {cut_data.shape}, {label.shape}, {cut_label.shape}')
    print(f'{data.dtype}, {cut_data.dtype}, {label.dtype}, {cut_label.dtype}')
    tmp_merged = merged_predictions[preview_index]
    print(f'{np.where(tmp_merged == 1)[0].shape},{np.where(tmp_merged == 2)[0].shape},{np.where(tmp_merged == 3)[0].shape},{np.where(tmp_merged == 4)[0].shape}')

    def f(slice_index, label_channel):
        print(f'{OARS_LABELS.OARS_LABELS_R_DICT[label_channel+1]}')
        tmp_tensor_label = torch.tensor(label[label_channel])
        tmp_tensor_prediciton = torch.tensor(data[label_channel+1])
        tmp_dsc = calc_dsc(tmp_tensor_label, tmp_tensor_prediciton)
        print(f'dsc {tmp_dsc}')

        plt.figure(figsize=(30, 20))

        plt.subplot(2, 3, 1).title.set_text('data')
        plt.imshow(cut_data[0, slice_index], cmap="gray")
        plt.subplot(2, 3, 2).title.set_text('label')
        plt.imshow(label[label_channel, slice_index])
        plt.subplot(2, 3, 3).title.set_text('prediciton')
        plt.imshow(data[label_channel+1, slice_index], vmin=0, vmax=np.unique(data[1])[-1])
        plt.subplot(2, 3, 4).title.set_text('merged prediction labels')
        plt.imshow(tmp_merged[slice_index], vmin=0, vmax=np.unique(tmp_merged)[-1])
        plt.subplot(2, 3, 5).title.set_text('merged labels ')
        plt.imshow(np.sum(label, axis=0)[slice_index])

        plt.show()

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    labelChannelSlider = widgets.IntSlider(min=0, max=max_channels - 1, step=1, value=(max_channels - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider, labelChannelSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider, 'label_channel': labelChannelSlider})
    # noinspection PyTypeChecker
    display(ui, out)

In [None]:
custom_preview_dataset2(extended_cut_full_res_dataset, preview_index=0)

## Merging Unet model

In [None]:
TRAIN_COMBINE_MODEL = False
if TRAIN_COMBINE_MODEL:
    log_date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    model_name = f'{log_date}_3d_unet_COMBINE'

    print(f'Training model with dataset all labels at input')
    print(f'folder \'{model_name}\'')
    combine_model_info = prepare_model(epochs=5,
                                   learning_rate=3e-4,
                                   in_channels=8,
                                   input_data_channels=22,
                                   output_label_channels=21,
                                   dropout_rate=0.2,
                                   train_batch_size=1,
                                   model_name=model_name,
                                   train_dataset=cut_train_dataset,
                                   valid_dataset=cut_valid_dataset,
                                   test_dataset=cut_test_dataset)
    show_model_info(combine_model_info)
    print('\n\n')
    train_loop(combine_model_info)
    print('\n\n')

    # clearing memory
    del combine_model_info
    torch.cuda.empty_cache()