In [None]:
import glob

import os

import pydicom

#import tensorflow_io as tfio

import matplotlib.pyplot as plt

import numpy as np

import tensorflow as tf

from pathlib import Path

from random import randint

import functools

import skimage.draw

import matplotlib

cmap = matplotlib.cm.get_cmap('hsv')

plt.rcParams['figure.figsize'] = [15, 15]

import random

import timeit

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


# HEADER

In [None]:
# FOR INPUT GENERATOR
dataset_path = "/home/matthew/priv/PROSTATE_TEST/"
#structure_names = ["patient", "RT HOF", "LT HOF", "BLADDER", "RECTUM", "Couch Foam Half Couch", "Couch Outer Half Couch", "Couch Edge"]
structure_names = ["patient"]

context = 10
batch_size = 2

train_ratio = 0.7 
valid_ratio = 0.2
test_ratio = 0.1

# FOR MODEL BUILDING
OPTIMIZER = 'adam'
LOSS = tf.keras.losses.SparseCategoricalCrossentropy()
METRICS = ['accuracy']


# GLOBAL VARS
input_shape = (2*context + 1, 512, 512, 1)
output_shape = (1, 512, 512, len(structure_names))
output_channels = output_shape[-1]

# INPUT GENERATOR

In [None]:
# GET PATHS FOR GENERATOR
start = timeit.default_timer()

patient_paths = glob.glob(dataset_path + "/*")

context_paths = glob.glob(dataset_path + "/*/*CT*", recursive=True)
context_paths.sort()

input_paths = [glob.glob(path + "/*CT*")[context:-context] for path in patient_paths]
input_paths = [item for sublist in input_paths for item in sublist]
random.shuffle(input_paths)

label_paths = glob.glob(dataset_path + "/*/*RS*", recursive=True)

assert len(context_paths) - (len(label_paths) * 2 * context) == len(input_paths)

start = timeit.default_timer()
end = timeit.default_timer()
print(f"Time to generate paths (s): {end-start}")
print("-------")
for path in patient_paths: print(path)

In [None]:
input_paths = input_paths[0:50]

num = len(input_paths)
num_train = int(num*train_ratio // 1)
num_valid = int(num*valid_ratio // 1)
num_test = int(num*test_ratio // 1)

print(f"Total: {num} = Train: {num_train} + Valid: {num_valid} + Test: {num_test}")

train_paths = input_paths[0:num_train]
valid_paths = input_paths[num_train:num_train+num_valid]
test_paths = input_paths[num_train+num_valid:]

In [None]:
class DataGen(tf.keras.utils.Sequence):
    
    def __init__(self, input_paths, context_paths, label_paths, batch_size, structure_names):
        self.input_paths = input_paths
        self.context_paths = context_paths
        self.label_paths = label_paths
        self.batch_size = batch_size
        self.structure_names = structure_names
        
        for path in self.label_paths:
            _ = self.pre_cached_structures(path)

        
    @functools.lru_cache()
    def pre_cached_structures(self, path):
        return pydicom.dcmread(path, force=True)
    
    
    def get_parent_dir(self, path):
        return Path(path).parent.name

        
    def __getitem__(self, batch_index, context = 10):
        
        if(batch_index+1)*self.batch_size > len(self.input_paths):
            self.batch_size = len(self.input_paths) - batch_index*self.batch_size
        
        batch_paths = self.input_paths[batch_index*self.batch_size : (batch_index+1)*self.batch_size]

        batch_inputs = []
        batch_labels  = []
        
        for image_path in batch_paths:
            # Get parent dir
            parent_dir = self.get_parent_dir(image_path)
            # Get mask path
            mask_path = [s for s in self.label_paths if parent_dir in s][0]
            # Get index
            image_index = self.context_paths.index(image_path)
            # Get context
            input_paths = self.context_paths[image_index-context:image_index+context+1]
    
            ###################### IMAGE LOOP ###################################
            
            images = []
            for dcm_path in input_paths:
                dicom_ct = pydicom.dcmread(dcm_path, force=True)
                try:
                    dicom_ct.file_meta.TransferSyntaxUID
                except AttributeError:
                    dicom_ct.file_meta.TransferSyntaxUID = (pydicom.uid.ImplicitVRLittleEndian)
                image = dicom_ct.pixel_array
                images = images + [image]
            batch_inputs.append(images)
            
            ####################### MASK LOOP ####################################

            img = pydicom.dcmread(image_path, force=True)
            img_position = img.ImagePositionPatient
            img_spacing = [x for x in img.PixelSpacing] + [img.SliceThickness]
            img_orientation = img.ImageOrientationPatient
            
#             start_read = timeit.default_timer()
#             dcm_rs = pydicom.dcmread(mask_path, force=True)
#             end_read = timeit.default_timer()           
#             print(f"Time to read UN-cached struct (s): {end_read-start_read}")

            dcm_rs = self.pre_cached_structures(mask_path)
            end_read = timeit.default_timer()           
            
            dcm_rs_struct_names = [structure.ROIName for structure in dcm_rs.StructureSetROISequence]
            
            # Pass this as arg!
            structure_names = self.structure_names

            names_to_pull = [name for name in dcm_rs_struct_names if name in structure_names]
            try:
                assert len(names_to_pull) == len(structure_names)
            except:
                batch_inputs.pop()
                continue
                
            structure_indexes = [dcm_rs_struct_names.index(name) for name in names_to_pull]
            assert img.FrameOfReferenceUID == dcm_rs.StructureSetROISequence[0].ReferencedFrameOfReferenceUID
            
            mask = np.zeros(shape=(1, 512, 512, len(structure_indexes)))

            dx, dy, *rest = img_spacing
            Cx, Cy, *rest = img_position
            Ox, Oy =  img_orientation[0], img_orientation[4]

            dicom_structures = dcm_rs

            for mask_index, structure_index in enumerate(structure_indexes):
                z = [z_slice.ContourData[2::3][0] for z_slice in dicom_structures.ROIContourSequence[structure_index].ContourSequence]
    
                try:
                    indexes = z.index(img_position[2])
                except:
                    continue
    
                try:
                    len(indexes)
                except:
                    indexes = [indexes]
        

                for index in indexes:
                    xyz = dicom_structures.ROIContourSequence[structure_index].ContourSequence[index].ContourData
               
                    x = np.array(xyz[0::3])
                    y = np.array(xyz[1::3])
        
                    r = (y - Cy) / dy * Oy
                    c = (x - Cx) / dx * Ox

                    rr, cc = skimage.draw.polygon(r, c)
        
                    mask[:, rr, cc, mask_index] = True
            
            batch_labels.append(mask)
            
        ###################### RETURNS ###################################    
        batch_inputs = np.array(batch_inputs)
        batch_inputs = batch_inputs[..., np.newaxis]    
        
        batch_input = np.array(batch_inputs)
        batch_labels = np.array(batch_labels)
        return batch_inputs, batch_labels
    
    def __len__(self):
        return int(np.ceil(len(self.input_paths)/float(self.batch_size)))

In [None]:
training_gen = DataGen(train_paths, context_paths, label_paths, batch_size = batch_size, structure_names = structure_names)
valid_gen = DataGen(valid_paths, context_paths, label_paths, batch_size = batch_size, structure_names = structure_names)

In [None]:
batch_index = randint(0, round(num_train / batch_size))

batch_inputs, batch_labels = training_gen.__getitem__(batch_index=batch_index, context=context)

print(batch_inputs.shape)
print(batch_labels.shape)

# [0] to remove batch size
# NOTE wants (batch_size, context, 512, 512, 1)
assert batch_inputs[0].shape == input_shape
# NOTE wants (batch_size, 1, 512, 512, number_of_structures)
assert batch_labels[0].shape == output_shape

In [None]:
index_in_batch = randint(0, batch_inputs.shape[0]-1)

images = batch_inputs[index_in_batch, ..., 0]
masks = batch_labels[index_in_batch, 0, ...]

num_mask = masks.shape[-1]
plt.imshow(images[context], cmap='gray')
for i in range(num_mask):
    mask = masks[...,i]
    if np.max(mask) > 0:
        plt.contour(mask, colors = [cmap((i)/num_mask)])

# MODEL BUILDING

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf


def down_block(x, m, n, c, size):
    initializer = tf.random_normal_initializer(0., 0.02)

    crop = tf.keras.layers.Cropping3D(cropping=size, data_format=None)(x)
    crop = tf.keras.layers.Conv3D(c, 1, activation=None)(crop)

    result = tf.keras.layers.ReLU()(x)
    for repeat in range(m):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False)(result)
        result = tf.keras.layers.ReLU()(result)
    for repeat in range(n):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False)(result)
        result = tf.keras.layers.Conv3D(c, (3, 1, 1),
                                        strides=1,
                                        padding='valid',
                                        kernel_initializer=initializer,
                                        use_bias=False)(result)
        if repeat != range(n)[-1]:
            result = tf.keras.layers.ReLU()(result)

    result = tf.keras.layers.Add()([crop, result])
    return result


def pool(x):
    result = tf.keras.layers.AveragePooling3D(pool_size=(1, 2, 2),
                                              strides=None,
                                              padding='valid',
                                              data_format='channels_last')(x)
    return result


def fc_block(x, r):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.layers.Conv3D(1024, (1, 8, 8),
                                    strides=1,
                                    padding='valid',
                                    kernel_initializer=initializer,
                                    use_bias=False)(x)
    for repeat in range(r):
        crop = result
        result = tf.keras.layers.ReLU()(result)
        result = tf.keras.layers.Add()([crop, result])

    result = tf.keras.layers.ReLU()(result)
    result = tf.keras.layers.Reshape((1, 8, 8, 256))(x)

    return result


def up_block(x, m, c):
    initializer = tf.random_normal_initializer(0., 0.02)
    print("\n-----------\nUPBLOCKING:")
    print(x)

    crop = tf.keras.layers.Conv3D(c, 1, activation=None)(x)

    result = tf.keras.layers.ReLU()(x)
    for repeat in range(m):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        data_format='channels_last',
                                        use_bias=False)(result)
        # result = tf.keras.layers.Conv3DTranspose(c,
        #                                          (3,1,1),
        #                                          strides= 1,
        #                                          data_format='channels_last',
        #                                          padding='valid')(result)
        result = tf.keras.layers.ReLU()(result)
    result = tf.keras.layers.Add()([crop, result])

    print("\nReturn")
    print(result)
    return result


def upscale(x):
    print("\n-----------\nUPSCALING:")
    print(x)
    result = tf.keras.layers.UpSampling3D(size=(1, 2, 2))(x)
    print("\nReturn")
    print(result)
    return result


def stack(x, skip):
    # NOTE axis 0 is the batch
    print("\n-----------\nSTACKING:")
    print(x)
    print(skip)
    result = tf.keras.layers.Concatenate(axis=1)([x, skip])
    # result = tf.keras.layers.Concatenate()([x, skip])
    print("\nReturn")
    print(result)
    return result


def Model(input_shape, output_channels):
    inputs = tf.keras.layers.Input(shape=input_shape)
    skips = []

    x = down_block(inputs, 3, 0, 32, 0)
    skips.append(x)
    x = pool(x)

    x = down_block(x, 3, 0, 32, 0)
    skips.append(x)
    x = pool(x)

    x = down_block(x, 3, 0, 64, 0)
    skips.append(x)
    x = pool(x)

    x = down_block(x, 1, 2, 64, (2, 0, 0))
    skips.append(x)
    x = pool(x)

    x = down_block(x, 1, 2, 128, (2, 0, 0))
    skips.append(x)
    x = pool(x)

    x = down_block(x, 1, 2, 128, (2, 0, 0))
    skips.append(x)
    x = pool(x)

    x = down_block(x, 0, 4, 256, (4, 0, 0))
    skips.append(x)

    x = fc_block(x, 2)

    print(
        "\n========================================================\nBLOCK 1:")
    x = stack(skips[-1], x)
    x = up_block(x, 4, 128)

    print(
        "\n=========================================================\nBLOCK 2:"
    )
    x = upscale(x)
    x = stack(skips[-2], x)
    x = up_block(x, 4, 128)

    print(
        "\n=========================================================\nBLOCK 3:"
    )
    x = upscale(x)
    x = stack(skips[-3], x)
    x = up_block(x, 4, 64)

    print(
        "\n=========================================================\nBLOCK 4:"
    )
    x = upscale(x)
    x = stack(skips[-4], x)
    x = up_block(x, 3, 64)

    print(
        "\n=========================================================\nBLOCK 5:"
    )
    x = upscale(x)
    x = stack(skips[-5], x)
    x = up_block(x, 3, 32)

    print(
        "\n=========================================================\nBLOCK 6:"
    )
    x = upscale(x)
    x = stack(skips[-6], x)
    x = up_block(x, 3, 32)

    print(
        "\n=========================================================\nBLOCK 7:"
    )
    x = upscale(x)
    x = stack(skips[-7], x)
    x = up_block(x, 3, 32)

    print(
        "\n=========================================================\nOUTPUT:")

    x = tf.keras.layers.Conv3D(filters=output_channels,
                               kernel_size=(104, 1, 1),
                               strides=1,
                               padding='valid')(x)

    x = tf.keras.layers.Conv3D(filters=output_channels,
                               kernel_size=(1, 1, 1),
                               strides=1,
                               activation='sigmoid',
                               kernel_initializer='he_normal',
                               padding='same')(x)

    print(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model = Model(input_shape, output_channels)

model.compile(optimizer = OPTIMIZER,
              loss = LOSS, 
              loss_weights=None, 
              sample_weight_mode=None, 
              weighted_metrics=None, 
              target_tensors=None)

In [None]:
#tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
#model.summary()

# TRAINING MODEL

In [None]:
#steps_per_epoch = len(input_paths) // batch_size
steps_per_epoch = len(train_paths) // batch_size
print(steps_per_epoch)

In [None]:
model.fit_generator(generator=training_gen,
                    validation_data=valid_gen,
                    steps_per_epoch = steps_per_epoch,
                    epochs=1)