In [None]:
import functools
import pathlib
import json
import random
import logging

import numpy as np
import matplotlib.pyplot as plt

import shapely.geometry
import skimage.draw
import skimage.filters

import tensorflow as tf

import pydicom

import pymedphys
import pymedphys._dicom.structure as dcm_struct

In [None]:
with open('name_mappings.json') as f:
    name_mappings_config = json.load(f)
    names_map = name_mappings_config["names_map"]
    ignore_list = name_mappings_config["ignore_list"]
    
    for key in ignore_list:
        names_map[key] = None

In [None]:
# Used to verify that all structures have either been ignored or mapped to a name

# names = set()

# for uid, path in structure_set_paths.items():
#     dcm = pydicom.read_file(
#         path, force=True, specific_tags=['StructureSetROISequence'])
#     for item in dcm.StructureSetROISequence:
#         names.add(item.ROIName)

# mapped_names = set(names_map.keys())
# print(mapped_names.difference(names))
# names.difference(mapped_names)

In [None]:
# full_list_of_structures = list(set([item for key, item in names_map.items()]).difference({None}))
# full_list_of_structures = sorted(full_list_of_structures)
# full_list_of_structures

In [None]:
# Create masks for the following structures, in the following order
structures_to_learn = [
    'lens_left', 'lens_right', 'eye_left', 'eye_right', 'patient']

# Only use a study set if all of the following are defined on that study set
study_set_filter_must_have_all_of = set(structures_to_learn)

# Only use a slice if one of the following contours exists on it
slice_filter_at_least_one_of = set([
    'lens_left', 'lens_right', 'eye_left', 'eye_right'])
slice_filter_must_have = set(['patient'])
slice_filter_cannot_have = set([])

In [None]:
# Put all of the SASH DICOM data within a directory called 'dicom' in here:
data_path_root = pathlib.Path.home().joinpath('.data/dicom-ct-and-structures')
dcm_paths = list(data_path_root.rglob('dicom/**/*.dcm'))

# This will be the location of the numpy cache
npz_directory = data_path_root.joinpath('npz_cache')
npz_directory.mkdir(parents=True, exist_ok=True)

# This will be the location of the tfrecord cache
tfrecord_directory = data_path_root.joinpath('tfrecord_cache')
tfrecord_directory.mkdir(parents=True, exist_ok=True)

In [None]:
# This will be the location of the DICOM header UID cache
uid_cache_path = data_path_root.joinpath("uid-cache.json")

# This will be the location of structure names by UID cache
structure_names_cache_path = data_path_root.joinpath("structure-names-cache.json")

In [None]:
def soft_surface_dice(reference, evaluation):
    edge_reference = skimage.filters.scharr(reference)
    edge_evaluation = skimage.filters.scharr(evaluation)
    
    score = (
        np.sum(np.abs(edge_evaluation - edge_reference)) /
        np.sum(edge_evaluation + edge_reference)
    )
    
    return 1 - score

In [None]:
def get_uid_cache(relative_paths):
    relative_paths = [
        str(path) for path in relative_paths
    ]
    
    try:
        with open(uid_cache_path) as f:
            uid_cache = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        uid_cache = {
            "ct_image_paths": {},
            "structure_set_paths": {},
            "ct_uid_to_structure_uid": {},
            "paths_when_run": []
        }
    
    if set(uid_cache["paths_when_run"]) == set(relative_paths):
        return uid_cache
    
    dcm_headers = []
    for dcm_path in dcm_paths:
        dcm_headers.append(pydicom.read_file(
            dcm_path, force=True, 
            specific_tags=['SOPInstanceUID', 'SOPClassUID', 'StudyInstanceUID']))
        
    ct_image_paths = {
        str(header.SOPInstanceUID): str(path)
        for header, path in zip(dcm_headers, relative_paths)
        if header.SOPClassUID.name == "CT Image Storage"
    }
    
    structure_set_paths = {
        str(header.SOPInstanceUID): str(path)
        for header, path in zip(dcm_headers, relative_paths)
        if header.SOPClassUID.name == "RT Structure Set Storage"
    }
    
    ct_uid_to_study_instance_uid = {
        str(header.SOPInstanceUID): str(header.StudyInstanceUID)
        for header in dcm_headers
        if header.SOPClassUID.name == "CT Image Storage"
    }
    
    study_instance_uid_to_structure_uid = {
        str(header.StudyInstanceUID): str(header.SOPInstanceUID)
        for header in dcm_headers
        if header.SOPClassUID.name == "RT Structure Set Storage"
    }
    
    ct_uid_to_structure_uid = {
        ct_uid: study_instance_uid_to_structure_uid[study_uid]
        for ct_uid, study_uid in ct_uid_to_study_instance_uid.items()
    }
    
    uid_cache["ct_image_paths"] = ct_image_paths
    uid_cache["structure_set_paths"] = structure_set_paths
    uid_cache["ct_uid_to_structure_uid"] = ct_uid_to_structure_uid    
    uid_cache["paths_when_run"] = relative_paths
    
    with open(uid_cache_path, "w") as f:
        json.dump(uid_cache, f)
        
    return uid_cache

In [None]:
relative_paths = [
    path.relative_to(data_path_root)
    for path in dcm_paths
]

uid_cache = get_uid_cache(relative_paths)
ct_image_paths = uid_cache["ct_image_paths"]
structure_set_paths = uid_cache["structure_set_paths"]
ct_uid_to_structure_uid = uid_cache["ct_uid_to_structure_uid"]

structure_uid_to_ct_uids = {}
for ct_uid, structure_uid in ct_uid_to_structure_uid.items():
    try:
        structure_uid_to_ct_uids[structure_uid].append(ct_uid)
    except KeyError:
        structure_uid_to_ct_uids[structure_uid] = [ct_uid]


In [None]:
def get_image_transformation_parameters(dcm_ct):
    # From Matthew Coopers work in ../old/data_generator.py
    
    position = dcm_ct.ImagePositionPatient
    spacing = [x for x in dcm_ct.PixelSpacing] + [dcm_ct.SliceThickness]
    orientation = dcm_ct.ImageOrientationPatient

    dx, dy, *_ = spacing
    Cx, Cy, *_ = position
    Ox, Oy = orientation[0], orientation[4]
    
    return dx, dy, Cx, Cy, Ox, Oy

In [None]:
def reduce_expanded_mask(expanded_mask, img_size, expansion):
    return np.mean(np.mean(
        tf.reshape(expanded_mask, (img_size, expansion, img_size, expansion)),
        axis=1), axis=2)

In [None]:
def calculate_aliased_mask(contours, dcm_ct, expansion=5):
    dx, dy, Cx, Cy, Ox, Oy = get_image_transformation_parameters(dcm_ct)
    
    ct_size = np.shape(dcm_ct.pixel_array)
    x_grid = np.arange(Cx, Cx + ct_size[0]*dx*Ox, dx*Ox)
    y_grid = np.arange(Cy, Cy + ct_size[1]*dy*Oy, dy*Oy)
    
    new_ct_size = np.array(ct_size) * expansion
    
    expanded_mask = np.zeros(new_ct_size)
    
    for xyz in contours:
        x = np.array(xyz[0::3])
        y = np.array(xyz[1::3])
        z = xyz[2::3]

        assert len(set(z)) == 1

        r = (((y - Cy) / dy * Oy)) * expansion + (expansion - 1) * 0.5
        c = (((x - Cx) / dx * Ox)) * expansion + (expansion - 1) * 0.5

        expanded_mask = np.logical_or(
            expanded_mask, 
            skimage.draw.polygon2mask(new_ct_size, np.array(list(zip(r, c)))))
        
    mask = reduce_expanded_mask(expanded_mask, ct_size[0], expansion)
    mask = 2 * mask - 1
    
    return x_grid, y_grid, mask

In [None]:
def get_contours_from_mask(x_grid, y_grid, mask):
    cs = plt.contour(x_grid, y_grid, mask, [0]);
    
    contours = [
        path.vertices for path in cs.collections[0].get_paths()
    ]
    
    plt.close()
    
    return contours

In [None]:
def get_structure_names_by_uids(structure_set_paths, names_map):
    structure_names_by_ct_uid = {}
    structure_names_by_structure_set_uid = {}


    for structure_set_uid, relative_structure_set_path in structure_set_paths.items():
        structure_set_path = data_path_root.joinpath(relative_structure_set_path)

        structure_set = pydicom.read_file(
            structure_set_path, 
            force=True, 
            specific_tags=['ROIContourSequence', 'StructureSetROISequence'])

        number_to_name_map = {
            roi_sequence_item.ROINumber: names_map[roi_sequence_item.ROIName]
            for roi_sequence_item in structure_set.StructureSetROISequence
            if names_map[roi_sequence_item.ROIName] is not None
        }

        structure_names_by_structure_set_uid[structure_set_uid] = [
            item for _, item in number_to_name_map.items()]


        for roi_contour_sequence_item in structure_set.ROIContourSequence:
            try:
                structure_name = number_to_name_map[roi_contour_sequence_item.ReferencedROINumber]
            except KeyError:
                continue

            for contour_sequence_item in roi_contour_sequence_item.ContourSequence:
                ct_uid = contour_sequence_item.ContourImageSequence[0].ReferencedSOPInstanceUID

                try:
                    structure_names_by_ct_uid[ct_uid].add(structure_name)
                except KeyError:
                    structure_names_by_ct_uid[ct_uid] = set([structure_name])
    
    structure_names_by_ct_uid = {
        key: list(item) for key, item in structure_names_by_ct_uid.items()
    }
    
    return structure_names_by_ct_uid, structure_names_by_structure_set_uid

In [None]:
def get_cached_structure_names_by_uids(structure_set_paths, names_map):
    structure_set_paths = {
        str(key): str(item) for key, item in structure_set_paths.items()
    }
    
    try:
        with open(structure_names_cache_path) as f:
            structure_names_cache = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        structure_names_cache = {
            "structure_names_by_ct_uid": {},
            "structure_names_by_structure_set_uid": {},
            "structure_set_paths_when_run": {},
            "names_map_when_run": {}
        }
        
    cache_valid = (
        structure_names_cache["structure_set_paths_when_run"] == structure_set_paths and
        structure_names_cache["names_map_when_run"] == names_map)
    
    if cache_valid:
        return structure_names_cache
    
    structure_names_by_ct_uid, structure_names_by_structure_set_uid = (
        get_structure_names_by_uids(structure_set_paths, names_map))
    
    structure_names_cache["structure_names_by_ct_uid"] = structure_names_by_ct_uid
    structure_names_cache[
        "structure_names_by_structure_set_uid"] = structure_names_by_structure_set_uid
    structure_names_cache["structure_set_paths_when_run"] = structure_set_paths
    structure_names_cache["names_map_when_run"] = names_map
            
    with open(structure_names_cache_path, "w") as f:
        json.dump(structure_names_cache, f)
    
    return structure_names_cache

In [None]:
structure_names_cache = get_cached_structure_names_by_uids(
    structure_set_paths, names_map)

In [None]:
structure_names_by_ct_uid = structure_names_cache["structure_names_by_ct_uid"]
structure_names_by_structure_set_uid = structure_names_cache[
    "structure_names_by_structure_set_uid"]

In [None]:
# structure_names_by_ct_uid

In [None]:
# structure_names_by_structure_set_uid

In [None]:
ct_uids_to_train_on = []

for structure_uid, ct_uids in structure_uid_to_ct_uids.items():
    structure_names_in_study_set = set(structure_names_by_structure_set_uid[structure_uid])
    
    if not structure_names_in_study_set.issuperset(
        study_set_filter_must_have_all_of
    ):
        continue
        
    for ct_uid in ct_uids:
        try:
            structure_names_on_slice = set(structure_names_by_ct_uid[ct_uid])
        except KeyError:
            structure_names_on_slice = set([])
            
        if len(structure_names_on_slice.intersection(slice_filter_at_least_one_of)) == 0:
            continue
            
        if not structure_names_on_slice.issuperset(slice_filter_must_have):
            continue
            
        if len(structure_names_on_slice.intersection(slice_filter_cannot_have)) != 0:
            continue
            
        ct_uids_to_train_on.append(ct_uid)

In [None]:
# ct_uids_to_train_on

In [None]:
# len({1, 2,3, 4}.intersection({2,3,5}))

In [None]:
def create_numpy_input_output(ct_uid):
    structure_uid = ct_uid_to_structure_uid[ct_uid]

    structure_set_path = data_path_root.joinpath(structure_set_paths[structure_uid])

    structure_set = pydicom.read_file(
        structure_set_path, 
        force=True, 
        specific_tags=['ROIContourSequence', 'StructureSetROISequence'])
    
    number_to_name_map = {
        roi_sequence_item.ROINumber: names_map[roi_sequence_item.ROIName]
        for roi_sequence_item in structure_set.StructureSetROISequence
        if names_map[roi_sequence_item.ROIName] is not None
    }
    
    contours_by_ct_uid = {}

    for roi_contour_sequence_item in structure_set.ROIContourSequence:
        try:
            structure_name = number_to_name_map[roi_contour_sequence_item.ReferencedROINumber]
        except KeyError:
            continue

        for contour_sequence_item in roi_contour_sequence_item.ContourSequence:
            ct_uid = contour_sequence_item.ContourImageSequence[0].ReferencedSOPInstanceUID

            try:
                _ = contours_by_ct_uid[ct_uid]
            except KeyError:
                contours_by_ct_uid[ct_uid] = dict()

            try:
                contours_by_ct_uid[ct_uid][structure_name].append(contour_sequence_item.ContourData)
            except KeyError:
                contours_by_ct_uid[ct_uid][structure_name] = [contour_sequence_item.ContourData]
                
    ct_path = data_path_root.joinpath(ct_image_paths[ct_uid])
    dcm_ct = pydicom.read_file(ct_path, force=True)
    dcm_ct.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian

    ct_size = np.shape(dcm_ct.pixel_array)
    
    contours_on_this_slice = contours_by_ct_uid[ct_uid].keys()

    masks = np.nan * np.ones((*ct_size, len(structures_to_learn)))

    for i, structure in enumerate(structures_to_learn):
        if not structure in contours_on_this_slice:
            masks[:,:,i] = np.zeros(ct_size) - 1

            continue

        original_contours = contours_by_ct_uid[ct_uid][structure]
        x_grid, y_grid, masks[:,:,i] = calculate_aliased_mask(original_contours, dcm_ct)
        
    np.shape(masks)
    assert np.sum(np.isnan(masks)) == 0
    
    return dcm_ct.pixel_array, masks

In [None]:
def numpy_input_output_from_cache(ct_uid, structures_to_learn):
    npz_path = npz_directory.joinpath(f'{ct_uid}.npz')
    structures_to_learn_cache_path = npz_directory.joinpath('structures_to_learn.json')
    
    try:
        with open(structures_to_learn_cache_path) as f:
            structures_to_learn_cache = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        structures_to_learn_cache = []
        
    if structures_to_learn_cache != structures_to_learn:
        logging.warning("Structures to learn has changed. Dumping npz cache.")
        for path in npz_directory.glob('*.npz'):
            path.unlink()
        with open(structures_to_learn_cache_path, "w") as f:
            json.dump(structures_to_learn, f)

    try:
        data = np.load(npz_path)
        input_array = data['input_array']
        output_array = data['output_array']
    except FileNotFoundError:
        input_array, output_array = create_numpy_input_output(ct_uid)
        np.savez(npz_path, input_array=input_array, output_array=output_array)
        
    return input_array, output_array

In [None]:
# all_ct_uids = list(ct_image_paths.keys())
# random.shuffle(all_ct_uids)

random.shuffle(ct_uids_to_train_on)

def from_numpy_generator():
    for ct_uid in ct_uids_to_train_on:
        input_array, output_array = numpy_input_output_from_cache(
            ct_uid, structures_to_learn)
        input_array = input_array[:,:,None]
        
        yield ct_uid, input_array, output_array
        
from_numpy_generator_params = (
    (tf.string, tf.int32, tf.float64),
    (
        tf.TensorShape(()), 
        tf.TensorShape([512, 512, 1]), 
        tf.TensorShape([512, 512, len(structures_to_learn)]))
)

from_numpy_dataset = tf.data.Dataset.from_generator(
    from_numpy_generator, *from_numpy_generator_params)

In [None]:
def _bytes_feature(value):
    value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialise(ct_uid, input_array, output_array):
    ct_uid = tf.io.serialize_tensor(ct_uid)
    input_array = tf.io.serialize_tensor(input_array)
    output_array = tf.io.serialize_tensor(output_array)
    
    feature = {
        'ct_uid': _bytes_feature(ct_uid),
        'input_array': _bytes_feature(input_array),
        'output_array': _bytes_feature(output_array),
    }

    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


for ct_uid, input_array, output_array in from_numpy_dataset.take(1):
    serialise(ct_uid, input_array, output_array)

In [None]:
# # Details on this from https://www.tensorflow.org/tutorials/load_data/tfrecord
# def tf_serialise(ct_uid, input_array, output_array):
#     tf_string = tf.py_function(
#         serialise,
#         (ct_uid, input_array, output_array),
#         tf.string
#     )
#     return tf.reshape(tf_string, ())

In [None]:
# serialised_dataset = from_numpy_dataset.map(tf_serialise)

In [None]:
# serialised_dataset

In [None]:
tfrecord_path = str(tfrecord_directory.joinpath(
    'lense-eye-patient.tfrecord'))
# writer = tf.data.experimental.TFRecordWriter(tfrecord_path)
# writer.write(serialised_dataset)

In [None]:
%%timeit

for ct_uid, input_array, output_array in from_numpy_dataset.take(100):
    pass

In [None]:
raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
raw_dataset

In [None]:
parse_features = {
    'ct_uid': tf.io.FixedLenFeature([], tf.string),
    'input_array': tf.io.FixedLenFeature([], tf.string),
    'output_array': tf.io.FixedLenFeature([], tf.string),
}

def _parse_dataset(example_proto):
    parsed = tf.io.parse_single_example(example_proto, parse_features)
    ct_uid = tf.io.parse_tensor(parsed['ct_uid'], tf.string)
    input_array = tf.io.parse_tensor(parsed['input_array'], tf.int32)
    output_array = tf.io.parse_tensor(parsed['output_array'], tf.float64)

    return ct_uid, input_array, output_array
    

parsed_dataset = raw_dataset.map(_parse_dataset)

In [None]:
%%timeit

for ct_uid, input_array, output_array in parsed_dataset.take(100):
    pass