In [None]:
import json
import pathlib
import random

import numpy as np
import matplotlib.pyplot as plt

from IPython import display

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
from pymedphys.labs.autosegmentation import indexing, filtering, pipeline, mask

In [None]:
release_url = 'https://github.com/pymedphys/data/releases/download/structure-dicom'
dicom_zip_url_pattern = f'{release_url}/' + '{dicom_type}.{uid}_Anonymised.zip'
mappings_url = f'{release_url}/mappings.zip'

In [None]:
data_download_root = pathlib.Path('auto-segmentation-dicom')

In [None]:
def get_filename_from_url(url):
    filename = url.split('/')[-1]
    
    return filename

In [None]:
save_filename = data_download_root.joinpath(
    get_filename_from_url(mappings_url))

mappings_paths = pymedphys.zip_data_paths(
    save_filename,
    check_hash=True,
    redownload_on_hash_mismatch=True,
    url=mappings_url
)

In [None]:
mappings_paths

In [None]:
data_path_root = mappings_paths[0].parent.parent
data_path_root

In [None]:
hash_path = data_path_root.joinpath('mappings', 'hashes.json')

In [None]:
name_mappings_path = data_path_root.joinpath('mappings', 'name_mappings.json')
names_map = filtering.load_names_mapping(name_mappings_path)

In [None]:
uid_cache_path = data_path_root.joinpath('mappings', 'uid-cache.json')

In [None]:
with open(uid_cache_path) as f:
    uid_cache = json.load(f)

In [None]:
(
    ct_image_paths,
    structure_set_paths,
    ct_uid_to_structure_uid,
    structure_uid_to_ct_uids,
) = indexing.read_uid_cache(data_path_root, uid_cache)

In [None]:
structure_names_mapping_cache_path = data_path_root.joinpath('mappings', 'structure-names-mapping-cache.json')

In [None]:
with open(structure_names_mapping_cache_path) as f:
    structure_names_cache = json.load(f)
    
structure_names_by_ct_uid = structure_names_cache["structure_names_by_ct_uid"]
structure_names_by_structure_set_uid = structure_names_cache[
    "structure_names_by_structure_set_uid"
]

In [None]:
uid_to_url = {}

for structure_uid, ct_uids in structure_uid_to_ct_uids.items():
    uid_to_url[structure_uid] = dicom_zip_url_pattern.format(dicom_type='RS', uid=structure_uid)
    
    for ct_uid in ct_uids:
        uid_to_url[ct_uid] = dicom_zip_url_pattern.format(dicom_type='CT', uid=ct_uid)

In [None]:
def download_uid(data_download_root, uid, uid_to_url, hash_path):
    url = uid_to_url[uid]
    filename = get_filename_from_url(url)
    save_filepath = data_download_root.joinpath('dicom', filename)
    
    pymedphys.zip_data_paths(
        save_filepath,
        check_hash=True,
        redownload_on_hash_mismatch=True,
        delete_when_no_hash_found=True,
        url=url,
        hash_filepath=hash_path
    )

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right', 'patient']

# Use the following to filter the slices used for training, validation,
# and testing
filters = {
    "study_set_must_have_all_of": structures_to_learn,
    "slice_at_least_one_of": [
        'lens_left', 'lens_right', 'eye_left', 'eye_right'
    ],
    "slice_must_have": ['patient'],
    "slice_cannot_have": []
}

In [None]:
filtered_ct_uids = filtering.filter_ct_uids(
    structure_uid_to_ct_uids,
    structure_names_by_structure_set_uid,
    structure_names_by_ct_uid,
    **filters,
)

In [None]:
random.shuffle(filtered_ct_uids)

In [None]:
## Add further filtering here to split dataset by training etc

In [None]:
dataset = pipeline.create_numpy_generator_dataset(    
    data_path_root,
    structure_set_paths,
    ct_image_paths,
    ct_uid_to_structure_uid,
    names_map,
    filtered_ct_uids,
    structures_to_learn,
    uid_to_url,
    hash_path)

In [None]:
def diagnostic_plotting(x_grid, y_grid, input_array, output_array):
    plt.figure(figsize=(15,10))
    
    x_grid = x_grid.numpy()
    y_grid = y_grid.numpy()
    input_array = input_array.numpy()[:,:,0]
    output_array = output_array.numpy()
    
    for i, structure in enumerate(structures_to_learn[0:-1]):
        if structure.endswith('left'):
            colour = 'r'
        elif structure.endswith('right'):
            colour = 'b'
        else:
            raise ValueError("Expected either left or right")
            
        if structure.startswith('lens'):
            colour += '--'
        elif structure.startswith('eye'):
            colour += '-'
        else:
            raise ValueError("Expected either eye or lens")

        contours = mask.get_contours_from_mask(
            x_grid, y_grid, output_array[:,:,i])
        for contour in contours:
            plt.plot(*contour.T, colour)
            
    
    plt.axis('equal')
    ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    contours = mask.get_contours_from_mask(
        x_grid, y_grid, output_array[:,:,-1])
    for contour in contours:
        plt.plot(*contour.T, 'k--')
    
    windowed = np.copy(input_array)

    vmin = 900
    vmax = 1200
    windowed[windowed<vmin] = vmin
    windowed[windowed>vmax] = vmax

    plt.contourf(x_grid, y_grid, windowed, 50)
    plt.colorbar()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in dataset:    
    ct_uid = ct_uid.numpy().decode()
    
    display.display(display.Markdown(f"## {ct_uid}"))
    diagnostic_plotting(x_grid, y_grid, input_array, output_array)
    plt.show()