In [None]:
import pathlib
import shutil

import numpy as np
import matplotlib.pyplot as plt

from IPython import display

import pydicom

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
from pymedphys.labs.autosegmentation import pipeline, filtering, indexing, mask

In [None]:
# Put all of the DICOM data within a directory called 'dicom' 
# organised by 'training', 'validation', and 'testing' in here:
data_path_root = pathlib.Path.home().joinpath('.data/dicom-ct-and-structures')
dicom_directory = data_path_root.joinpath('dicom')

training_directory = dicom_directory.joinpath('training')
validation_directory = dicom_directory.joinpath('validation')
testing_directory = dicom_directory.joinpath('testing')

# Of note, the DICOM file directory structure need not have any further
# organisation beyond being placed somewhere within one of the three
# 'training', 'validation', or 'testing'. They can be organised into
# directories by patient but that is not a requirement.

In [None]:
name_mappings_path = data_path_root.joinpath('name_mappings.json')

In [None]:
dicom_paths = pymedphys.zenodo_data_paths("auto-segmentation")

for path in dicom_paths:
    if path.suffix == '.dcm':
        dataset_id = path.parent.name
        parent_and_file = path.parts[-2::]

        if int(dataset_id) < 4:
            new_path = testing_directory.joinpath(*parent_and_file)
        elif int(dataset_id) < 8:
            new_path = validation_directory.joinpath(*parent_and_file)
        else:
            new_path = training_directory.joinpath(*parent_and_file)

    elif path.name == 'name_mappings.json':
        new_path = name_mappings_path
        
    else:
        raise ValueError(f"Unexpected file found. {path}.")
        
    if not new_path.exists():
        new_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(path, new_path)

In [None]:
# The following names_map is used to standardise the structure names
names_map = filtering.load_names_mapping(name_mappings_path)

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right', 'patient']

# Use the following to filter the slices used for training, validation,
# and testing
filters = {
    "study_set_must_have_all_of": structures_to_learn,
    "slice_at_least_one_of": [
        'lens_left', 'lens_right', 'eye_left', 'eye_right'
    ],
    "slice_must_have": ['patient'],
    "slice_cannot_have": []
}

In [None]:
datasets = pipeline.create_datasets(
    data_path_root, names_map, structures_to_learn, filters)

In [None]:
(
    ct_image_paths,
    structure_set_paths,
    ct_uid_to_structure_uid,
    structure_uid_to_ct_uids,
) = indexing.get_uid_cache(data_path_root)

In [None]:
(
    structure_names_by_ct_uid,
    structure_names_by_structure_set_uid,
) = indexing.get_cached_structure_names_by_uids(
    data_path_root, structure_set_paths, names_map
)

In [None]:
def diagnostic_plotting(x_grid, y_grid, input_array, output_array):
    plt.figure(figsize=(15,10))
    
    x_grid = x_grid.numpy()
    y_grid = y_grid.numpy()
    input_array = input_array.numpy()[:,:,0]
    output_array = output_array.numpy()
    
    for i, structure in enumerate(structures_to_learn[0:-1]):
        if structure.endswith('left'):
            colour = 'r'
        elif structure.endswith('right'):
            colour = 'b'
        else:
            raise ValueError("Expected either left or right")
            
        if structure.startswith('lens'):
            colour += '--'
        elif structure.startswith('eye'):
            colour += '-'
        else:
            raise ValueError("Expected either eye or lens")

        contours = mask.get_contours_from_mask(
            x_grid, y_grid, output_array[:,:,i])
        for contour in contours:
            plt.plot(*contour.T, colour)
            
    
    plt.axis('equal')
    ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    contours = mask.get_contours_from_mask(
        x_grid, y_grid, output_array[:,:,-1])
    for contour in contours:
        plt.plot(*contour.T, 'k--')
    
    windowed = np.copy(input_array)

    vmin = 900
    vmax = 1200
    windowed[windowed<vmin] = vmin
    windowed[windowed>vmax] = vmax

    plt.contourf(x_grid, y_grid, windowed, 50)
    plt.colorbar()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

In [None]:
def display_paths(ct_uid):
    print(ct_image_paths[ct_uid])
    
    structure_set_uid = ct_uid_to_structure_uid[ct_uid]
    print(structure_set_paths[structure_set_uid])

# Display Results

In [None]:
for (
    ct_uid, x_grid, y_grid, input_array, output_array
) in datasets['testing'].take(1):
    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()
    display_paths(ct_uid)

In [None]:
for (
    ct_uid, x_grid, y_grid, input_array, output_array
) in datasets['validation'].take(1):
    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()
    display_paths(ct_uid)

In [None]:
for (
    ct_uid, x_grid, y_grid, input_array, output_array
) in datasets['training'].take(10):
    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()
    display_paths(ct_uid)