In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/My Drive/dp_tomastik/code')
    !bash "/content/drive/My Drive/dp_tomastik/code/scripts/install_libs.sh"
    
import SimpleITK as sitk
import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import pandas as pd
import logging
import datetime
from torchio import RandomAffine, Compose, ZNormalization
from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets

from src.helpers import preview_3d_image
from src.helpers import show_cuda_usage, preview_model_dataset_pred, preview_dataset
from src.helpers import get_threshold_info_df, get_rescaled_preds
from src.helpers import compare_prediction_with_ground_true, compare_one_prediction_with_ground_true
from src.helpers import get_img_outliers_pixels, get_raw_with_prediction
from src.helpers import get_rescaled_pred
from src.helpers import get_transformed_label_np, create_regis_trans_list, trans_list

from src.dataset import HaNOarsDataset, transform_input_with_registration, get_norm_transform
from src.dataset import get_full_res_cut, get_cut_lists, OARS_LABELS, get_dataset, get_dataset_info, get_dataset_transform
from src.dataset import split_dataset, copy_split_dataset

from src.model_and_training import prepare_model, train_loop, show_model_info, load_checkpoint_model_info
from src.model_and_training import iterate_model_v3v2
from src.model_and_training.getters.get_device import get_device
from src.model_and_training.architectures.unet_architecture_v3v2 import UNetV3v2

from src.consts import DATASET_MAX_BOUNDING_BOX, DESIRE_BOUNDING_BOX_SIZE
  
torch.manual_seed(20)
logging.basicConfig(filename='logs/model3v2_all_organs_jupyter.log', level=logging.DEBUG)

print('Dataset biggest bounding box wihtout spinal cord', DATASET_MAX_BOUNDING_BOX)
print('Cut target size', DESIRE_BOUNDING_BOX_SIZE)
print('Done Init')

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits

Dataset biggest bounding box wihtout spinal cord [56, 177, 156]
Cut target size [72, 192, 168]
Done Init


# Loading precourse neural network with datasets:
- loading fullres dataset (512x512)
- loading lowres dataset (32x32)
- loading precourse model
- parsing dataset to create cut dataset

In [2]:
CALCULATE_DATASET = False
data_path = './data/HaN_OAR_cut_72_192_168'

if CALCULATE_DATASET:
    datasets_params = ['train_dataset', 'valid_dataset', 'test_dataset']
    filter_labels = OARS_LABELS.OARS_LABELS_LIST
    if OARS_LABELS.SPINAL_CORD in filter_labels:
        filter_labels.remove(OARS_LABELS.SPINAL_CORD)

    # low res
    low_res_dataset = get_dataset(dataset_size=50, shrink_factor=16, filter_labels=filter_labels, unify_labels=True)
    low_res_dataset.dilatate_labels(repeat=1)
    low_res_dataset.to_numpy()
    low_res_split_dataset_obj = split_dataset(low_res_dataset, train_size=40, valid_size=5, test_size=5)
    train_low_res_dataset, valid_low_res_dataset, test_low_res_dataset = itemgetter(*datasets_params)(low_res_split_dataset_obj)

    # full res
    full_res_dataset = get_dataset(dataset_size=50, shrink_factor=1, filter_labels=filter_labels, unify_labels=False)
    full_res_dataset.to_numpy()
    full_res_split_dataset_obj = copy_split_dataset(full_res_dataset, low_res_split_dataset_obj)

    # low res model - precourse model
    epoch = 500
    log_date = datetime.datetime(year=2020, month=10, day=27, hour=11, minute=45, second=30).strftime("%Y%m%d-%H%M%S")
    model_name = f'{log_date}_3d_unet_PRECOURSE'

    low_res_model_info = load_checkpoint_model_info(model_name, epoch, train_low_res_dataset, valid_low_res_dataset, test_low_res_dataset)
    show_model_info(low_res_model_info)

    # moving low res to gpu
    low_res_model_info['device'] = get_device()
    # low_res_model_info['device'] = 'cpu'
    low_res_model_info['model'] = low_res_model_info['model'].to(low_res_model_info['device'])
    low_res_model_info['model'].eval()

    # cut res
    cut_full_res_dataset = full_res_dataset.copy(copy_lists=False)
    cut_full_res_dataset = get_cut_lists(low_res_model_info['model'],
                                         low_res_model_info['device'],
                                         low_res_dataset, 
                                         full_res_dataset, 
                                         cut_full_res_dataset, 
                                         low_res_mask_threshold=0.5)
    cut_full_res_dataset.set_output_label(None)
    cut_split_dataset_obj = copy_split_dataset(cut_full_res_dataset, low_res_split_dataset_obj)
    cut_train_dataset, cut_valid_dataset, cut_test_dataset = itemgetter(*datasets_params)(cut_split_dataset_obj)

    # moving low res model to cpu
    low_res_model_info['device'] = 'cpu'
    low_res_model_info['model'] = low_res_model_info['model'].to(low_res_model_info['device'])
    
    # saving parsed data
    cut_full_res_dataset.save_to_file(data_path)
else:
    print('loading cut dataset')
    cut_full_res_dataset = HaNOarsDataset(data_path, size=50, load_images=False)
    cut_full_res_dataset.load_from_file(data_path)

loading cut dataset


In [3]:
cut_full_res_dataset.data_list[0].shape, cut_full_res_dataset.label_list[0].shape

((1, 72, 192, 168), (1, 72, 192, 168))

In [4]:
try:
    get_dataset_info(low_res_dataset, low_res_split_dataset_obj)
except NameError:
    print('low_res_dataset not loaded')

low_res_dataset not loaded


In [None]:
import matplotlib.pyplot as plt

from IPython.display import display
from ipywidgets import widgets

from src.dataset.dataset_transforms import get_dataset_transform
from src.dataset.transform_input import transform_input


def preview_dataset(dataset, preview_index=0, show_hist=False, use_transform=False):
    data, label = dataset.get_raw_item_with_label_filter(preview_index)  # equivalent dataset[preview_index]
    if use_transform:
        transform = get_dataset_transform()
        data, label = transform_input(data, label, transform)

    max_channels = label.shape[0]
    max_slices = label.shape[1]

    print(f'data max {data.max()}, min {data.min()}')
    print(f'label max {label.max()}, min {label.min()}')

    def f(slice_index, label_channel):
        plt.figure(figsize=(20, 10))
        plt.subplot(1, 2, 1)
        plt.imshow(data[0, slice_index], cmap="gray")
        plt.subplot(1, 2, 2)
        plt.imshow(label[label_channel, slice_index])
        plt.show()

        if show_hist:
            plt.figure(figsize=(20, 10))
            plt.subplot(1, 2, 1)
            plt.hist(data.flatten(), 128)
            plt.subplot(1, 2, 2)
            plt.hist(label.flatten(), 128)
            plt.show()

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    labelChannelSlider = widgets.IntSlider(min=0, max=max_channels - 1, step=1, value=(max_channels - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider, labelChannelSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider, 'label_channel': labelChannelSlider})
    # noinspection PyTypeChecker
    display(ui, out)


In [5]:
preview_dataset(cut_full_res_dataset)

data max 3071, min -1024
label max 22, min 0


VBox(children=(HBox(children=(IntSlider(value=35, max=71), IntSlider(value=0, max=0))),))

Output()

## Adding registration to dataset

In [6]:
atlas_ri_sitk = sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/RI.mhd')
atlas_brainstem_map_sitk = sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/brain_stem_map.mhd')
atlas_left_parotid_map_sitk = sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/left_parotid_map.mhd')
atlas_right_parotid_map_sitk = sitk.ReadImage('./data/PDDCA-1.2-atlas/probabilistic_atlas/right_parotid_map.mhd')

atlas_ri = sitk.GetArrayFromImage(atlas_ri_sitk)
atlas_brainstem_map = sitk.GetArrayFromImage(atlas_brainstem_map_sitk)
atlas_left_parotid_map = sitk.GetArrayFromImage(atlas_left_parotid_map_sitk)
atlas_right_parotid_map = sitk.GetArrayFromImage(atlas_right_parotid_map_sitk)

print('Atlas shape', atlas_ri.shape, atlas_brainstem_map.shape, atlas_left_parotid_map.shape, atlas_right_parotid_map.shape)

Atlas shape (136, 120, 219) (136, 120, 219) (136, 120, 219) (136, 120, 219)


In [7]:
# resampling atlas to higher spacing in slices
atlas_resampler = sitk.ResampleImageFilter()
atlas_resampler.SetReferenceImage(atlas_ri_sitk)
atlas_resampler.SetInterpolator(sitk.sitkLinear)
new_spacing = atlas_ri_sitk.GetSpacing()
new_spacing = (new_spacing[0], new_spacing[1], 5.75)
atlas_resampler.SetOutputSpacing(new_spacing)

atlas_input_data_sitk = atlas_resampler.Execute(atlas_ri_sitk)
atlas_input_label_sitk = atlas_resampler.Execute(atlas_right_parotid_map_sitk)

# parsing to numpy
atlas_input_data_np = sitk.GetArrayFromImage(atlas_input_data_sitk)
atlas_input_label_np = sitk.GetArrayFromImage(atlas_input_label_sitk)
atlas_input_np = (atlas_input_data_np, atlas_input_label_np)

print('Atlas input', atlas_input_data_np.shape, atlas_input_label_np.shape)

Atlas input (136, 120, 219) (136, 120, 219)


showing registration example

In [8]:
cut_full_res_dataset.spacing_list[1]

(0.9765625, 0.9765625, 3.0)

In [9]:
print('Atlas and dataset spacing', atlas_ri_sitk.GetSpacing(), cut_full_res_dataset.spacing_list[0]) 

Atlas and dataset spacing (2.389269, 2.5155, 3.216912) (1.1796875, 1.1796875, 3.0)


In [10]:
preview_3d_image(atlas_ri, figsize=(5,5))
preview_3d_image(atlas_brainstem_map, figsize=(5,5))
preview_3d_image(atlas_left_parotid_map, figsize=(5,5))
preview_3d_image(atlas_right_parotid_map, figsize=(5,5))

VBox(children=(HBox(children=(IntSlider(value=67, max=135),)),))

Output()

VBox(children=(HBox(children=(IntSlider(value=67, max=135),)),))

Output()

VBox(children=(HBox(children=(IntSlider(value=67, max=135),)),))

Output()

VBox(children=(HBox(children=(IntSlider(value=67, max=135),)),))

Output()

In [11]:
cut_full_res_dataset[0][0].shape, cut_full_res_dataset[0][1].shape

((1, 72, 192, 168), (1, 72, 192, 168))

In [16]:
preview_3d_image(dataset_input[0][0])

VBox(children=(HBox(children=(IntSlider(value=35, max=71),)),))

Output()

In [12]:
dataset_input = cut_full_res_dataset.get_raw_item_with_label_filter(0)
tmp = get_transformed_label_np(dataset_input, atlas_input_np, numberOfIterations=2500, show=False, preview=True, figsize=(10, 10))

Optimizer stop condition: RegularStepGradientDescentOptimizerv4: Step too small after 211 iterations. Current step (6.10352e-05) is less than minimum step (0.0001).
  Iteration: 212
  Metric value: -0.7815303303309206


VBox(children=(HBox(children=(IntSlider(value=35, max=71),)),))

Output()

In [13]:
preview_3d_image(atlas_brainstem_map_sitk + atlas_left_parotid_map_sitk + atlas_right_parotid_map_sitk)

VBox(children=(HBox(children=(IntSlider(value=67, max=135),)),))

Output()

registering whole dataset to probabilistic atlas

In [19]:
CALC_REGISTRATION = False
if CALC_REGISTRATION:
    regis_trans_list = create_regis_trans_list(cut_full_res_dataset, atlas_input_data_np, numberOfIterations=3000)

Optimizer stop condition: RegularStepGradientDescentOptimizerv4: Step too small after 211 iterations. Current step (6.10352e-05) is less than minimum step (0.0001).
  Iteration: 212
  Metric value: -0.7815303303309206
Registration done for index: 0
Optimizer stop condition: RegularStepGradientDescentOptimizerv4: Gradient magnitude tolerance met after 281 iterations. Gradient magnitude (9.97276e-07) is less than gradient magnitude tolerance (1e-06).
  Iteration: 282
  Metric value: -0.6014645117096772
Registration done for index: 1
Optimizer stop condition: RegularStepGradientDescentOptimizerv4: Step too small after 2058 iterations. Current step (6.10352e-05) is less than minimum step (0.0001).
  Iteration: 2059
  Metric value: -0.6610746228301189
Registration done for index: 2
Optimizer stop condition: RegularStepGradientDescentOptimizerv4: Gradient magnitude tolerance met after 161 iterations. Gradient magnitude (9.94478e-07) is less than gradient magnitude tolerance (1e-06).
  Iterat

In [20]:
if CALC_REGISTRATION:
    cut_full_res_dataset = HaNOarsDataset(data_path, size=50, load_images=False)
    cut_full_res_dataset.load_from_file('./data/HaN_OAR_cut_72_192_168')

    # atlas_brainstem_map_sitk, atlas_left_parotid_map_sitk, atlas_right_parotid_map_sitk
    atlas_input_data_sitk = atlas_resampler.Execute(atlas_ri_sitk)
    atlas_input_labels = list()
    atlas_input_labels.append(('brainstem', atlas_brainstem_map_sitk))
    atlas_input_labels.append(('left_parotid', atlas_left_parotid_map_sitk))
    atlas_input_labels.append(('right_parotid', atlas_right_parotid_map_sitk))
    atlas_input_labels.append(('parotids', atlas_left_parotid_map_sitk + atlas_right_parotid_map_sitk))
    atlas_input_labels.append(('all_maps', atlas_brainstem_map_sitk + atlas_left_parotid_map_sitk + atlas_right_parotid_map_sitk))
    
    for atlas_name, atlas_label_map in atlas_input_labels:
        print(f'Calculating {atlas_name}')
        atlas_input_label_sitk = atlas_resampler.Execute(atlas_label_map)

        # parsing to numpy
        atlas_input_data_np = sitk.GetArrayFromImage(atlas_input_data_sitk)
        atlas_input_label_np = sitk.GetArrayFromImage(atlas_input_label_sitk)
        atlas_input_np = (atlas_input_data_np, atlas_input_label_np)
        print('Atlas input', atlas_input_data_np.shape, atlas_input_label_np.shape)

        # registration_list = create_regis_list(cut_full_res_dataset, atlas_input_np, numberOfIterations=3000)
        atlas_input_np = (atlas_input_data_np, atlas_input_label_np)
        registration_list = trans_list(cut_full_res_dataset, atlas_input_np, regis_trans_list)

        cut_full_res_dataset.data_list = registration_list
        cut_full_res_dataset.save_to_file(f'./data/HaN_OAR_cut_{atlas_name}_reg')

        print(f'cut_full_res_dataset.is_numpy, {cut_full_res_dataset.output_label}')

Calculating brainstem
Atlas input (136, 120, 219) (136, 120, 219)
Transform done for index: 0
Transform done for index: 1
Transform done for index: 2
Transform done for index: 3
Transform done for index: 4
Transform done for index: 5
Transform done for index: 6
Transform done for index: 7
Transform done for index: 8
Transform done for index: 9
Transform done for index: 10
Transform done for index: 11
Transform done for index: 12
Transform done for index: 13
Transform done for index: 14
Transform done for index: 15
Transform done for index: 16
Transform done for index: 17
Transform done for index: 18
Transform done for index: 19
Transform done for index: 20
Transform done for index: 21
Transform done for index: 22
Transform done for index: 23
Transform done for index: 24
Transform done for index: 25
Transform done for index: 26
Transform done for index: 27
Transform done for index: 28
Transform done for index: 29
Transform done for index: 30
Transform done for index: 31
Transform done f

In [21]:
preview_3d_image(registration_list[0][0], figsize=(10, 10))
preview_3d_image(registration_list[0][1], figsize=(10, 10))

preview_3d_image(cut_full_res_dataset.label_list[1][0] == OARS_LABELS.PAROTID_GLAND_R, figsize=(10, 10))

VBox(children=(HBox(children=(IntSlider(value=35, max=71),)),))

Output()

VBox(children=(HBox(children=(IntSlider(value=35, max=71),)),))

Output()

VBox(children=(HBox(children=(IntSlider(value=35, max=71),)),))

Output()

In [22]:
print('loading cut dataset')
data_path = './data/HaN_OAR_cut_brainstem_reg'
l_parotid_cut_full_res_dataset = HaNOarsDataset(data_path, size=50, load_images=False)
l_parotid_cut_full_res_dataset.load_from_file(data_path)

loading cut dataset


In [None]:
# preview_3d_image(l_parotid_cut_full_res_dataset[0][0][0], figsize=(5, 5))
# preview_3d_image(l_parotid_cut_full_res_dataset[0][0][1], figsize=(5, 5))
# preview_3d_image(l_parotid_cut_full_res_dataset[0][1], figsize=(5, 5))

# Training all organs models

checking list used for training single models

training each model

In [None]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']

tmp_list = list(filter_labels.items())
labels_list = [tmp_list[10]]
for OAR_KEY, OAR_VALUE in labels_list:
    print(f"{OAR_KEY}, {OAR_VALUE}")

In [None]:
TRAIN_MODELS = False
if TRAIN_MODELS:
    for OAR_KEY, OAR_VALUE in labels_list:
        cut_full_res_dataset.set_output_label(OAR_VALUE)
        log_date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        model_name = f'{log_date}_3d_unet_lowres_model3v2_{OAR_KEY}'

        print(f'Training model with dataset label \'{OAR_KEY}\', value \'{OAR_VALUE}\'')
        print(f'folder \'{model_name}\'')
        cut_model_info = prepare_model(epochs=75,
                                       learning_rate=3e-4,
                                       in_channels=8,
                                       input_data_channels=1,
                                       output_label_channels=1,
                                       dropout_rate=0.2,
                                       train_batch_size=1,
                                       model_name=model_name,
                                       train_dataset=cut_train_dataset, 
                                       valid_dataset=cut_valid_dataset, 
                                       test_dataset=cut_test_dataset,
                                       model_class=UNetV3v2)
        show_model_info(cut_model_info)
        print('\n\n')
        train_loop(cut_model_info, iterate_model_fn=iterate_model_v3v2)
        print('\n\n')

        # clearing memory
        torch.cuda.empty_cache()

# Evaluating Model

In [None]:
filter_labels = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels['SPINAL_CORD']

tmp_list = list(filter_labels.items())
labels_list = [tmp_list[5], tmp_list[6], tmp_list[7], tmp_list[10], tmp_list[11], tmp_list[12]]
for OAR_KEY, OAR_VALUE in labels_list:
    print(f"{OAR_KEY}, {OAR_VALUE}")

In [None]:
def get_possible_models(oar_key):
    possible_models = [folder_name for folder_name in os.listdir('./models') if oar_key in folder_name]    
    
    return possible_models

### loading models to CPU 

In [None]:
models = dict()
for OAR_KEY, OAR_VALUE in labels_list:
    epoch = 75
    possible_models = get_possible_models(f"model3v2_{OAR_KEY}")
    if len(possible_models) <= 0:
        print(f'{OAR_KEY} Model: No avaiable model')
        continue

    model_name = possible_models[0]
    print(f'{OAR_KEY} Model: Loading model {model_name}')

    # loading model checkpoint
    cut_model_info = load_checkpoint_model_info(model_name, epoch, cut_train_dataset, cut_valid_dataset, cut_test_dataset, model_class=UNetV3v2)

    # moving model to cpu/cuda with eval mode
    cut_model_info['device'] = 'cpu'
    cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
    cut_model_info['model'].eval()
    cut_model_info['model'].disable_tensorboard_writing = True
    
    models[OAR_KEY] = cut_model_info

## Testing Eval vs Train Mode

testing iteration function

In [None]:
cut_full_res_dataset.set_output_label(OARS_LABELS.PITUITARY)
cut_model_info = models['PITUITARY']
cut_model_info['device'] = get_device()
cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
cut_model_info['model'].disable_tensorboard_writing = True
    
model, model_name, optimizer, criterion = itemgetter('model', 'model_name', 'optimizer', 'criterion')(cut_model_info)
epochs, device, tensorboard_writer = itemgetter('epochs', 'device', 'tensorboard_writer')(cut_model_info)
train_dataloader, valid_dataloader, test_dataloader = itemgetter('train_dataloader',
                                                                 'valid_dataloader',
                                                                 'test_dataloader')(cut_model_info)
model.actual_epoch = 100

valid_loss, valid_dsc = iterate_model_v3v2(valid_dataloader, model, optimizer, criterion, device, is_eval=True)
print(valid_loss, valid_dsc)

cut_model_info['model'].disable_tensorboard_writing = True
cut_model_info['device'] = 'cpu'
cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])


In [None]:
SHOW_DSC_INFO = True
if SHOW_DSC_INFO:
    info_per_organs_df = {}
    models_info = list()
    for OAR_KEY, OAR_VALUE in labels_list:
        if OAR_KEY not in models:
            print(f'{OAR_KEY} Model: No avaiable model')
            continue

        # getting model to gpu
        cut_model_info = models[OAR_KEY]
        cut_model_info['device'] = get_device()
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        cut_model_info['model'].eval()
        cut_model_info['model'].disable_tensorboard_writing = True

        # preparing dataset for comparison
        cut_full_res_dataset.set_output_label(OAR_VALUE)

        # calculating dsc predictions        
        info_df, preds, rescaled_preds = get_threshold_info_df(
                                    model=cut_model_info['model'], 
                                    dataset=cut_full_res_dataset, 
                                    device=cut_model_info['device'], 
                                    train_indices=cut_train_dataset.indices, 
                                    valid_indices=cut_valid_dataset.indices, 
                                    test_indices=cut_test_dataset.indices,
                                    step=0.5,
                                    transform_input_fn=transform_input_with_registration)
        info_per_organs_df[OAR_KEY] = info_df

        # moving model back to cpu
        cut_model_info['device'] = 'cpu'
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])

        # parsing data
        best_threshold_col = 'thres_rescaled_dsc_0.50'
        train_tmp_df = info_df[info_df['is_train']][best_threshold_col]
        valid_tmp_df = info_df[info_df['is_valid']][best_threshold_col]
        train_dsc = train_tmp_df.mean()
        valid_dsc = valid_tmp_df.mean()
        print(f'{OAR_KEY} Model: DSC train {round(train_dsc, 4)} valid {round(valid_dsc, 4)}')

        models_info.append({
            'oar_key': OAR_KEY,
            'model_name': model_name,
            # Train
            'train_dsc_mean': train_dsc,
            'train_dsc_std': train_tmp_df.std(),
            'train_dsc_median': train_tmp_df.median(),
            'train_dsc_min': train_tmp_df.min(),
            'train_dsc_max': train_tmp_df.max(),
            # Valid
            'valid_dsc_mean': valid_dsc,
            'valid_dsc_std': valid_tmp_df.std(),
            'valid_dsc_median': valid_tmp_df.median(),
            'valid_dsc_min': valid_tmp_df.min(),
            'valid_dsc_max': valid_tmp_df.max(),
            # Both
            'train_valid_mean_delta': train_dsc - valid_dsc
        })

    models_info_df = pd.DataFrame(models_info)
    
    tmp_df = models_info_df[['oar_key', 'train_dsc_mean', 'train_dsc_std', 'valid_dsc_mean', 'valid_dsc_std']].copy()
    tmp_df['train_dsc_mean'] = (tmp_df['train_dsc_mean'] * 100).round(2)
    tmp_df['valid_dsc_mean'] = (tmp_df['valid_dsc_mean'] * 100).round(2)
    tmp_df['train_dsc_std'] = (tmp_df['train_dsc_std'] * 100).round(2)
    tmp_df['valid_dsc_std'] = (tmp_df['valid_dsc_std'] * 100).round(2)
    
    display(tmp_df.mean().round(2))
    display(tmp_df.round(2))
    display(tmp_df.sort_values(by=['train_dsc_std']).round(2))
    display(models_info_df.sort_values(by=['train_dsc_mean']).drop(columns=['model_name']).round(2))
    display(models_info_df.sort_values(by=['train_valid_mean_delta']).drop(columns=['model_name']).round(2))

In [None]:
if SHOW_DSC_INFO:
    tmp_column = 'is_train'
    
    try:
        print('OARS_LABELS.PAROTID_GLAND_R')
        tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.PAROTID_GLAND_R]]
        display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    except:
        pass

    try:   
        print('OARS_LABELS.PAROTID_GLAND_L')
        tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.PAROTID_GLAND_L]]
        display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    except:
        pass
        
    try: 
        print('OARS_LABELS.OPT_NERVE_L')
        tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.OPT_NERVE_L]]
        display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    except:
        pass
        
    try: 
        print('OARS_LABELS.PITUITARY')
        tmp_df = info_per_organs_df[OARS_LABELS.OARS_LABELS_R_DICT[OARS_LABELS.PITUITARY]]
        display(tmp_df[tmp_df[tmp_column]].sort_values(by='thres_rescaled_dsc_0.50'))
    except:
        pass

# Predictions merging and checking

In [None]:
filter_labels_dict = OARS_LABELS.OARS_LABELS_DICT
if 'SPINAL_CORD' in filter_labels:
    del filter_labels_dict['SPINAL_CORD']

cut_full_res_dataset.set_output_label(filter_labels_dict)
preview_dataset(cut_full_res_dataset, preview_index=35)

In [None]:
PARSE_CUT_DATASET = True
if PARSE_CUT_DATASET:
    from collections import defaultdict
    
    prediction_threshold = 0.5
    output_label_items = list(cut_full_res_dataset.output_label.items())[:]
    cut_dataset_predictions = defaultdict(lambda: defaultdict(lambda: np.zeros(cut_full_res_dataset[0][0][0].shape)))
    
    # for each label
    for label_index, val in enumerate(output_label_items[:]):
        OAR_KEY, OAR_VALUE = val
        # loading model
        if OAR_KEY not in models:
            print(f'{label_index+1}/{len(output_label_items)}: {OAR_KEY} Model: No avaiable model')
            continue
        print(f'{label_index+1}/{len(output_label_items)}: {OAR_KEY} Model: got model {datetime.datetime.now()}')

        # getting model to gpu
        cut_model_info = models[OAR_KEY]
        cut_model_info['device'] = get_device()
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])
        cut_model_info['model'].eval()
        cut_model_info['model'].disable_tensorboard_writing = True

        # for label in whole dataset
        for index in range(len(cut_full_res_dataset)):
            prediction, rescaled_pred = get_rescaled_pred(cut_model_info['model'], cut_full_res_dataset, 
                                                          cut_model_info['device'], index, use_only_one_dimension=False)
    
            cut_dataset_predictions[index][OAR_VALUE] = prediction[0]
            # extended_cut_full_res_dataset.data_list[index][label_index + 1] = prediction
            # extended_cut_full_res_dataset.data_list[index][label_index + 1] = ((rescaled_pred > prediction_threshold) * 1).astype(np.int8)

        # moving model back to cpu
        cut_model_info['device'] = 'cpu'
        cut_model_info['model'] = cut_model_info['model'].to(cut_model_info['device'])

In [None]:
if PARSE_CUT_DATASET:
    def custom_preview_dataset(dataset, predictions, preview_index=0, show_hist=False, use_transform=False):
        data, label = dataset.get_raw_item_with_label_filter(preview_index)  # equivalent dataset[preview_index]
        if use_transform:
            transform = get_dataset_transform()
            data, label = transform_input(data, label, transform)

        prediction = predictions[preview_index]
        max_channels = label.shape[0]
        max_slices = label.shape[1]

        print(f'data max {data.max()}, min {data.min()}')
        print(f'label max {label.max()}, min {label.min()}')

        def f(slice_index, label_channel):
            plt.figure(figsize=(20, 10))
            plt.subplot(1, 3, 1)
            plt.imshow(data[0, slice_index], cmap="gray")
            plt.subplot(1, 3, 2)
            plt.imshow(prediction[label_channel+1][slice_index])
            plt.subplot(1, 3, 3)
            plt.imshow(label[label_channel, slice_index])
            plt.show()

            if show_hist:
                plt.figure(figsize=(20, 10))
                plt.subplot(1, 2, 1)
                plt.hist(data.flatten(), 128)
                plt.subplot(1, 2, 2)
                plt.hist(label.flatten(), 128)
                plt.show()

        sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
        labelChannelSlider = widgets.IntSlider(min=0, max=max_channels - 1, step=1, value=(max_channels - 1) / 2)
        ui = widgets.VBox([widgets.HBox([sliceSlider, labelChannelSlider])])
        out = widgets.interactive_output(f, {'slice_index': sliceSlider, 'label_channel': labelChannelSlider})
        # noinspection PyTypeChecker
        display(ui, out)
    
    index = cut_valid_dataset.indices[3]
    index = 35
    custom_preview_dataset(cut_full_res_dataset, cut_dataset_predictions, preview_index=index)

## Merging predictions

# TODO: refactor

MERGE_PREDICTIONS = False
if MERGE_PREDICTIONS:
    merged_predictions = [None] * len(extended_cut_full_res_dataset)
    for index in range(len(extended_cut_full_res_dataset)):
        # print(f"{index+1}/{len(extended_cut_full_res_dataset)}: Merging predictions to single label")
        data, label = extended_cut_full_res_dataset.get_raw_item_with_label_filter(index)

        new_data = np.zeros(data[0].shape, dtype=np.int16)
        for i in range(1, 22):
            new_data += data[i]

        merged_predictions[index] = new_data
    print('Merging done')

    # checking how many masks are overlapping
    for i, tmp_merged in enumerate(merged_predictions):
        display(f'scan id {i}: {np.where(tmp_merged == 1)[0].shape[0]}, {np.where(tmp_merged == 2)[0].shape[0]}, {np.where(tmp_merged == 3)[0].shape[0]}, {np.where(tmp_merged == 4)[0].shape[0]}')

# TODO: refactor

from src.losses import calc_dscm

def custom_preview_dataset2(dataset, preview_index=0, use_transform=False):
    data, label = dataset.get_raw_item_with_label_filter(preview_index)
    if use_transform:
        transform = get_dataset_transform()
        data, label = transform_input(data, label, transform)
        
    cut_data, cut_label = cut_full_res_dataset.get_raw_item_with_label_filter(preview_index)
    max_channels = label.shape[0]
    max_slices = label.shape[1]
    
    print(f'data max {data.max()}, min {data.min()}')
    print(f'label max {label.max()}, min {label.min()}')
    print(f'{data.shape}, {cut_data.shape}, {label.shape}, {cut_label.shape}')
    print(f'{data.dtype}, {cut_data.dtype}, {label.dtype}, {cut_label.dtype}')
    tmp_merged = merged_predictions[preview_index]
    print(f'{np.where(tmp_merged == 1)[0].shape},{np.where(tmp_merged == 2)[0].shape},{np.where(tmp_merged == 3)[0].shape},{np.where(tmp_merged == 4)[0].shape}')

    def f(slice_index, label_channel):
        print(f'{OARS_LABELS.OARS_LABELS_R_DICT[label_channel+1]}')
        tmp_tensor_label = torch.tensor(label[label_channel])
        tmp_tensor_prediciton = torch.tensor(data[label_channel+1])
        tmp_dsc = calc_dsc(tmp_tensor_label, tmp_tensor_prediciton)
        print(f'dsc {tmp_dsc}')

        plt.figure(figsize=(30, 20))

        plt.subplot(2, 3, 1).title.set_text('data')
        plt.imshow(cut_data[0, slice_index], cmap="gray")
        plt.subplot(2, 3, 2).title.set_text('label')
        plt.imshow(label[label_channel, slice_index])
        plt.subplot(2, 3, 3).title.set_text('prediciton')
        plt.imshow(data[label_channel+1, slice_index])
        # print(data.shape, np.sum(data[label_channel+1]), np.unique(data[1])[-1])
        print(f'slices with values > 0', (np.where(data[label_channel+1] > 0))[0])
        
        plt.subplot(2, 3, 4).title.set_text('merged prediction labels')
        plt.imshow(tmp_merged[slice_index], vmin=0, vmax=np.unique(tmp_merged)[-1])
        plt.subplot(2, 3, 5).title.set_text('merged labels ')
        plt.imshow(np.sum(label, axis=0)[slice_index])

        plt.show()

    sliceSlider = widgets.IntSlider(min=0, max=max_slices - 1, step=1, value=(max_slices - 1) / 2)
    labelChannelSlider = widgets.IntSlider(min=0, max=max_channels - 1, step=1, value=(max_channels - 1) / 2)
    ui = widgets.VBox([widgets.HBox([sliceSlider, labelChannelSlider])])
    out = widgets.interactive_output(f, {'slice_index': sliceSlider, 'label_channel': labelChannelSlider})
    # noinspection PyTypeChecker
    display(ui, out)

custom_preview_dataset2(extended_cut_full_res_dataset, preview_index=35, use_transform=True)