In [None]:
import functools
import pathlib
import json
import random
import logging

import numpy as np
import matplotlib.pyplot as plt

import shapely.geometry
import skimage.draw
import skimage.filters

import tensorflow as tf

import pydicom

import pymedphys
import pymedphys._dicom.structure as dcm_struct

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys.labs.autosegmentation import indexing, softdice, filtering, pipeline, tfrecord

In [None]:
# Put all of the SASH DICOM data within a directory called 'dicom' in here:
data_path_root = pathlib.Path.home().joinpath('.data/dicom-ct-and-structures')

In [None]:
ct_image_paths, structure_set_paths, ct_uid_to_structure_uid, structure_uid_to_ct_uids = indexing.get_uid_cache(data_path_root)

In [None]:
names_map = filtering.load_names_mapping('name_mappings.json')

# Used to verify that all structures have either been ignored or mapped to a name
filtering.verify_all_names_have_mapping(structure_set_paths, names_map)

In [None]:
structure_names_by_ct_uid, structure_names_by_structure_set_uid = indexing.get_cached_structure_names_by_uids(
    data_path_root, structure_set_paths, names_map)

In [None]:
full_list_of_structures = list(set([item for key, item in names_map.items()]).difference({None}))
full_list_of_structures = sorted(full_list_of_structures)
full_list_of_structures

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right', 'patient']

# Only use a study set if all of the following are defined on that study set
study_set_must_have_all_of = structures_to_learn

# Only use a slice if one of the following contours exists on it
slice_at_least_one_of = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right']
slice_must_have = ['patient']
slice_cannot_have = []

In [None]:
filtered_ct_uids = filtering.filter_ct_uids(
    structure_uid_to_ct_uids,
    structure_names_by_structure_set_uid,
    structure_names_by_ct_uid,
    study_set_must_have_all_of,
    slice_at_least_one_of,
    slice_must_have,
    slice_cannot_have,
)

In [None]:
random.shuffle(filtered_ct_uids)

dataset = pipeline.create_numpy_generator_dataset(
    data_path_root,
    structure_set_paths,
    ct_image_paths,
    ct_uid_to_structure_uid,
    names_map,
    filtered_ct_uids,
    structures_to_learn,
)

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in dataset.take(15):
    print(ct_uid)

In [None]:
tfrecord_path = str(data_path_root.joinpath(
    'lense-eye-patient.tfrecord'))
tfrecord.write(tfrecord_path, dataset.take(15))

In [None]:
loaded_dataset = tfrecord.read(tfrecord_path)

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in loaded_dataset.take(15):
    print(ct_uid)