In [1]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print('Found Google Colab')
    !pip3 install torch torchvision torchsummary
    !pip3 install simpleitk

    # noinspection PyUnresolvedReferences
    from google.colab import drive
    drive.mount('/content/drive')

import numpy as np
import matplotlib.pyplot as plt
import torch

import src.helpers.oars_labels_consts as OARS_LABELS

from operator import itemgetter
from IPython.display import display, Markdown
from ipywidgets import widgets

from src.training_helpers import loss_batch, show_model_info
from src.helpers.prepare_model import prepare_model
from src.helpers.train_loop import train_loop
from src.helpers.get_dataset import get_dataset, get_dataloaders
from src.helpers.get_dataset_info import get_dataset_info
from src.helpers.preview_dataset import preview_dataset


MAX_PADDING_SLICES = 160
torch.manual_seed(20)
print('Done Init')

Done Init


# Neural Network

## dataset

labels:
The 22 annotated OARs contain Brain_Stem,Eye_L,Eye_R,Lens_L,Lens_R,Opt_Nerve_L,Opt_Nerve_R,Opt_Chiasma,Temporal_Lobes_L, Temporal_Lobes_R,Pituitary,Parotid_Gland_L,Parotid_Gland_R,Inner_Ear_L,Inner_Ear_R,Mid_Ear_L,Mid_Ear_R, T_M_Joint_L,T_M_Joint_R,Spinal_Cord,Mandible_L,Mandible_R, corresponding to the label 1 to 22 in the annotation file.

In [2]:
filter_labels = [OARS_LABELS.EYE_L, OARS_LABELS.EYE_R, OARS_LABELS.LENS_L, OARS_LABELS.LENS_R]
dataset = get_dataset(shrink_factor=8, filter_labels=filter_labels)

dataset.to_numpy()
dataloaders_obj = get_dataloaders(dataset)

get_dataset_info(dataset, dataloaders_obj)
train_dataset, valid_dataset, test_dataset = itemgetter('train_dataset', 'valid_dataset', 'test_dataset')(dataloaders_obj)

CUDA using 8x dataset
train 40, valid_size 5, test 5, full 50
train indeces [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 20, 21, 22, 23, 24, 28, 30, 31, 32, 33, 34, 35, 36, 37, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
valid indeces [6, 13, 19, 25, 38]
test indeces [16, 26, 27, 29, 39]


In [3]:
preview_dataset(dataset, preview_index=2, show_hist=False, max_padding_slices=MAX_PADDING_SLICES)

data max 11.757370914300283, min -0.4249158796367933
label max 1, min 0


VBox(children=(HBox(children=(IntSlider(value=101, max=159),)),))

Output()

## training

### params and architecture preparation

In [4]:
model_info = prepare_model(in_channels=16, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset)
show_model_info(model_info, MAX_PADDING_SLICES)

train_loop_params = {k:v for k,v in model_info.items() if k not in ['model_total_params', 'model_total_trainable_params']}
train_loop(**train_loop_params)

Device running "cuda"
max output channels 256
Model number of params: 4770177, trainable 4770177
Running training loop
Batch eval [1] loss 0.99936, dsc 0.00064
Batch eval [2] loss 0.99957, dsc 0.00043
Batch eval [3] loss 0.99937, dsc 0.00063
Batch eval [4] loss 0.99920, dsc 0.00080
Batch eval [5] loss 0.99889, dsc 0.00111
Epoch [1] T 34.87s, deltaT 34.87s, loss: train 0.99940, valid 0.99928, dsc: train 0.00060, valid 0.00072
Batch eval [1] loss 0.99915, dsc 0.00085
Batch eval [2] loss 0.99944, dsc 0.00056
Batch eval [3] loss 0.99920, dsc 0.00080
Batch eval [4] loss 0.99903, dsc 0.00097
Batch eval [5] loss 0.99863, dsc 0.00137
Epoch [2] T 70.49s, deltaT 35.62s, loss: train 0.99922, valid 0.99909, dsc: train 0.00078, valid 0.00091
Batch eval [1] loss 0.99899, dsc 0.00101
Batch eval [2] loss 0.99930, dsc 0.00070
Batch eval [3] loss 0.99901, dsc 0.00099
Batch eval [4] loss 0.99875, dsc 0.00125
Batch eval [5] loss 0.99827, dsc 0.00173
Epoch [3] T 106.87s, deltaT 36.38s, loss: train 0.99902,

RuntimeError: can't export a trace that didn't finish running

## testing and evaluating

In [None]:
model, device, optimizer, criterion = itemgetter('model', 'device', 'optimizer', 'criterion')(model_info)


def eval_image_dataset(dataset, def_slice=90, figfile=None):
    with torch.no_grad():
        model.eval()
        aSlider = widgets.IntSlider(min=0, max=MAX_PADDING_SLICES-1, step=1, value=def_slice)
        ui = widgets.VBox([
            widgets.HBox([aSlider])
        ])

        model.eval()
        torch.cuda.empty_cache()
        print(f'showing number {dataset.indices[0]}')
        inputs, labels = dataset[0]
        inputs = torch.from_numpy(np.array([inputs])).to(device).float()
        labels = torch.from_numpy(np.array([labels])).to(device).float()
        prediction = model(inputs)

        item_loss, item_dsc, inputs_len = loss_batch(model, optimizer, criterion, inputs, labels)
        print(f'loss {item_loss}, dsc {item_dsc}, inputs_len {inputs_len}')

        inputs = inputs.cpu()
        labels = labels.cpu()
        prediction_np = prediction.cpu().detach().numpy()

        plt.hist(prediction_np.flatten(), 20)
        plt.title('Distribution of prediction values')
        plt.show()

        def f(a):
            plt.figure(figsize=(30, 16))
            tmp_ax = plt.subplot(1, 3, 1)
            tmp_ax.title.set_text('Input')
            plt.imshow(inputs[0, 0, a], cmap="gray")

            tmp_ax = plt.subplot(1, 3, 2)
            tmp_ax.title.set_text('Label')
            plt.imshow(labels[0, a], cmap="gray")

            tmp_ax = plt.subplot(1, 3, 3)
            tmp_ax.title.set_text('Prediction')
            plt.imshow(prediction_np[0, 0, a], cmap="gray", vmin=0, vmax=1)
            #     plt.subplot(2, 2, 4)
            #     plt.imshow(prediction_np[0, 0, a], cmap="gray")

            if figfile is not None:
                plt.savefig(figfile, dpi=96)
            plt.show()


        out = widgets.interactive_output(f, {'a': aSlider })
        display(ui, out)

        print('DEBUG shapes', prediction[:, 0].shape, labels.shape, inputs.shape)
        print(f'DEBUG prediction max {prediction.max().item()}, min {prediction.min().item()}')
        print('DEBUG intersection', (labels.cpu() *  prediction.cpu()[:,0]).sum().item())
        print('DEBUG label sum', labels.cpu().sum().item())
        print('DEBUG prediction sum', prediction.cpu()[:,0].sum().item())

        smooth = 1e-6
        y_pred = prediction.cpu()[:,0].contiguous().view(-1)
        y_true = labels.cpu()[:].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + smooth) / (
            y_pred.sum() + y_true.sum() + smooth
        )

        print('DEBUG intersection2', intersection.item())
        print('DEBUG dsc', dsc.item())
        print('DEBUG MSE', (labels.cpu() - prediction.cpu()[:,0]).pow(2).mean().item())

                
# noinspection PyTypeChecker
display(Markdown("## Train Eval"))
eval_image_dataset(train_dataset, 90)
# noinspection PyTypeChecker
display(Markdown("## Valid Eval"))
eval_image_dataset(valid_dataset, 90)
# display(Markdown("## Test Eval"))
# eval_image_dataset(test_dataset, 78, 'test_plot.png')