In [None]:
import pathlib
import tensorflow as tf
import tensorflow.keras.backend as K
import skimage

import imageio

import numpy as np
import matplotlib.pyplot as plt

In [None]:
def get_image_paths_for_uids(uids):
    image_paths = [
        str(path) for path in pathlib.Path('data').glob('**/*_image.png')
        if path.parent.name in uids
    ]
    np.random.shuffle(image_paths)
    
    return image_paths


def mask_paths_from_image_paths(image_paths):
    mask_paths = [
        f"{image_path.split('_')[0]}_mask.png"
        for image_path in image_paths
    ]
    
    return mask_paths

In [None]:
structure_uids = [
    path.name for path in pathlib.Path('data').glob('*')
]
split_num = len(structure_uids) - 2
training_uids = structure_uids[0:split_num]
testing_uids = structure_uids[split_num:]

training_image_paths = get_image_paths_for_uids(training_uids)
training_mask_paths = mask_paths_from_image_paths(training_image_paths)

testing_image_paths = get_image_paths_for_uids(testing_uids)
testing_mask_paths = mask_paths_from_image_paths(testing_image_paths)

In [None]:
def _centre_crop(image):
    shape = image.shape
    cropped = image[
        shape[0]//4:3*shape[0]//4,
        shape[1]//4:3*shape[1]//4,
        ...
    ]
    return cropped

In [None]:
def _process_mask(png_mask):
    normalised_mask = png_mask / 255
    cropped = _centre_crop(normalised_mask)
    
    return cropped

In [None]:
def _process_image(png_image):
    normalised_image = png_image[:,:,None].astype(float) / 255
    cropped = _centre_crop(normalised_image)
    return cropped

In [None]:
def get_datasets(image_paths, mask_paths):
    input_arrays = []
    output_arrays = []
    for image_path, mask_path in zip(image_paths, mask_paths):
        input_arrays.append(_process_image(imageio.imread(image_path)))
        output_arrays.append(_process_mask(imageio.imread(mask_path)))
        
    images = tf.cast(np.array(input_arrays), tf.float32)
    masks = tf.cast(np.array(output_arrays), tf.float32)
    
    return images, masks

In [None]:
training_images, training_masks = get_datasets(training_image_paths, training_mask_paths)
testing_images, testing_masks = get_datasets(testing_image_paths, testing_mask_paths)

In [None]:
def display(display_list):
    plt.figure(figsize=(18, 5))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])            
        plt.imshow(display_list[i])
        plt.colorbar()
        plt.axis('off')
        
    plt.show()

In [None]:
has_brain = np.sum(testing_masks[:,:,:,1], axis=(1,2))
has_eyes = np.sum(testing_masks[:,:,:,0], axis=(1,2))

brain_sort = 1 - np.argsort(has_brain) / len(has_brain)
eyes_sort = 1 - np.argsort(has_eyes) / len(has_eyes)

max_combo = np.argmax(brain_sort * eyes_sort * has_brain * has_eyes)

sample_image = testing_images[max_combo,:,:,:]
sample_mask = testing_masks[max_combo,:,:,:]

In [None]:
def _activation(x):
    x = tf.keras.layers.Activation("relu")(x)

    return x


def _convolution(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2D(
        number_of_filters, kernel_size, padding="same", kernel_initializer="he_normal"
    )(x)

    return x


def _conv_transpose(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2DTranspose(
        number_of_filters,
        kernel_size,
        strides=2,
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    return x

In [None]:
def decode(
    x,
    skip,
    number_of_filters,
    number_of_convolutions=2,
):
    x = _conv_transpose(x, number_of_filters)
    x = _activation(x)

    x = tf.keras.layers.concatenate([skip, x], axis=3)

    for _ in range(number_of_convolutions):
        x = _convolution(x, number_of_filters)
        x = _activation(x)

    return x

In [None]:
def encode(
    x,
    number_of_filters,
    number_of_convolutions=2,
):
    for _ in range(number_of_convolutions):
        x = _convolution(x, number_of_filters)
        x = _activation(x)
    skip = x

    x = tf.keras.layers.MaxPool2D()(x)
    x = _activation(x)

    return x, skip

In [None]:
def fully_connected(
    x,
    fc_channels,
    interface_grid_size,
    interface_channels,
    fc_repeats=2,
):
    start = x
    
    x = tf.keras.layers.Conv2D(
        fc_channels, interface_grid_size, padding="valid", kernel_initializer="he_normal"
    )(x)
    
    for _ in range(fc_repeats):
        residual = x
        x = _activation(x)
        x = tf.keras.layers.Dense(fc_channels)(x)
        x = tf.keras.layers.Add()([residual, x])

    x = _activation(x)
    
    x = tf.keras.layers.Dense(
        interface_grid_size * interface_grid_size * interface_channels
    )(x)
    x = tf.keras.layers.Reshape(
        (interface_grid_size, interface_grid_size, interface_channels)
    )(x)
    
    x = tf.keras.layers.Add()([start, x])
    return x

In [None]:
mask_dims = training_masks.shape
assert mask_dims[1] == mask_dims[2]
grid_size = int(mask_dims[2])
output_channels = int(mask_dims[-1])

In [None]:
inputs = tf.keras.layers.Input((grid_size, grid_size, 1))
x = inputs
skips = []

for number_of_filters in [32, 64]:
    x, skip = encode(x, number_of_filters)
    skips.append(skip)
    
skips.reverse()

x = fully_connected(
    x,
    fc_channels=256,
    interface_grid_size=8,
    interface_channels=64,
    fc_repeats=2,
)

for number_of_filters, skip in zip([128, 64], skips):
    x = decode(x, skip, number_of_filters)
    
x = tf.keras.layers.Conv2D(
    output_channels,
    1,
    activation="sigmoid",
    padding="same",
    kernel_initializer="he_normal",
)(x)

model = tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model.summary()

In [None]:
def show_prediction():
    display(
        [
            sample_image, sample_mask,
            model.predict(sample_image[None,:,:,:])[0,:,:,:]
        ]
    )
        
        
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        show_prediction()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
        
show_prediction()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision()
    ]
)

In [None]:
history = model.fit(
    training_images, 
    training_masks,
    epochs=100,
    validation_data=(testing_images, testing_masks),
    callbacks=[DisplayCallback()]
)

In [None]:
# checkpoints_dir = pathlib.Path('checkpoints')
# checkpoints_dir.mkdir(exist_ok=True)

In [None]:
# model.save_weights(checkpoints_dir.joinpath('final'))