In [None]:
# %load header
# INPUT GENERATOR
DATA_PATH = "/home/matthew/priv/PROSTATE_TEST/"
STRUCTURE_NAMES = ["patient"]

CONTEXT = 0
BATCH_SIZE = 5

# Train/Valid/Test
SPLIT_RATIO = (0.7, 0.2, 0.1)

# DATA SHAPES
INPUT_SHAPE = (2*CONTEXT + 1, 128, 128, 1)
OUTPUT_SHAPE = (1, 512, 512, len(STRUCTURE_NAMES))
OUTPUT_CHANNELS = OUTPUT_SHAPE[-1]

# MODEL COMPILING
EPOCHS = 1
OPTIMIZER = 'adam'

import tensorflow as tf
LOSS = tf.nn.sigmoid_cross_entropy_with_logits

METRICS = ['accuracy']


In [None]:
from paths import *
from generator import DataGen

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10, 10]

In [None]:
input_paths, context_paths, label_paths = get_paths(DATA_PATH, CONTEXT)

train_paths, valid_paths, test_paths = split_paths(input_paths, SPLIT_RATIO)

# For quick training
train_paths = train_paths[0:2*BATCH_SIZE]
valid_paths = valid_paths[0:2*BATCH_SIZE]
test_paths = test_paths[0:1*BATCH_SIZE]

In [None]:
for path in train_paths: print(path)

In [None]:
# %load generator
import tensorflow as tf
import functools
import skimage.draw
import numpy as np
# Perhaps use tf.io instead
import pydicom
from pathlib import Path
import skimage.transform


class DataGen(tf.keras.utils.Sequence):
    def __init__(self, input_paths, context_paths, label_paths, context,
                 batch_size, structure_names, resize=128):
        self.input_paths = input_paths
        self.context_paths = context_paths
        self.label_paths = label_paths
        self.context = context
        self.batch_size = batch_size
        self.structure_names = structure_names

        for path in self.label_paths:
            _ = self.pre_cached_structures(path)

        self.on_epoch_end()

    @functools.lru_cache()
    def pre_cached_structures(self, path):
        return pydicom.dcmread(path, force=True)

    def get_parent_dir(self, path):
        return Path(path).parent.name

    # def resize_vol(volume, shape):
    #     for s in volume:
    #         skimage.transform.resize(s, shape)
    #     return volume


    def __getitem__(self, batch_index):

        if (batch_index + 1) * self.batch_size > len(self.input_paths):
            self.batch_size = len(
                self.input_paths) - batch_index * self.batch_size

        batch_paths = self.input_paths[batch_index *
                                       self.batch_size:(batch_index + 1) *
                                       self.batch_size]

        batch_inputs = []
        batch_labels = []

        for image_path in batch_paths:
            # Get parent dir
            parent_dir = self.get_parent_dir(image_path)
            # Get mask path
            mask_path = [s for s in self.label_paths if parent_dir in s][0]
            # Get index
            image_index = self.context_paths.index(image_path)
            # Get context
            input_paths = self.context_paths[image_index -
                                             self.context:image_index +
                                             self.context + 1]

            try:
                assert len(input_paths) == 2 * self.context + 1
            except:
                continue

            ###################### IMAGE LOOP ###################################

            images = []
            for dcm_path in input_paths:
                dicom_ct = pydicom.dcmread(dcm_path, force=True)
                try:
                    dicom_ct.file_meta.TransferSyntaxUID
                except AttributeError:
                    dicom_ct.file_meta.TransferSyntaxUID = (
                        pydicom.uid.ImplicitVRLittleEndian)
                image = dicom_ct.pixel_array
                image = skimage.transform.resize(image, (128, 128))
                images = images + [image]

            batch_inputs.append(images)

            ####################### MASK LOOP ####################################

            img = pydicom.dcmread(image_path, force=True)
            img_position = img.ImagePositionPatient
            img_spacing = [x for x in img.PixelSpacing] + [img.SliceThickness]
            img_orientation = img.ImageOrientationPatient

            dicom_structures = self.pre_cached_structures(mask_path)

            assert img.FrameOfReferenceUID == dicom_structures.StructureSetROISequence[
                0].ReferencedFrameOfReferenceUID

            dcm_rs_struct_names = [
                structure.ROIName
                for structure in dicom_structures.StructureSetROISequence
            ]

            structure_names = self.structure_names

            names_to_pull = [
                name for name in dcm_rs_struct_names if name in structure_names
            ]
            try:
                assert len(names_to_pull) == len(structure_names)
            except:
                batch_inputs.pop()
                continue

            structure_indexes = [
                dcm_rs_struct_names.index(name) for name in names_to_pull
            ]

            mask = np.zeros(shape=(1, 512, 512, len(structure_indexes)))

            dx, dy, *rest = img_spacing
            Cx, Cy, *rest = img_position
            Ox, Oy = img_orientation[0], img_orientation[4]

            for mask_index, structure_index in enumerate(structure_indexes):
                z = [
                    z_slice.ContourData[2::3][0]
                    for z_slice in dicom_structures.
                    ROIContourSequence[structure_index].ContourSequence
                ]

                try:
                    indexes = z.index(img_position[2])
                except:
                    continue

                try:
                    len(indexes)
                except:
                    indexes = [indexes]

                for index in indexes:
                    xyz = dicom_structures.ROIContourSequence[
                        structure_index].ContourSequence[index].ContourData

                    x = np.array(xyz[0::3])
                    y = np.array(xyz[1::3])

                    r = (y - Cy) / dy * Oy
                    c = (x - Cx) / dx * Ox

                    rr, cc = skimage.draw.polygon(r, c)

                    mask[:, rr, cc, mask_index] = True
            mask = skimage.transform.resize(mask, (1, 128, 128, 1))
            batch_labels.append(mask)

        ###################### RETURNS ###################################
        batch_inputs = np.array(batch_inputs)
        batch_inputs = batch_inputs[..., np.newaxis]

        #batch_input = np.array(batch_inputs)
        batch_labels = np.array(batch_labels)

        return batch_inputs, batch_labels

    def __len__(self):
        # number of batches per epoch
        return int(np.ceil(len(self.input_paths) / float(self.batch_size)))

    def on_epoch_end(self):
        """Updates indexes after each epoch
        """
        None


In [None]:
train_gen = DataGen(train_paths,
                       context_paths,
                       label_paths,
                       context=CONTEXT,
                       batch_size=BATCH_SIZE,
                       structure_names=STRUCTURE_NAMES)

valid_gen = DataGen(valid_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES)

# test_gen = DataGen(test_paths,
#                     context_paths,
#                     label_paths,
#                     context=CONTEXT,
#                     batch_size=BATCH_SIZE,
#                     structure_names=STRUCTURE_NAMES)

In [None]:
from random import randint
batch_index = randint(0, round(len(train_paths) / BATCH_SIZE) - 1)
print(batch_index)
inputs, labels = train_gen.__getitem__(batch_index=batch_index)


In [None]:
print(inputs.shape)
print(labels.shape)

In [None]:
for arr in labels: print(arr.shape)

In [None]:
index = 0

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 10), sharex=True, sharey=True)
axes[0].imshow(inputs[index,0,:,:,0])
axes[1].imshow(labels[index,0,:,:,0])
fig.tight_layout()

In [None]:
# %load model_myfull_slim
import tensorflow as tf


def down_block(x, m, n, c, size):
    Zc = int((n / 2) * (size - 1))
    crop = tf.keras.layers.Cropping3D(cropping=(Zc, 0, 0))(x)
    crop = tf.keras.layers.Conv3D(c, 1, activation=None)(crop)

#     result = tf.keras.layers.ReLU()(x)
    result = x
    for repeat in range(m):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same')(result)
        result = tf.keras.layers.ReLU()(result)

    for repeat in range(n):
#         result = tf.keras.layers.Conv3D(c, (1, 3, 3),
#                                         strides=1,
#                                         padding='same', activation=None)(result)
        result = tf.keras.layers.Conv3D(c, (size, 1, 1),
                                        strides=1,
                                        padding='valid', activation=None)(result)
        
#         if repeat != range(n)[-1]:
#             result = tf.keras.layers.ReLU()(result)

    result = tf.keras.layers.Add()([crop, result])

    return result


def pool(x, size):
    result = tf.keras.layers.AveragePooling3D(pool_size=(1, size, size),
                                              strides=None,
                                              padding='valid')(x)
    return result


def fc_block(x, r):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.layers.Conv3D(1024, (1, 8, 8),
                                    strides=1,
                                    padding='valid')(x)
    for repeat in range(r):
        crop = result
        # TODO: Should this be a dense layer with RelU activation instead?
        result = tf.keras.layers.ReLU()(result)
        result = tf.keras.layers.Add()([crop, result])

#     result = tf.keras.layers.ReLU()(result)
    result = tf.keras.layers.Reshape((1, 8, 8, 256))(x)

    return result


def up_block(x, m, c):
    initializer = tf.random_normal_initializer(0., 0.02)

    crop = tf.keras.layers.Conv3D(c, 1, activation=None)(x)

    result = tf.keras.layers.ReLU()(x)
    for repeat in range(m):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same')(result)
        result = tf.keras.layers.ReLU()(result)
    result = tf.keras.layers.Add()([crop, result])
    return result


def upscale(x, size):
    result = tf.keras.layers.UpSampling3D(size=(1, size, size))(x)
    return result


def stack(x, skip):
    # NOTE axis 0 is the batch
    result = tf.keras.layers.Concatenate(axis=1)([x, skip])
    return result


# def Model(input_shape, output_channels):
#     input_shape = (1, 128, 128, 1)
#     inputs = tf.keras.layers.Input(shape=input_shape)
#     skips = []

#     x = down_block(inputs, 0, 2, 64, 5)
#     skips.append(x)
# #    x = pool(x, 4)

# #     x = down_block(x, 0, 2, 128, 4)
# #     skips.append(x)
# #     x = pool(x, 4)

#    # x = down_block(x, 0, 2, 256, 4)
#    # skips.append(x)
#    # x = pool(x, 4)

#     #x = fc_block(x, 2)

#    # x = upscale(x, 4)
#    # x = stack(skips[-1], x)
#    # x = up_block(x, 1, 128)

# #     x = upscale(x, 4)
# #     x = stack(skips[-1], x)
# #     x = up_block(x, 1, 64)

# #    x = upscale(x, 4)
#     x = stack(skips[-1], x)
#     x = up_block(x, 1, 1)

#     x = tf.keras.layers.Conv3D(filters=output_channels,
#                                kernel_size=(26, 1, 1),
#                                strides=1,
#                                activation='sigmoid',
#                                padding='valid')(x)

#     return tf.keras.Model(inputs=inputs, outputs=x)


In [None]:
def Model(input_shape, output_channels):
    input_shape = (21, 128, 128, 1)
    inputs = tf.keras.layers.Input(shape=input_shape)
    skips = []

    x = down_block(inputs, 0, 1, 64, 5)
    skips.append(x)
    x = pool(x, 4)

#     x = down_block(x, 0, 2, 128, 4)
#     skips.append(x)
#     x = pool(x, 4)

   # x = down_block(x, 0, 2, 256, 4)
   # skips.append(x)
   # x = pool(x, 4)

    #x = fc_block(x, 2)

   # x = upscale(x, 4)
   # x = stack(skips[-1], x)
   # x = up_block(x, 1, 128)

#     x = upscale(x, 4)
#     x = stack(skips[-1], x)
#     x = up_block(x, 1, 64)

    x = upscale(x, 4)
    x = stack(skips[-1], x)
    x = up_block(x, 1, 1)

    x = tf.keras.layers.Conv3D(filters=output_channels,
                               kernel_size=(34, 1, 1),
                               strides=1,
                               activation='sigmoid',
                               padding='valid')(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model = Model(INPUT_SHAPE, OUTPUT_CHANNELS)

model.compile(optimizer=OPTIMIZER, loss=LOSS)

tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)

In [None]:
model.summary()

In [None]:
model.fit_generator(generator=train_gen,
                    validation_data=valid_gen,
                    steps_per_epoch=len(train_paths) // BATCH_SIZE,
                    epochs=EPOCHS)

In [None]:
# Get testing batch
from random import randint
batch_index = randint(0, round(len(test_paths) / BATCH_SIZE))
test_inputs, test_labels = test_gen.__getitem__(batch_index=7)

In [None]:
test_labels.shape

In [None]:
predicts = model.predict(test_inputs)

In [None]:
index = 4

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 10), sharex=True, sharey=True)
axes[0].imshow(predicts[index,0,:,:,0])
axes[1].imshow(test_labels[index,0,:,:,0])
fig.tight_layout()