In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/My Drive/dp_tomastik/code')
    !bash "/content/drive/My Drive/dp_tomastik/code/scripts/install_libs.sh"

import SimpleITK as sitk
import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import pandas as pd
import logging
import datetime

from torchio import RandomAffine, Compose, ZNormalization
from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets

from src.helpers import preview_dataset, get_threshold_info_df, preview_model_dataset_pred
from src.helpers import show_cuda_usage, get_rescaled_preds
from src.helpers import compare_prediction_with_ground_true, compare_one_prediction_with_ground_true
from src.helpers import get_img_outliers_pixels, get_raw_with_prediction

from src.dataset import get_cut_lists, get_full_res_cut, get_dataset
from src.dataset import get_dataset_info, get_dataset_transform
from src.dataset import split_dataset, copy_split_dataset, OARS_LABELS
from src.dataset import HaNOarsDataset

from src.model_and_training import prepare_model, train_loop
from src.model_and_training import load_checkpoint_model_info, show_model_info

from src.model_and_training.getters.get_device import get_device
from src.consts import DATASET_MAX_BOUNDING_BOX, DESIRE_BOUNDING_BOX_SIZE


torch.manual_seed(20)
logging.basicConfig(filename='logs/model2_all_organs_jupyter.log', level=logging.DEBUG)

print('Dataset biggest bounding box wihtout spinal cord', DATASET_MAX_BOUNDING_BOX)
print('Cut target size', DESIRE_BOUNDING_BOX_SIZE)
print('Done Init')

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits

Dataset biggest bounding box wihtout spinal cord [56, 177, 156]
Cut target size [72, 192, 168]
Done Init


In [2]:
def get_possible_models(oar_key):
    possible_models = [folder_name for folder_name in os.listdir('./models') if oar_key in folder_name]    
    
    return possible_models

# Training all organs models

In [3]:
data_path = f'./data/HaN_OAR_cut_72_192_168'
cut_dataset = HaNOarsDataset(data_path, size=50, load_images=False)
cut_dataset.load_from_file(data_path)
cut_dataset_obj = split_dataset(cut_dataset, train_size=40, valid_size=5, test_size=5)
cut_train_dataset, cut_valid_dataset, cut_test_dataset = itemgetter(*['train_dataset', 'valid_dataset', 'test_dataset'])(cut_dataset_obj)

In [4]:
get_dataset_info(cut_dataset, cut_dataset_obj)
preview_dataset(cut_dataset)

train 40, valid_size 5, test 5, full 50
train indices [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 20, 21, 22, 23, 24, 28, 30, 31, 32, 33, 34, 35, 36, 37, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
valid indices [6, 13, 19, 25, 38]
test indices [16, 26, 27, 29, 39]
data max 3071, min -1024
label max 22, min 0


VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=0, max=0))),))

Output()

computing dataset

In [5]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']

tmp_list = list(filter_labels.items())
labels_list = [tmp_list[10]]
for OAR_KEY, OAR_VALUE in labels_list:
    print(f"{OAR_KEY}, {OAR_VALUE}")

PITUITARY, 11


In [6]:
TRAIN_MODELS = False
if TRAIN_MODELS:
    for OAR_KEY, OAR_VALUE in labels_list:
        cut_dataset.set_output_label(OAR_VALUE)
        log_date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        model_name = f'{log_date}_3d_unet_lowres_model2__cloud-{OAR_KEY}'

        print(f'Training model with dataset label \'{OAR_KEY}\', value \'{OAR_VALUE}\'')
        print(f'folder \'{model_name}\'')
        cut_model_info = prepare_model(epochs=175,
                                       learning_rate=3e-4,
                                       in_channels=8,
                                       input_data_channels=1,
                                       output_label_channels=1,
                                       dropout_rate=0.2,
                                       train_batch_size=2,
                                       model_name=model_name,
                                       train_dataset=cut_train_dataset, 
                                       valid_dataset=cut_valid_dataset, 
                                       test_dataset=cut_test_dataset)
        show_model_info(cut_model_info)
        print('\n\n')
        train_loop(cut_model_info)
        print('\n\n')

        # clearing memory
        del cut_model_info
        torch.cuda.empty_cache()

# Preview organ model

# Loading all models

In [7]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']
    
models = dict()
for OAR_KEY, OAR_VALUE in list(filter_labels.items())[:]:
    epoch = 100
    possible_models = get_possible_models(f"model2__cloud-{OAR_KEY}")
    if len(possible_models) <= 0:
        print(f'{OAR_KEY} Model: No avaiable model')
        continue

    model_name = possible_models[0]
    print(f'{OAR_KEY} Model: Loading model {model_name}')

    # loading model checkpoint
    cut_model_info = load_checkpoint_model_info(model_name, epoch, cut_train_dataset, cut_valid_dataset, cut_test_dataset)

    # moving model to cpu/cuda with eval mode
    cut_model_info['device'] = 'cpu'
    cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
    cut_model_info['model'].eval()
    
    models[OAR_KEY] = cut_model_info

BRAIN_STEM Model: No avaiable model
EYE_L Model: No avaiable model
EYE_R Model: No avaiable model
LENS_L Model: No avaiable model
LENS_R Model: No avaiable model
OPT_NERVE_L Model: No avaiable model
OPT_NERVE_R Model: No avaiable model
OPT_CHIASMA Model: No avaiable model
TEMPORAL_LOBES_L Model: No avaiable model
TEMPORAL_LOBES_R Model: No avaiable model
PITUITARY Model: Loading model 20210426-171235_3d_unet_lowres_model2__cloud-PITUITARY
PAROTID_GLAND_L Model: No avaiable model
PAROTID_GLAND_R Model: No avaiable model
INNER_EAR_L Model: No avaiable model
INNER_EAR_R Model: No avaiable model
MID_EAR_L Model: No avaiable model
MID_EAR_R Model: No avaiable model
T_M_JOINT_L Model: No avaiable model
T_M_JOINT_R Model: No avaiable model
MANDIBLE_L Model: No avaiable model
MANDIBLE_R Model: No avaiable model


# Calculating DSC for all models

In [8]:
SHOW_DSC_INFO = True
if SHOW_DSC_INFO:
    info_per_organs_df = {}
    models_info = list()
    for OAR_KEY, OAR_VALUE in list(filter_labels.items())[:]:
        if OAR_KEY not in models:
            print(f'{OAR_KEY} Model: No avaiable model')
            continue

        # getting model to gpu
        cut_model_info = models[OAR_KEY]
        cut_model_info['device'] = get_device()
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        cut_model_info['model'].eval()

        # preparing dataset for comparison
        cut_dataset.set_output_label(OAR_VALUE)

        # calculating dsc predictions        
        info_df, preds, rescaled_preds = get_threshold_info_df(model=cut_model_info['model'], 
                                    dataset=cut_dataset, 
                                    device=cut_model_info['device'], 
                                    train_indices=cut_train_dataset.indices, 
                                    valid_indices=cut_valid_dataset.indices, 
                                    test_indices=cut_test_dataset.indices,
                                    step=0.5)
        info_per_organs_df[OAR_KEY] = info_df

        # moving model back to cpu
        cut_model_info['device'] = 'cpu'
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])

        # parsing data
        best_threshold_col = 'thres_rescaled_dsc_0.50'
        train_tmp_df = info_df[info_df['is_train']][best_threshold_col]
        valid_tmp_df = info_df[info_df['is_valid']][best_threshold_col]
        test_tmp_df = info_df[info_df['is_test']][best_threshold_col]
        
        train_dsc = train_tmp_df.mean()
        valid_dsc = valid_tmp_df.mean()
        test_dsc = test_tmp_df.mean()
        print(f'{OAR_KEY} Model: DSC train {round(train_dsc, 4)} valid {round(valid_dsc, 4)}')

        models_info.append({
            'oar_key': OAR_KEY,
            'model_name': model_name,
            # Train
            'train_dsc_mean': train_dsc,
            'train_dsc_std': train_tmp_df.std(),
            'train_dsc_median': train_tmp_df.median(),
            'train_dsc_min': train_tmp_df.min(),
            'train_dsc_max': train_tmp_df.max(),
            # Valid
            'valid_dsc_mean': valid_dsc,
            'valid_dsc_std': valid_tmp_df.std(),
            'valid_dsc_median': valid_tmp_df.median(),
            'valid_dsc_min': valid_tmp_df.min(),
            'valid_dsc_max': valid_tmp_df.max(),
            # Test
            'test_dsc_mean': test_dsc,
            'test_dsc_std': test_tmp_df.std(),
            'test_dsc_median': test_tmp_df.median(),
            'test_dsc_min': test_tmp_df.min(),
            'test_dsc_max': test_tmp_df.max(),
            # Both
            'train_valid_mean_delta': train_dsc - valid_dsc
        })

    models_info_df = pd.DataFrame(models_info)
    
    tmp_df = models_info_df[['model_name',
                             'train_dsc_mean', 'train_dsc_std', 
                             'valid_dsc_mean', 'valid_dsc_std', 
                             'test_dsc_mean', 'test_dsc_std']].copy()
    tmp_df['train_dsc_mean'] = (tmp_df['train_dsc_mean'] * 100).round(2)
    tmp_df['valid_dsc_mean'] = (tmp_df['valid_dsc_mean'] * 100).round(2)
    tmp_df['test_dsc_mean'] = (tmp_df['test_dsc_mean'] * 100).round(2)

    tmp_df['train_dsc_std'] = (tmp_df['train_dsc_std'] * 100).round(2)
    tmp_df['valid_dsc_std'] = (tmp_df['valid_dsc_std'] * 100).round(2)
    tmp_df['test_dsc_std'] = (tmp_df['test_dsc_std'] * 100).round(2)
    
    display(tmp_df.mean().round(2))
    display(tmp_df.round(2))
    display(tmp_df.sort_values(by=['train_dsc_std']).round(2))
    display(models_info_df.sort_values(by=['train_dsc_mean']).drop(columns=['model_name']).round(2))
    display(models_info_df.sort_values(by=['train_valid_mean_delta']).drop(columns=['model_name']).round(2))

BRAIN_STEM Model: No avaiable model
EYE_L Model: No avaiable model
EYE_R Model: No avaiable model
LENS_L Model: No avaiable model
LENS_R Model: No avaiable model
OPT_NERVE_L Model: No avaiable model
OPT_NERVE_R Model: No avaiable model
OPT_CHIASMA Model: No avaiable model
TEMPORAL_LOBES_L Model: No avaiable model
TEMPORAL_LOBES_R Model: No avaiable model
PITUITARY Model: DSC train 0.6222 valid 0.4519
PAROTID_GLAND_L Model: No avaiable model
PAROTID_GLAND_R Model: No avaiable model
INNER_EAR_L Model: No avaiable model
INNER_EAR_R Model: No avaiable model
MID_EAR_L Model: No avaiable model
MID_EAR_R Model: No avaiable model
T_M_JOINT_L Model: No avaiable model
T_M_JOINT_R Model: No avaiable model
MANDIBLE_L Model: No avaiable model
MANDIBLE_R Model: No avaiable model


train_dsc_mean    62.22
train_dsc_std     20.47
valid_dsc_mean    45.19
valid_dsc_std     21.10
test_dsc_mean     33.40
test_dsc_std      22.32
dtype: float64

Unnamed: 0,model_name,train_dsc_mean,train_dsc_std,valid_dsc_mean,valid_dsc_std,test_dsc_mean,test_dsc_std
0,20210426-171235_3d_unet_lowres_model2__cloud-P...,62.22,20.47,45.19,21.1,33.4,22.32


Unnamed: 0,model_name,train_dsc_mean,train_dsc_std,valid_dsc_mean,valid_dsc_std,test_dsc_mean,test_dsc_std
0,20210426-171235_3d_unet_lowres_model2__cloud-P...,62.22,20.47,45.19,21.1,33.4,22.32


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,train_dsc_median,train_dsc_min,train_dsc_max,valid_dsc_mean,valid_dsc_std,valid_dsc_median,valid_dsc_min,valid_dsc_max,test_dsc_mean,test_dsc_std,test_dsc_median,test_dsc_min,test_dsc_max,train_valid_mean_delta
0,PITUITARY,0.62,0.2,0.65,0.0,0.92,0.45,0.21,0.35,0.28,0.76,0.33,0.22,0.35,0.03,0.56,0.17


Unnamed: 0,oar_key,train_dsc_mean,train_dsc_std,train_dsc_median,train_dsc_min,train_dsc_max,valid_dsc_mean,valid_dsc_std,valid_dsc_median,valid_dsc_min,valid_dsc_max,test_dsc_mean,test_dsc_std,test_dsc_median,test_dsc_min,test_dsc_max,train_valid_mean_delta
0,PITUITARY,0.62,0.2,0.65,0.0,0.92,0.45,0.21,0.35,0.28,0.76,0.33,0.22,0.35,0.03,0.56,0.17


In [9]:
if SHOW_DSC_INFO:
    tmp_column = 'is_train' 
    
#     print('OARS_LABELS.PAROTID_GLAND_L')
#     tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.PAROTID_GLAND_L]]
#     display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    
#     print('OARS_LABELS.OPT_NERVE_L')
#     tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.OPT_NERVE_L]]
#     display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    
    print('OARS_LABELS.PITUITARY')
    tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.PITUITARY]]
    # display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    display(tmp_df.sort_values(by='thres_rescaled_dsc_0.50'))

OARS_LABELS.PITUITARY


Unnamed: 0_level_0,dsc,rescaled_dsc,is_train,is_valid,is_test,thres_rescaled_dsc_0.00,thres_rescaled_dsc_0.50,thres_rescaled_dsc_1.00
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
15,8e-06,8e-06,True,False,False,2.6e-05,2.083333e-08,3.333333e-08
29,4.8e-05,0.002946,False,False,True,0.000177,0.02884616,4.878049e-09
48,0.117911,0.117911,True,False,False,1.1e-05,0.15,7.692307e-08
39,0.217418,0.217728,False,False,True,0.000127,0.2011173,6.756757e-09
12,0.023223,0.165507,True,False,False,4.6e-05,0.2295082,1.886792e-08
6,0.296809,0.296809,False,True,False,9.6e-05,0.2810811,8.928572e-09
38,0.297914,0.297914,False,True,False,4.3e-05,0.2941177,2e-08
19,0.361915,0.370673,False,True,False,9.9e-05,0.3453237,8.695652e-09
27,0.334105,0.334105,False,False,True,3.4e-05,0.3529412,2.5e-08
22,0.35081,0.35081,True,False,False,2.4e-05,0.3733334,3.571428e-08
