In [None]:
import functools
import pathlib
import json
import random
import logging

import numpy as np
import matplotlib.pyplot as plt

import shapely.geometry
import skimage.draw
import skimage.filters

import tensorflow as tf

import pydicom

import pymedphys
import pymedphys._dicom.structure as dcm_struct

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys.labs.autosegmentation import indexing, softdice, filtering, pipeline

In [None]:
with open('name_mappings.json') as f:
    name_mappings_config = json.load(f)
    names_map = name_mappings_config["names_map"]
    ignore_list = name_mappings_config["ignore_list"]
    
    for key in ignore_list:
        names_map[key] = None

In [None]:
# Used to verify that all structures have either been ignored or mapped to a name

# names = set()

# for uid, path in structure_set_paths.items():
#     dcm = pydicom.read_file(
#         path, force=True, specific_tags=['StructureSetROISequence'])
#     for item in dcm.StructureSetROISequence:
#         names.add(item.ROIName)

# mapped_names = set(names_map.keys())
# print(mapped_names.difference(names))
# names.difference(mapped_names)

In [None]:
# full_list_of_structures = list(set([item for key, item in names_map.items()]).difference({None}))
# full_list_of_structures = sorted(full_list_of_structures)
# full_list_of_structures

In [None]:
# Put all of the SASH DICOM data within a directory called 'dicom' in here:
data_path_root = pathlib.Path.home().joinpath('.data/dicom-ct-and-structures')

In [None]:
ct_image_paths, structure_set_paths, ct_uid_to_structure_uid, structure_uid_to_ct_uids = indexing.get_uid_cache(data_path_root)

In [None]:
structure_names_by_ct_uid, structure_names_by_structure_set_uid = indexing.get_cached_structure_names_by_uids(
    data_path_root, structure_set_paths, names_map)

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right', 'patient']

# Only use a study set if all of the following are defined on that study set
study_set_must_have_all_of = structures_to_learn

# Only use a slice if one of the following contours exists on it
slice_at_least_one_of = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right']
slice_must_have = ['patient']
slice_cannot_have = []

In [None]:
filtered_ct_uids = filtering.filter_ct_uids(
    structure_uid_to_ct_uids,
    structure_names_by_structure_set_uid,
    structure_names_by_ct_uid,
    study_set_must_have_all_of,
    slice_at_least_one_of,
    slice_must_have,
    slice_cannot_have,
)

In [None]:
# ct_uids_to_train_on

In [None]:
# len({1, 2,3, 4}.intersection({2,3,5}))

In [None]:
random.shuffle(filtered_ct_uids)

dataset = pipeline.create_numpy_generator_dataset(
    data_path_root,
    structure_set_paths,
    ct_image_paths,
    ct_uid_to_structure_uid,
    names_map,
    filtered_ct_uids,
    structures_to_learn,
)

In [None]:
for ct_uid, x_grid, y_grid, input_array, output_array in dataset.take(1):
    print(ct_uid)

In [None]:
# serialised_dataset = from_numpy_dataset.map(tf_serialise)

In [None]:
# serialised_dataset

In [None]:
tfrecord_path = str(tfrecord_directory.joinpath(
    'lense-eye-patient.tfrecord'))
# writer = tf.data.experimental.TFRecordWriter(tfrecord_path)
# writer.write(serialised_dataset)

In [None]:
%%timeit

for ct_uid, input_array, output_array in from_numpy_dataset.take(100):
    pass

In [None]:
raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
raw_dataset

In [None]:
parse_features = {
    'ct_uid': tf.io.FixedLenFeature([], tf.string),
    'input_array': tf.io.FixedLenFeature([], tf.string),
    'output_array': tf.io.FixedLenFeature([], tf.string),
}

def _parse_dataset(example_proto):
    parsed = tf.io.parse_single_example(example_proto, parse_features)
    ct_uid = tf.io.parse_tensor(parsed['ct_uid'], tf.string)
    input_array = tf.io.parse_tensor(parsed['input_array'], tf.int32)
    output_array = tf.io.parse_tensor(parsed['output_array'], tf.float64)

    return ct_uid, input_array, output_array
    

parsed_dataset = raw_dataset.map(_parse_dataset)

In [None]:
%%timeit

for ct_uid, input_array, output_array in parsed_dataset.take(100):
    pass