In [None]:
import pathlib
import urllib.request
import shutil

import numpy as np
import matplotlib.pyplot as plt
import imageio

In [None]:
import tensorflow as tf

In [None]:
zip_url = 'https://zenodo.org/record/4448689/files/minified-animal-patient-brain-orbits.zip?download=1'
zip_filepath = 'data.zip'

data_directory = pathlib.Path('data')

if not data_directory.exists():
    urllib.request.urlretrieve(zip_url, zip_filepath)
    shutil.unpack_archive(zip_filepath, data_directory)

In [None]:
dataset_types = [path.name for path in data_directory.glob('*') if path.is_dir()]
dataset_types

In [None]:
def _load_image(image_path):
    png_image = imageio.imread(image_path)
    normalised_image = png_image[:,:,None].astype(float) / 255
    
    return normalised_image


def _load_mask(mask_path):
    png_mask = imageio.imread(mask_path)
    normalised_mask = png_mask / 255
    
    return normalised_mask

In [None]:
def load_dataset_type(dataset_type):
    image_suffix = '_image.png'
    mask_suffix = '_mask.png'
    
    image_paths = list(data_directory.joinpath(dataset_type).glob(f'**/*{image_suffix}'))
    np.random.shuffle(image_paths)
    
    mask_paths = [
        path.parent / path.name.replace(image_suffix, mask_suffix)
        for path in image_paths
    ]
    
    image_arrays = [
        _load_image(image_path)
        for image_path in image_paths
    ]
    mask_arrays = [
        _load_mask(mask_path)
        for mask_path in mask_paths
    ]
        
    images =np.array(image_arrays)
    masks = np.array(mask_arrays)
    
    return images, masks

In [None]:
training_images, training_masks = load_dataset_type('training')
validation_images, validation_masks = load_dataset_type('validation')

In [None]:
def _find_image_with_most_variety(images, masks):
    has_brain = np.sum(masks[:,:,:,1], axis=(1,2))
    has_eyes = np.sum(masks[:,:,:,0], axis=(1,2))

    brain_sort = 1 - np.argsort(has_brain) / len(has_brain)
    eyes_sort = 1 - np.argsort(has_eyes) / len(has_eyes)

    max_combo = np.argmax(brain_sort * eyes_sort * has_brain * has_eyes)

    sample_image = images[max_combo,:,:,:]
    sample_mask = masks[max_combo,:,:,:]
    
    return sample_image, sample_mask


sample_image, sample_mask = _find_image_with_most_variety(
    validation_images, validation_masks
)

In [None]:
def display(image, mask):
    plt.figure(figsize=(18, 5))

    title = ['Input Image', 'True Mask', 'Predicted Mask']
    
    plt.subplot(1, 3, 1)
    plt.title('Input Image')            
    plt.imshow(image[:,:,0])
    plt.colorbar()
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.title('True Mask')            
    plt.imshow(mask)
    plt.colorbar()
    plt.axis('off')
    
    try:
        precited_mask = model.predict(image[None, ...])[0, ...]
        plt.subplot(1, 3, 3)
        plt.title('Predicted Mask')            
        plt.imshow(precited_mask)
        plt.colorbar()
        plt.axis('off')
    except NameError:
        pass
        
    plt.show()
    
    
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        display(sample_image, sample_mask)
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
    
    
display(sample_image, sample_mask)

In [None]:
def _activation(x):
    x = tf.keras.layers.Activation("relu")(x)

    return x


def _convolution(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2D(
        number_of_filters, kernel_size, padding="same", kernel_initializer="he_normal"
    )(x)

    return x


def _conv_transpose(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2DTranspose(
        number_of_filters,
        kernel_size,
        strides=2,
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    return x

In [None]:
def encode(
    x,
    number_of_filters,
    number_of_convolutions=2,
):
    for _ in range(number_of_convolutions):
        x = _convolution(x, number_of_filters)
        x = _activation(x)
    skip = x

    x = tf.keras.layers.MaxPool2D()(x)
    x = _activation(x)

    return x, skip

In [None]:
def decode(
    x,
    skip,
    number_of_filters,
    number_of_convolutions=2,
):
    x = _conv_transpose(x, number_of_filters)
    x = _activation(x)

    x = tf.keras.layers.concatenate([skip, x], axis=3)

    for _ in range(number_of_convolutions):
        x = _convolution(x, number_of_filters)
        x = _activation(x)

    return x

In [None]:
mask_dims = training_masks.shape
assert mask_dims[1] == mask_dims[2]
grid_size = int(mask_dims[2])
output_channels = int(mask_dims[-1])

In [None]:
inputs = tf.keras.layers.Input((grid_size, grid_size, 1))
x = inputs
skips = []

for number_of_filters in [32, 64, 128]:
    x, skip = encode(x, number_of_filters)
    skips.append(skip)
    
skips.reverse()

for number_of_filters, skip in zip([256, 128, 64], skips):
    x = decode(x, skip, number_of_filters)
    
x = tf.keras.layers.Conv2D(
    output_channels,
    1,
    activation="sigmoid",
    padding="same",
    kernel_initializer="he_normal",
)(x)

model = tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model.summary()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision()
    ]
)

In [None]:
history = model.fit(
    training_images, 
    training_masks,
    epochs=100,
    validation_data=(validation_images, validation_masks),
    callbacks=[DisplayCallback()]
)