In [None]:
import json
import pathlib
import random

import numpy as np
import matplotlib.pyplot as plt

from IPython import display

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
from pymedphys.labs.autosegmentation import indexing, filtering, pipeline, mask

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right', 'patient']

# Use the following to filter the slices used for training, validation,
# and testing
filters = {
    "study_set_must_have_all_of": structures_to_learn,
    "slice_at_least_one_of": [
        'lens_left', 'lens_right', 'eye_left', 'eye_right'
    ],
    "slice_must_have": ['patient'],
    "slice_cannot_have": []
}

In [None]:
dataset_all, used_structure_uids = pipeline.create_dataset(structures_to_learn, filters)

used_structure_uids = list(used_structure_uids)
sorted(used_structure_uids)
used_structure_uids

In [None]:
testing_uids = used_structure_uids[0:2]
print(testing_uids)
validation_uids = used_structure_uids[2:5]
print(validation_uids)
training_uids = used_structure_uids[5::]
print(training_uids)

In [None]:
testing, used_structure_uids = pipeline.create_dataset(
    structures_to_learn, filters, structure_uids=testing_uids)

assert set(used_structure_uids) == set(testing_uids)

In [None]:
validation, used_structure_uids = pipeline.create_dataset(
    structures_to_learn, filters, structure_uids=validation_uids)

assert set(used_structure_uids) == set(validation_uids)

In [None]:
training, used_structure_uids = pipeline.create_dataset(
    structures_to_learn, filters, structure_uids=training_uids)

assert set(used_structure_uids) == set(training_uids)

In [None]:
def diagnostic_plotting(x_grid, y_grid, input_array, output_array):
    plt.figure(figsize=(15,10))
    
    x_grid = x_grid.numpy()
    y_grid = y_grid.numpy()
    input_array = input_array.numpy()[:,:,0]
    output_array = output_array.numpy()
    
    for i, structure in enumerate(structures_to_learn[0:-1]):
        if structure.endswith('left'):
            colour = 'r'
        elif structure.endswith('right'):
            colour = 'b'
        else:
            raise ValueError("Expected either left or right")
            
        if structure.startswith('lens'):
            colour += '--'
        elif structure.startswith('eye'):
            colour += '-'
        else:
            raise ValueError("Expected either eye or lens")

        contours = mask.get_contours_from_mask(
            x_grid, y_grid, output_array[:,:,i])
        for contour in contours:
            plt.plot(*contour.T, colour)
            
    
    plt.axis('equal')
    ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    contours = mask.get_contours_from_mask(
        x_grid, y_grid, output_array[:,:,-1])
    for contour in contours:
        plt.plot(*contour.T, 'k--')
    
    windowed = np.copy(input_array)

    vmin = 900
    vmax = 1200
    windowed[windowed<vmin] = vmin
    windowed[windowed>vmax] = vmax

    plt.contourf(x_grid, y_grid, windowed, 50)
    plt.colorbar()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in training:    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()