In [3]:
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/My Drive/dp_tomastik/code')
    !bash "/content/drive/My Drive/dp_tomastik/code/scripts/install_libs.sh"

import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import pandas as pd
import logging
import datetime
from torchio import RandomAffine, Compose, ZNormalization

import src.dataset.oars_labels_consts as OARS_LABELS
from src.consts import DATASET_MAX_BOUNDING_BOX, DESIRE_BOUNDING_BOX_SIZE

from src.dataset import get_cut_lists
from src.dataset import get_full_res_cut
from src.dataset import get_dataset
from src.dataset import get_dataset_info
from src.dataset import get_dataset_transform
from src.dataset import split_dataset, copy_split_dataset

from src.model_and_training import prepare_model
from src.model_and_training import train_loop
from src.model_and_training import show_model_info
from src.model_and_training import load_checkpoint_model_info

from src.helpers import preview_dataset
from src.helpers import get_threshold_info_df
from src.helpers import preview_model_dataset_pred
from src.helpers import show_cuda_usage
from src.helpers import get_rescaled_preds
from src.helpers import compare_prediction_with_ground_true, compare_one_prediction_with_ground_true
from src.helpers import get_img_outliers_pixels
from src.helpers import get_raw_with_prediction

from src.model_and_training.getters.get_device import get_device
from src.model_and_training.architectures.unet_architecture_v3v1 import UNetV3v1


from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets

torch.manual_seed(20)
logging.basicConfig(filename='logs/all_organs_jupyter.log', level=logging.DEBUG)

print('Dataset biggest bounding box wihtout spinal cord', DATASET_MAX_BOUNDING_BOX)
print('Cut target size', DESIRE_BOUNDING_BOX_SIZE)
print('Done Init')

Dataset biggest bounding box wihtout spinal cord [56, 177, 156]
Cut target size [72, 192, 168]
Done Init


In [11]:
def get_possible_models(oar_key):
    possible_models = [folder_name for folder_name in os.listdir('./models') if oar_key in folder_name]    
    
    return possible_models

# Loading precourse neural network with datasets

In [12]:
datasets_params = ['train_dataset', 'valid_dataset', 'test_dataset']
filter_labels = OARS_LABELS.OARS_LABELS_LIST
if OARS_LABELS.SPINAL_CORD in filter_labels:
    filter_labels.remove(OARS_LABELS.SPINAL_CORD)

# low res
low_res_dataset = get_dataset(dataset_size=50, shrink_factor=16, filter_labels=filter_labels, unify_labels=True)
low_res_dataset.dilatate_labels(repeat=1)
low_res_dataset.to_numpy()
low_res_split_dataset_obj = split_dataset(low_res_dataset, train_size=40, valid_size=5, test_size=5)
train_low_res_dataset, valid_low_res_dataset, test_low_res_dataset = itemgetter(*datasets_params)(low_res_split_dataset_obj)

# full res
full_res_dataset = get_dataset(dataset_size=50, shrink_factor=1, filter_labels=filter_labels, unify_labels=False)
full_res_dataset.to_numpy()
full_res_split_dataset_obj = copy_split_dataset(full_res_dataset, low_res_split_dataset_obj)

# low res model - precourse model
epoch = 500
log_date = datetime.datetime(year=2020, month=10, day=27, hour=11, minute=45, second=30).strftime("%Y%m%d-%H%M%S")
model_name = f'{log_date}_3d_unet_PRECOURSE'

low_res_model_info = load_checkpoint_model_info(model_name, epoch, train_low_res_dataset, valid_low_res_dataset, test_low_res_dataset)
show_model_info(low_res_model_info)

# moving low res to gpu
low_res_model_info['device'] = get_device()
# low_res_model_info['device'] = 'cpu'
low_res_model_info['model'] = low_res_model_info['model'].to(low_res_model_info['device'])
low_res_model_info['model'].eval()

# cut res
cut_full_res_dataset = full_res_dataset.copy(copy_lists=False)
cut_full_res_dataset = get_cut_lists(low_res_model_info['model'],
                                     low_res_model_info['device'],
                                     low_res_dataset, 
                                     full_res_dataset, 
                                     cut_full_res_dataset, 
                                     low_res_mask_threshold=0.5)
cut_full_res_dataset.set_output_label(None)
cut_split_dataset_obj = copy_split_dataset(cut_full_res_dataset, low_res_split_dataset_obj)
cut_train_dataset, cut_valid_dataset, cut_test_dataset = itemgetter(*datasets_params)(cut_split_dataset_obj)

# moving low res model to cpu
low_res_model_info['device'] = 'cpu'
low_res_model_info['model'] = low_res_model_info['model'].to(low_res_model_info['device'])

CUDA using 16x dataset
filtering labels
filtering labels done
dilatating 1x dataset
parsing dataset to numpy
numpy parsing done
CUDA using 1x dataset
filtering labels
filtering labels done
parsing dataset to numpy
numpy parsing done




Model number of params: 298881, trainable 298881
get_cut_lists: Cutting index 0
get_full_res_cut: Removing 10/1335 outlier pixels
get_final_bounding_box_slice: box delta [21 48 24]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1223526 1223526
get_cut_lists: Cutting index 1
get_full_res_cut: Removing 0/1416 outlier pixels
get_final_bounding_box_slice: box delta [24 16  8]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1326052 1326052
get_cut_lists: Cutting index 2
get_full_res_cut: Removing 0/1873 outlier pixels
get_final_bounding_box_slice: box delta [ 20   0 -24]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1890464 1890464
get_cut_lists: Cutting index 3
get_full_res_cut: Removing 0/1545 outlier pixels
get_final_bounding_box_slice: box delta [17 32  8]
get_full_res_cut: Does cut and original label contain the same amount of pixels? True 1560217 1560217
get_cut_lists: 

In [13]:
get_dataset_info(low_res_dataset, low_res_split_dataset_obj)

train 40, valid_size 5, test 5, full 50
train indices [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 20, 21, 22, 23, 24, 28, 30, 31, 32, 33, 34, 35, 36, 37, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
valid indices [6, 13, 19, 25, 38]
test indices [16, 26, 27, 29, 39]


In [14]:
preview_dataset(cut_full_res_dataset)

data max 3071, min -1024
label max 22, min 0


VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=0, max=0))),))

Output()

# Training all organs models

checking list used for training single models

In [15]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']

for OAR_KEY, OAR_VALUE in list(filter_labels.items())[:]:
    cut_full_res_dataset.set_output_label(OAR_VALUE)
    print(f'dataset label \'{OAR_KEY}\', \t value \'{OAR_VALUE}\'')

dataset label 'BRAIN_STEM', 	 value '1'
dataset label 'EYE_L', 	 value '2'
dataset label 'EYE_R', 	 value '3'
dataset label 'LENS_L', 	 value '4'
dataset label 'LENS_R', 	 value '5'
dataset label 'OPT_NERVE_L', 	 value '6'
dataset label 'OPT_NERVE_R', 	 value '7'
dataset label 'OPT_CHIASMA', 	 value '8'
dataset label 'TEMPORAL_LOBES_L', 	 value '9'
dataset label 'TEMPORAL_LOBES_R', 	 value '10'
dataset label 'PITUITARY', 	 value '11'
dataset label 'PAROTID_GLAND_L', 	 value '12'
dataset label 'PAROTID_GLAND_R', 	 value '13'
dataset label 'INNER_EAR_L', 	 value '14'
dataset label 'INNER_EAR_R', 	 value '15'
dataset label 'MID_EAR_L', 	 value '16'
dataset label 'MID_EAR_R', 	 value '17'
dataset label 'T_M_JOINT_L', 	 value '18'
dataset label 'T_M_JOINT_R', 	 value '19'
dataset label 'MANDIBLE_L', 	 value '21'
dataset label 'MANDIBLE_R', 	 value '22'


training each model

In [16]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']

tmp_list = list(filter_labels.items())
labels_list = [tmp_list[5], tmp_list[6], tmp_list[7], tmp_list[11], tmp_list[12]]
for OAR_KEY, OAR_VALUE in labels_list:
    print(f"{OAR_KEY}, {OAR_VALUE}")

OPT_NERVE_L, 6
OPT_NERVE_R, 7
OPT_CHIASMA, 8
PAROTID_GLAND_L, 12
PAROTID_GLAND_R, 13


In [17]:
TRAIN_MODELS = False
if TRAIN_MODELS:
    for OAR_KEY, OAR_VALUE in labels_list:
        cut_full_res_dataset.set_output_label(OAR_VALUE)
        log_date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        model_name = f'{log_date}_3d_unet_lowres_model3v1_{OAR_KEY}'

        print(f'Training model with dataset label \'{OAR_KEY}\', value \'{OAR_VALUE}\'')
        print(f'folder \'{model_name}\'')
        cut_model_info = prepare_model(epochs=75,
                                       learning_rate=3e-4,
                                       in_channels=8,
                                       input_data_channels=1,
                                       output_label_channels=1,
                                       dropout_rate=0.2,
                                       train_batch_size=1,
                                       model_name=model_name,
                                       train_dataset=cut_train_dataset, 
                                       valid_dataset=cut_valid_dataset, 
                                       test_dataset=cut_test_dataset,
                                       model_class=UNetV3v1)
        show_model_info(cut_model_info)
        print('\n\n')
        train_loop(cut_model_info)
        print('\n\n')

        # clearing memory
        del cut_model_info
        torch.cuda.empty_cache()

## DSC

In [18]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']

tmp_list = list(filter_labels.items())
labels_list = [tmp_list[5], tmp_list[6], tmp_list[7], tmp_list[10], tmp_list[11], tmp_list[12]]
for OAR_KEY, OAR_VALUE in labels_list:
    print(f"{OAR_KEY}, {OAR_VALUE}")

OPT_NERVE_L, 6
OPT_NERVE_R, 7
OPT_CHIASMA, 8
PITUITARY, 11
PAROTID_GLAND_L, 12
PAROTID_GLAND_R, 13


### loading models to CPU 

In [19]:
models = dict()
for OAR_KEY, OAR_VALUE in labels_list:
    epoch = 75
    possible_models = get_possible_models(f"model3v1_{OAR_KEY}")
    if len(possible_models) <= 0:
        print(f'{OAR_KEY} Model: No avaiable model')
        continue

    model_name = possible_models[0]
    print(f'{OAR_KEY} Model: Loading model {model_name}')

    # loading model checkpoint
    cut_model_info = load_checkpoint_model_info(model_name, epoch, cut_train_dataset, cut_valid_dataset, cut_test_dataset, model_class=UNetV3v1)

    # moving model to cpu/cuda with eval mode
    cut_model_info['device'] = 'cpu'
    cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
    cut_model_info['model'].eval()
    cut_model_info['model'].disable_tensorboard_writing = True
    
    models[OAR_KEY] = cut_model_info

OPT_NERVE_L Model: Loading model 20210309-012048_3d_unet_lowres_model3v1_OPT_NERVE_L
OPT_NERVE_R Model: Loading model 20210309-034637_3d_unet_lowres_model3v1_OPT_NERVE_R
OPT_CHIASMA Model: Loading model 20210309-061213_3d_unet_lowres_model3v1_OPT_CHIASMA
PITUITARY Model: Loading model 20210308-182550_3d_unet_lowres_model3v1_PITUITARY
PAROTID_GLAND_L Model: Loading model 20210309-083743_3d_unet_lowres_model3v1_PAROTID_GLAND_L
PAROTID_GLAND_R Model: Loading model 20210309-110322_3d_unet_lowres_model3v1_PAROTID_GLAND_R


In [24]:
SHOW_DSC_INFO = True
if SHOW_DSC_INFO:
    info_per_organs_df = {}
    models_info = list()
    for OAR_KEY, OAR_VALUE in labels_list:
        if OAR_KEY not in models:
            print(f'{OAR_KEY} Model: No avaiable model')
            continue

        # getting model to gpu
        cut_model_info = models[OAR_KEY]
        cut_model_info['device'] = get_device()
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        cut_model_info['model'].train()
        cut_model_info['model'].disable_tensorboard_writing = True

        # preparing dataset for comparison
        cut_full_res_dataset.set_output_label(OAR_VALUE)

        # calculating dsc predictions        
        info_df, preds, rescaled_preds = get_threshold_info_df(model=cut_model_info['model'], 
                                    dataset=cut_full_res_dataset, 
                                    device=cut_model_info['device'], 
                                    train_indices=cut_train_dataset.indices, 
                                    valid_indices=cut_valid_dataset.indices, 
                                    test_indices=cut_test_dataset.indices,
                                    step=0.5)
        info_per_organs_df[OAR_KEY] = info_df

        # moving model back to cpu
        cut_model_info['device'] = 'cpu'
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])

        # parsing data
        best_threshold_col = 'thres_rescaled_dsc_0.50'
        train_tmp_df = info_df[info_df['is_train']][best_threshold_col]
        valid_tmp_df = info_df[info_df['is_valid']][best_threshold_col]
        train_dsc = train_tmp_df.mean()
        valid_dsc = valid_tmp_df.mean()
        print(f'{OAR_KEY} Model: DSC train {round(train_dsc, 4)} valid {round(valid_dsc, 4)}')

        models_info.append({
            'oar_key': OAR_KEY,
            'model_name': model_name,
            # Train
            'train_dsc_mean': train_dsc,
            'train_dsc_std': train_tmp_df.std(),
            'train_dsc_median': train_tmp_df.median(),
            'train_dsc_min': train_tmp_df.min(),
            'train_dsc_max': train_tmp_df.max(),
            # Valid
            'valid_dsc_mean': valid_dsc,
            'valid_dsc_std': valid_tmp_df.std(),
            'valid_dsc_median': valid_tmp_df.median(),
            'valid_dsc_min': valid_tmp_df.min(),
            'valid_dsc_max': valid_tmp_df.max(),
            # Both
            'train_valid_mean_delta': train_dsc - valid_dsc
        })

    models_info_df = pd.DataFrame(models_info)
    
    tmp_df = models_info_df[['oar_key', 'train_dsc_mean', 'train_dsc_std', 'valid_dsc_mean', 'valid_dsc_std']].copy()
    tmp_df['train_dsc_mean'] = (tmp_df['train_dsc_mean'] * 100).round(2)
    tmp_df['valid_dsc_mean'] = (tmp_df['valid_dsc_mean'] * 100).round(2)
    tmp_df['train_dsc_std'] = (tmp_df['train_dsc_std'] * 100).round(2)
    tmp_df['valid_dsc_std'] = (tmp_df['valid_dsc_std'] * 100).round(2)
    
    display(tmp_df.mean().round(2))
    display(tmp_df.round(2))
    display(tmp_df.sort_values(by=['train_dsc_std']).round(2))
    display(models_info_df.sort_values(by=['train_dsc_mean']).drop(columns=['model_name']).round(2))
    display(models_info_df.sort_values(by=['train_valid_mean_delta']).drop(columns=['model_name']).round(2))

OPT_NERVE_L Model: DSC train 0.752 valid 0.6352
OPT_NERVE_R Model: DSC train 0.7527 valid 0.6131
OPT_CHIASMA Model: DSC train 0.6444 valid 0.4578
PITUITARY Model: DSC train 0.6468 valid 0.4909
PAROTID_GLAND_L Model: DSC train 0.8978 valid 0.8702
PAROTID_GLAND_R Model: DSC train 0.8923 valid 0.846


train_dsc_mean    76.43
train_dsc_std      8.11
valid_dsc_mean    65.22
valid_dsc_std      9.79
dtype: float64

Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,valid_dsc_mean,valid_dsc_std
0,OPT_NERVE_L,75.2,7.44,63.52,11.2
1,OPT_NERVE_R,75.27,9.5,61.31,11.69
2,OPT_CHIASMA,64.44,7.23,45.78,13.46
3,PITUITARY,64.68,21.17,49.09,14.84
4,PAROTID_GLAND_L,89.78,1.61,87.02,3.09
5,PAROTID_GLAND_R,89.23,1.73,84.6,4.44


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,valid_dsc_mean,valid_dsc_std
4,PAROTID_GLAND_L,89.78,1.61,87.02,3.09
5,PAROTID_GLAND_R,89.23,1.73,84.6,4.44
2,OPT_CHIASMA,64.44,7.23,45.78,13.46
0,OPT_NERVE_L,75.2,7.44,63.52,11.2
1,OPT_NERVE_R,75.27,9.5,61.31,11.69
3,PITUITARY,64.68,21.17,49.09,14.84


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,train_dsc_median,train_dsc_min,train_dsc_max,valid_dsc_mean,valid_dsc_std,valid_dsc_median,valid_dsc_min,valid_dsc_max,train_valid_mean_delta
2,OPT_CHIASMA,0.64,0.07,0.64,0.46,0.77,0.46,0.13,0.47,0.31,0.61,0.19
3,PITUITARY,0.65,0.21,0.73,0.13,0.91,0.49,0.15,0.56,0.27,0.63,0.16
0,OPT_NERVE_L,0.75,0.07,0.76,0.53,0.86,0.64,0.11,0.63,0.52,0.8,0.12
1,OPT_NERVE_R,0.75,0.09,0.78,0.49,0.87,0.61,0.12,0.59,0.47,0.79,0.14
5,PAROTID_GLAND_R,0.89,0.02,0.89,0.85,0.92,0.85,0.04,0.86,0.79,0.89,0.05
4,PAROTID_GLAND_L,0.9,0.02,0.9,0.87,0.93,0.87,0.03,0.87,0.82,0.9,0.03


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,train_dsc_median,train_dsc_min,train_dsc_max,valid_dsc_mean,valid_dsc_std,valid_dsc_median,valid_dsc_min,valid_dsc_max,train_valid_mean_delta
4,PAROTID_GLAND_L,0.9,0.02,0.9,0.87,0.93,0.87,0.03,0.87,0.82,0.9,0.03
5,PAROTID_GLAND_R,0.89,0.02,0.89,0.85,0.92,0.85,0.04,0.86,0.79,0.89,0.05
0,OPT_NERVE_L,0.75,0.07,0.76,0.53,0.86,0.64,0.11,0.63,0.52,0.8,0.12
1,OPT_NERVE_R,0.75,0.09,0.78,0.49,0.87,0.61,0.12,0.59,0.47,0.79,0.14
3,PITUITARY,0.65,0.21,0.73,0.13,0.91,0.49,0.15,0.56,0.27,0.63,0.16
2,OPT_CHIASMA,0.64,0.07,0.64,0.46,0.77,0.46,0.13,0.47,0.31,0.61,0.19


In [25]:
if SHOW_DSC_INFO:
    tmp_column = 'is_train' 
    
    print('OARS_LABELS.PAROTID_GLAND_L')
    tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.PAROTID_GLAND_L]]
    display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    
    print('OARS_LABELS.OPT_NERVE_L')
    tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.OPT_NERVE_L]]
    display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    
    print('OARS_LABELS.PITUITARY')
    tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.PITUITARY]]
    display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))

OARS_LABELS.PAROTID_GLAND_L


Unnamed: 0_level_0,dsc,rescaled_dsc,is_train,is_valid,is_test,thres_rescaled_dsc_0.00,thres_rescaled_dsc_0.50,thres_rescaled_dsc_1.00
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
24,0.860782,0.86079,True,False,False,0.006555,0.865229,1.309243e-10
47,0.859524,0.859549,True,False,False,0.004808,0.865879,1.786352e-10
45,0.867004,0.867019,True,False,False,0.004914,0.872036,1.748252e-10
20,0.868213,0.868231,True,False,False,0.005839,0.872653,1.470588e-10
49,0.865662,0.865677,True,False,False,0.003226,0.873873,2.665245e-10
37,0.866939,0.866969,True,False,False,0.003423,0.874292,2.511301e-10
41,0.879059,0.879085,True,False,False,0.006546,0.882623,1.311303e-10
8,0.878797,0.878823,True,False,False,0.006569,0.883691,1.306677e-10
43,0.879491,0.879511,True,False,False,0.004044,0.885221,2.125398e-10
1,0.880485,0.880504,True,False,False,0.005453,0.886276,1.574803e-10


OARS_LABELS.OPT_NERVE_L


Unnamed: 0_level_0,dsc,rescaled_dsc,is_train,is_valid,is_test,thres_rescaled_dsc_0.00,thres_rescaled_dsc_0.50,thres_rescaled_dsc_1.00
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
8,0.524342,0.524342,True,False,False,6.8e-05,0.526316,1.265823e-08
4,0.597179,0.597179,True,False,False,0.000105,0.599407,8.196722e-09
33,0.605275,0.605275,True,False,False,8.4e-05,0.609319,1.030928e-08
24,0.66349,0.66349,True,False,False,0.000226,0.665198,3.816794e-09
31,0.642596,0.642596,True,False,False,6e-05,0.666667,1.428571e-08
5,0.677889,0.677889,True,False,False,0.000209,0.676123,4.115226e-09
14,0.679609,0.679609,True,False,False,0.000207,0.678663,4.166667e-09
1,0.683899,0.683899,True,False,False,0.000347,0.686244,2.48139e-09
2,0.689796,0.689796,True,False,False,0.000435,0.693168,1.980198e-09
40,0.699361,0.699361,True,False,False,0.000378,0.699422,2.277904e-09


OARS_LABELS.PITUITARY


Unnamed: 0_level_0,dsc,rescaled_dsc,is_train,is_valid,is_test,thres_rescaled_dsc_0.00,thres_rescaled_dsc_0.50,thres_rescaled_dsc_1.00
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
48,0.126755,0.126755,True,False,False,1.1e-05,0.126829,7.692307e-08
15,0.150495,0.150495,True,False,False,2.6e-05,0.147465,3.333333e-08
42,0.237654,0.237654,True,False,False,2.6e-05,0.238095,3.333333e-08
22,0.263028,0.263028,True,False,False,2.4e-05,0.267943,3.571428e-08
8,0.349595,0.349595,True,False,False,4.7e-05,0.352941,1.851852e-08
11,0.397872,0.397872,True,False,False,4.7e-05,0.398551,1.818182e-08
31,0.417295,0.417295,True,False,False,4.2e-05,0.418803,2.040816e-08
12,0.412157,0.412157,True,False,False,4.6e-05,0.418972,1.886792e-08
17,0.424459,0.424459,True,False,False,7.1e-05,0.423773,1.219512e-08
2,0.476153,0.476153,True,False,False,7.7e-05,0.488235,1.123596e-08


# Predictions merging and checking

In [55]:
def custom_preview_dataset(dataset, preview_index=0, show_hist=False, use_transform=False):
    if use_transform:
        data, label = dataset[preview_index]
    else:
        data, label = dataset.get_raw_item_with_label_filter(preview_index)
    max_channels = label.shape[0]
    max_slices = label.shape[1]

    print(f'data max {data.max()}, min {data.min()}')
    print(f'label max {label.max()}, min {label.min()}')

    def f(slice_index, label_channel):
        plt.figure(figsize=(20, 10))
        plt.subplot(1, 3, 1)
        plt.imshow(data[0, slice_index], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.imshow(data[label_channel+1, slice_index])
        plt.subplot(1, 3, 3)
        plt.imshow(label[label_channel, slice_index])
        plt.show()

        if show_hist:
            plt.figure(figsize=(20, 10))
            plt.subplot(1, 2, 1)
            plt.hist(data.flatten(), 128)
            plt.subplot(1, 2, 2)
            plt.hist(label.flatten(), 128)
            plt.show()

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    labelChannelSlider = widgets.IntSlider(min=0, max=max_channels - 1, step=1, value=(max_channels - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider, labelChannelSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider, 'label_channel': labelChannelSlider})
    # noinspection PyTypeChecker
    display(ui, out)

In [15]:
filter_labels_dict = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels_dict['SPINAL_CORD']

cut_full_res_dataset.set_output_label(filter_labels_dict)
preview_dataset(cut_full_res_dataset)

data max 3071, min -1024
label max 1, min 0


VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=10, max=20))),))

Output()

In [56]:
from src.dataset.get_norm_transform import get_norm_transform
from src.dataset.transform_input import transform_input
from src.helpers.get_rescaled_pred import get_rescaled_pred

PARSE_CUT_DATASET = True
if PARSE_CUT_DATASET:
    extended_cut_full_res_dataset = cut_full_res_dataset.copy()
    
    # preparing cut dataset
    for index in range(len(extended_cut_full_res_dataset)):
        tmp_label = extended_cut_full_res_dataset.label_list[index]
        new_data_channels = len(extended_cut_full_res_dataset.output_label) + 1
        new_data_shape = (new_data_channels, *tmp_label[0].shape)
        new_data = np.zeros(new_data_shape, dtype=np.int16)
        new_data[0] = extended_cut_full_res_dataset.data_list[index][0]

        extended_cut_full_res_dataset.data_list[index] = new_data

    prediction_threshold = 0.5
    output_label_items = list(extended_cut_full_res_dataset.output_label.items())[:]
    # for each label
    for label_index, val in enumerate(output_label_items[:]):
        OAR_KEY, OAR_VALUE = val
        # loading model
        if OAR_KEY not in models:
            print(f'{label_index+1}/{len(output_label_items)}: {OAR_KEY} Model: No avaiable model')
            continue
        print(f'{label_index+1}/{len(output_label_items)}: {OAR_KEY} Model: got model {datetime.datetime.now()}')

        # getting model to gpu
        cut_model_info = models[OAR_KEY]
        cut_model_info['device'] = get_device()
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        cut_model_info['model'].eval()

        # for label in whole dataset
        for index in range(len(extended_cut_full_res_dataset)):
            prediction, rescaled_pred = get_rescaled_pred(cut_model_info['model'], 
                                                          cut_full_res_dataset, 
                                                          cut_model_info['device'], 
                                                          index,
                                                          use_only_one_dimension=True)
            
            extended_cut_full_res_dataset.data_list[index][label_index + 1] = ((rescaled_pred > prediction_threshold) * 1).astype(np.int8)

        # moving model back to cpu
        cut_model_info['device'] = 'cpu'
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        
    custom_preview_dataset(extended_cut_full_res_dataset)

1/21: BRAIN_STEM Model: No avaiable model
2/21: EYE_L Model: No avaiable model
3/21: EYE_R Model: No avaiable model
4/21: LENS_L Model: No avaiable model
5/21: LENS_R Model: No avaiable model
6/21: OPT_NERVE_L Model: got model 2021-03-15 16:06:59.917339
7/21: OPT_NERVE_R Model: got model 2021-03-15 16:07:19.936131
8/21: OPT_CHIASMA Model: got model 2021-03-15 16:07:39.733704
9/21: TEMPORAL_LOBES_L Model: No avaiable model
10/21: TEMPORAL_LOBES_R Model: No avaiable model
11/21: PITUITARY Model: got model 2021-03-15 16:07:59.351204
12/21: PAROTID_GLAND_L Model: got model 2021-03-15 16:08:19.769163
13/21: PAROTID_GLAND_R Model: got model 2021-03-15 16:08:39.509099
14/21: INNER_EAR_L Model: No avaiable model
15/21: INNER_EAR_R Model: No avaiable model
16/21: MID_EAR_L Model: No avaiable model
17/21: MID_EAR_R Model: No avaiable model
18/21: T_M_JOINT_L Model: No avaiable model
19/21: T_M_JOINT_R Model: No avaiable model
20/21: MANDIBLE_L Model: No avaiable model
21/21: MANDIBLE_R Model: No

VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=10, max=20))),))

Output()

## Merging predictions

In [53]:
MERGE_PREDICTIONS = True
if MERGE_PREDICTIONS:
    merged_predictions = [None] * len(extended_cut_full_res_dataset)
    for index in range(len(extended_cut_full_res_dataset)):
        # print(f"{index+1}/{len(extended_cut_full_res_dataset)}: Merging predictions to single label")
        data, label = extended_cut_full_res_dataset.get_raw_item_with_label_filter(index)

        new_data = np.zeros(data[0].shape, dtype=np.int16)
        for i in range(1, 22):
            new_data += data[i]

        merged_predictions[index] = new_data
    print('Merging done')

    # checking how many masks are overlapping
    for i, tmp_merged in enumerate(merged_predictions):
        display(f'{i}, {np.where(tmp_merged == 1)[0].shape[0]}, {np.where(tmp_merged == 2)[0].shape[0]}, {np.where(tmp_merged == 3)[0].shape[0]}, {np.where(tmp_merged == 4)[0].shape[0]}')

Merging done


'0, 13120, 0, 0, 0'

'1, 14534, 0, 0, 0'

'2, 6077, 0, 0, 0'

'3, 11339, 3, 0, 0'

'4, 20124, 12, 0, 0'

'5, 15714, 0, 0, 0'

'6, 20554, 0, 0, 0'

'7, 7725, 0, 0, 0'

'8, 16483, 0, 0, 0'

'9, 11301, 0, 0, 0'

'10, 21334, 0, 0, 0'

'11, 17808, 0, 0, 0'

'12, 19614, 0, 0, 0'

'13, 15096, 0, 0, 0'

'14, 14050, 0, 0, 0'

'15, 11787, 0, 0, 0'

'16, 7616, 0, 0, 0'

'17, 20557, 0, 0, 0'

'18, 18881, 0, 0, 0'

'19, 10633, 17, 0, 0'

'20, 15394, 20, 0, 0'

'21, 19341, 0, 0, 0'

'22, 7701, 0, 0, 0'

'23, 9684, 3, 0, 0'

'24, 12251, 0, 0, 0'

'25, 17332, 10, 0, 0'

'26, 13218, 0, 0, 0'

'27, 14729, 0, 0, 0'

'28, 7625, 0, 0, 0'

'29, 14825, 0, 0, 0'

'30, 9987, 0, 0, 0'

'31, 8540, 0, 0, 0'

'32, 16243, 0, 0, 0'

'33, 11370, 0, 0, 0'

'34, 15393, 23, 0, 0'

'35, 20892, 0, 0, 0'

'36, 17766, 13, 0, 0'

'37, 9148, 10, 0, 0'

'38, 20507, 0, 0, 0'

'39, 8702, 0, 0, 0'

'40, 13085, 41, 0, 0'

'41, 13142, 0, 0, 0'

'42, 9131, 0, 0, 0'

'43, 7946, 0, 0, 0'

'44, 24287, 0, 0, 0'

'45, 7299, 0, 0, 0'

'46, 13661, 0, 0, 0'

'47, 9608, 0, 0, 0'

'48, 7284, 0, 0, 0'

'49, 7523, 0, 0, 0'

In [54]:
from src.helpers.calc_dsc import calc_dsc

def custom_preview_dataset2(dataset, preview_index=0, show_hist=False, use_transform=False):
    if use_transform:
        data, label = dataset[preview_index]
    else:
        data, label = dataset.get_raw_item_with_label_filter(preview_index)
        
    cut_data, cut_label = cut_full_res_dataset.get_raw_item_with_label_filter(preview_index)
    max_channels = label.shape[0]
    max_slices = label.shape[1]
    
    print(f'data max {data.max()}, min {data.min()}')
    print(f'label max {label.max()}, min {label.min()}')
    print(f'{data.shape}, {cut_data.shape}, {label.shape}, {cut_label.shape}')
    print(f'{data.dtype}, {cut_data.dtype}, {label.dtype}, {cut_label.dtype}')
    tmp_merged = merged_predictions[preview_index]
    print(f'{np.where(tmp_merged == 1)[0].shape},{np.where(tmp_merged == 2)[0].shape},{np.where(tmp_merged == 3)[0].shape},{np.where(tmp_merged == 4)[0].shape}')

    def f(slice_index, label_channel):
        print(f'{OARS_LABELS.OARS_LABELS_R_DICT[label_channel+1]}')
        tmp_tensor_label = torch.tensor(label[label_channel])
        tmp_tensor_prediciton = torch.tensor(data[label_channel+1])
        tmp_dsc = calc_dsc(tmp_tensor_label, tmp_tensor_prediciton)
        print(f'dsc {tmp_dsc}')

        plt.figure(figsize=(30, 20))

        plt.subplot(2, 3, 1).title.set_text('data')
        plt.imshow(cut_data[0, slice_index], cmap="gray")
        plt.subplot(2, 3, 2).title.set_text('label')
        plt.imshow(label[label_channel, slice_index])
        plt.subplot(2, 3, 3).title.set_text('prediciton')
        plt.imshow(data[label_channel+1, slice_index])
        # print(data.shape, np.sum(data[label_channel+1]), np.unique(data[1])[-1])
        print(f'slices with values > 0', (np.where(data[label_channel+1] > 0))[0])
        
        plt.subplot(2, 3, 4).title.set_text('merged prediction labels')
        plt.imshow(tmp_merged[slice_index], vmin=0, vmax=np.unique(tmp_merged)[-1])
        plt.subplot(2, 3, 5).title.set_text('merged labels ')
        plt.imshow(np.sum(label, axis=0)[slice_index])

        plt.show()

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    labelChannelSlider = widgets.IntSlider(min=0, max=max_channels - 1, step=1, value=(max_channels - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider, labelChannelSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider, 'label_channel': labelChannelSlider})
    # noinspection PyTypeChecker
    display(ui, out)

custom_preview_dataset2(extended_cut_full_res_dataset, preview_index=0)

data max 3071, min -1024
label max 1, min 0
(22, 72, 192, 168), (1, 72, 192, 168), (21, 72, 192, 168), (21, 72, 192, 168)
int16, int16, int8, int8
(13120,),(0,),(0,),(0,)


VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=10, max=20))),))

Output()