In [None]:
import pathlib
import tensorflow as tf
import tensorflow.keras.backend as K
import skimage

import imageio

import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys._experimental.autosegmentation import unet

In [None]:
structure_uids = [
    path.name for path in pathlib.Path('data').glob('*')
]

structure_uids

In [None]:
split_num = len(structure_uids) - 2
training_uids = structure_uids[0:split_num]
testing_uids = structure_uids[split_num:]

In [None]:
training_uids

In [None]:
testing_uids

In [None]:
def get_image_paths_for_uids(uids):
    image_paths = [
        str(path) for path in pathlib.Path('data').glob('**/*_image.png')
        if path.parent.name in uids
    ]
    np.random.shuffle(image_paths)
    
    return image_paths


def mask_paths_from_image_paths(image_paths):
    mask_paths = [
        f"{image_path.split('_')[0]}_mask.png"
        for image_path in image_paths
    ]
    
    return mask_paths

In [None]:
training_image_paths = get_image_paths_for_uids(training_uids)
training_mask_paths = mask_paths_from_image_paths(training_image_paths)

len(training_image_paths), len(training_mask_paths)

In [None]:
testing_image_paths = get_image_paths_for_uids(testing_uids)
testing_mask_paths = mask_paths_from_image_paths(testing_image_paths)

len(testing_image_paths), len(testing_mask_paths)

In [None]:
def _centre_crop(image):
    shape = image.shape
    cropped = image[
        shape[0]//4:3*shape[0]//4,
        shape[1]//4:3*shape[1]//4,
        ...
    ]
    return cropped

In [None]:
def _process_mask(png_mask):
    normalised_mask = png_mask / 255
    cropped = _centre_crop(normalised_mask)
    
    return cropped

In [None]:
# def _remove_mask_weights(weighted_mask):
#     return weighted_mask / mask_weights
    
 

In [None]:
for mask_path in testing_mask_paths[0:1]:
    png_mask = imageio.imread(mask_path)
    processed_mask = _process_mask(png_mask)
    plt.imshow(png_mask)
    plt.show()
    plt.imshow(processed_mask)
    plt.colorbar()
    plt.show()

In [None]:
processed_mask.shape

In [None]:
def _process_image(png_image):
    normalised_image = png_image[:,:,None].astype(float) / 255
    cropped = _centre_crop(normalised_image)
    return cropped

In [None]:
for image_path in testing_image_paths[0:1]:
    png_image = imageio.imread(image_path)
    processed_image = _process_image(png_image)
    plt.imshow(png_image)
    plt.colorbar()
    plt.show()
    plt.imshow(processed_image)
    plt.colorbar()
    plt.show()

In [None]:
def get_datasets(image_paths, mask_paths):
    input_arrays = []
    output_arrays = []
    for image_path, mask_path in zip(image_paths, mask_paths):
        input_arrays.append(_process_image(imageio.imread(image_path)))
        output_arrays.append(_process_mask(imageio.imread(mask_path)))
        
    images = tf.cast(np.array(input_arrays), tf.float32)
    masks = tf.cast(np.array(output_arrays), tf.float32)
    
    return images, masks

In [None]:
training_images, training_masks = get_datasets(training_image_paths, training_mask_paths)
testing_images, testing_masks = get_datasets(testing_image_paths, testing_mask_paths)

In [None]:
# dir(K)

In [None]:
mask_dims = training_masks.shape
mask_dims

In [None]:
testing_masks.shape

In [None]:
def display(display_list):
    plt.figure(figsize=(18, 5))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])            
        plt.imshow(display_list[i])
        plt.colorbar()
        plt.axis('off')
        
    plt.show()
    
display([testing_images[0,:,:,:], testing_masks[0,:,:,:]])

In [None]:
has_brain = np.sum(testing_masks[:,:,:,1], axis=(1,2))
has_eyes = np.sum(testing_masks[:,:,:,0], axis=(1,2))

brain_sort = 1 - np.argsort(has_brain) / len(has_brain)
eyes_sort = 1 - np.argsort(has_eyes) / len(has_eyes)

max_combo = np.argmax(brain_sort * eyes_sort * has_brain * has_eyes)
sample_index = max_combo

In [None]:
# eyes_sort

In [None]:
# brain_sort

In [None]:
sample_image = testing_images[max_combo,:,:,:]
sample_mask = testing_masks[max_combo,:,:,:]

In [None]:
plt.imshow(sample_image)

In [None]:
plt.imshow(sample_mask)

In [None]:
assert mask_dims[1] == mask_dims[2]
grid_size = int(mask_dims[2])
output_channels = int(mask_dims[-1])

tf.keras.backend.clear_session()
model = unet.unet(
    grid_size=grid_size, 
    output_channels=output_channels, 
    number_of_filters_start=32,
    max_filter_num=64,
    min_grid_size=8,
    num_of_fc=2,
    batch_normalisation=False,
    use_dropout=False,
)
model.summary()

In [None]:
def show_prediction():
    predicted_masks = model.predict(testing_images)
    display(
        [
            sample_image, sample_mask,
            predicted_masks[sample_index,:,:,:]
        ]
    )
        
        
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        show_prediction()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
        
show_prediction()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(
#         learning_rate=0.0001
    ),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision()
    ]
)

In [None]:
# model.load_weights('./checkpoints/binomial-cross-entropy')

In [None]:
history = model.fit(
    training_images, 
    training_masks,
    epochs=100,
#     batch_size=training_masks.shape[0]//3,
    validation_data=(testing_images, testing_masks),
    callbacks=[DisplayCallback()]
)

In [None]:
# checkpoints_dir = pathlib.Path('checkpoints')
# checkpoints_dir.mkdir(exist_ok=True)

In [None]:
# model.save_weights(checkpoints_dir.joinpath('final'))