In [None]:
import pathlib
import tensorflow as tf
import tensorflow.keras.backend as K
import skimage

import imageio

import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys._experimental.autosegmentation import unet

In [None]:
structure_uids = [
    path.name for path in pathlib.Path('data').glob('*')
]

structure_uids

In [None]:
split_num = len(structure_uids) - 2
training_uids = structure_uids[0:split_num]
testing_uids = structure_uids[split_num:]

In [None]:
training_uids

In [None]:
testing_uids

In [None]:
def get_image_paths_for_uids(uids):
    image_paths = [
        str(path) for path in pathlib.Path('data').glob('**/*_image.png')
        if path.parent.name in uids
    ]
    np.random.shuffle(image_paths)
    
    return image_paths


def mask_paths_from_image_paths(image_paths):
    mask_paths = [
        f"{image_path.split('_')[0]}_mask.png"
        for image_path in image_paths
    ]
    
    return mask_paths

In [None]:
training_image_paths = get_image_paths_for_uids(training_uids)
training_mask_paths = mask_paths_from_image_paths(training_image_paths)

len(training_image_paths), len(training_mask_paths)

In [None]:
testing_image_paths = get_image_paths_for_uids(testing_uids)
testing_mask_paths = mask_paths_from_image_paths(testing_image_paths)

len(testing_image_paths), len(testing_mask_paths)

In [None]:
def _centre_crop(image):
    shape = image.shape
    cropped = image[
        shape[0]//4:3*shape[0]//4,
        shape[1]//4:3*shape[1]//4,
        ...
    ]
    return cropped

In [None]:
def _process_mask(png_mask):
    normalised_mask = png_mask / 255
    cropped = _centre_crop(normalised_mask)
    
    return cropped

In [None]:
# def _remove_mask_weights(weighted_mask):
#     return weighted_mask / mask_weights
    
 

In [None]:
for mask_path in testing_mask_paths[0:5]:
    png_mask = imageio.imread(mask_path)
    processed_mask = _process_mask(png_mask)
    plt.imshow(png_mask)
    plt.show()
    plt.imshow(processed_mask)
    plt.colorbar()
    plt.show()

In [None]:
processed_mask.shape

In [None]:
def _process_image(png_image):
    normalised_image = png_image[:,:,None].astype(float) / 255
    cropped = _centre_crop(normalised_image)
    return cropped

In [None]:
for image_path in testing_image_paths[0:5]:
    png_image = imageio.imread(image_path)
    processed_image = _process_image(png_image)
    plt.imshow(png_image)
    plt.colorbar()
    plt.show()
    plt.imshow(processed_image)
    plt.colorbar()
    plt.show()

In [None]:
def get_datasets(image_paths, mask_paths):
    input_arrays = []
    output_arrays = []
    for image_path, mask_path in zip(image_paths, mask_paths):
        input_arrays.append(_process_image(imageio.imread(image_path)))
        output_arrays.append(_process_mask(imageio.imread(mask_path)))
        
    images = tf.cast(np.array(input_arrays), tf.float32)
    masks = tf.cast(np.array(output_arrays), tf.float32)
    
    return images, masks

In [None]:
training_images, training_masks = get_datasets(training_image_paths, training_mask_paths)
testing_images, testing_masks = get_datasets(testing_image_paths, testing_mask_paths)

In [None]:
# dir(K)

In [None]:
mask_dims = training_masks.shape
mask_dims

In [None]:
testing_masks.shape

In [None]:
def display(display_list):
    plt.figure(figsize=(18, 5))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])            
        plt.imshow(display_list[i])
        plt.colorbar()
        plt.axis('off')
        
    plt.show()
    
display([testing_images[0,:,:,:], testing_masks[0,:,:,:]])

In [None]:
def sigmoid_about_zero(mask):
    return 2 * (K.sigmoid(mask) - 0.5)

In [None]:
scharr_x = np.array([
    [47, 0, -47],
    [162, 0, -162],
    [47, 0, -47]
]).astype(np.float32)
scharr_y = scharr_x.T
scharr_x = K.constant(scharr_x)
scharr_y = K.constant(scharr_y)

def _apply_sharr_filter(image):
    items = []
    for i in range(image.shape[-1]):
        x = tf.compat.v1.nn.convolution(image[:,:,:,i][:,:,:,None], scharr_x[:,:,None,None], padding="VALID") / 255
        y = tf.compat.v1.nn.convolution(image[:,:,:,i][:,:,:,None], scharr_y[:,:,None,None], padding="VALID") / 255
        items.append(
            K.sqrt(x**2 + y**2)
        )
        
    return K.concatenate(items, axis=-1)

In [None]:
def _collect_scharr_filters_without_merging(image):
    items = []
    for i in range(image.shape[-1]):
        x = tf.compat.v1.nn.convolution(image[:,:,:,i][:,:,:,None], scharr_x[:,:,None,None], padding="VALID") / 255
        y = tf.compat.v1.nn.convolution(image[:,:,:,i][:,:,:,None], scharr_y[:,:,None,None], padding="VALID") / 255
        items.append(x)
        items.append(y)
        
    return K.concatenate(items, axis=-1)

In [None]:
filtered_masks = _apply_sharr_filter(K.constant(tf.cast(testing_masks, tf.float32)))

In [None]:
plt.imshow(testing_masks[0,:,:,2])
plt.colorbar()

In [None]:
plt.imshow(filtered_masks[0,:,:,2])
plt.colorbar()

In [None]:
left_pixel_removed = testing_masks[:,:,1:,:]
right_pixel_removed = testing_masks[:,:,:-1,:]

In [None]:
edge_reference = _apply_sharr_filter(left_pixel_removed)
edge_evaluation = _apply_sharr_filter(right_pixel_removed)

In [None]:
plt.imshow(edge_reference[0,:,:,2])
plt.colorbar()

In [None]:
plt.imshow(edge_evaluation[0,:,:,2])

In [None]:
diff_image = K.abs(edge_evaluation - edge_reference)


In [None]:
plt.imshow((K.sigmoid(diff_image[0,:,:,2]) - 0.5) * 2)
plt.colorbar()

In [None]:
# edge_reference - edge_evaluation

In [None]:
1 - K.sum(K.sqrt(edge_reference * edge_evaluation)) / K.sum((edge_reference + edge_evaluation) / 2)

In [None]:
1 - K.sum(K.sqrt(edge_reference * edge_reference)) / K.sum((edge_reference + edge_reference) / 2)

In [None]:
plt.imshow(K.sqrt(edge_reference * edge_reference)[0,:,:,2])
plt.colorbar()

In [None]:
K.sum(K.abs(edge_reference - edge_evaluation)) / K.sum(edge_reference + edge_evaluation)

In [None]:
K.sum(K.abs(edge_reference - edge_reference)) / K.sum(edge_reference + edge_reference)

In [None]:
# plt.imshow(_absolute_like_sigmoid(_apply_sharr_filter(testing_masks) * _apply_sharr_filter(testing_masks))[0,:,:,2])
# plt.colorbar()

In [None]:
# loss = K.sum(diff_image) / (K.sum(edge_evaluation) + K.sum(edge_reference))
# loss

In [None]:
# dir(K)

In [None]:
# intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
# sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
# jac = (intersection + smooth) / (sum_ - intersection + smooth)

In [None]:
1 - K.sum(K.sqrt(edge_reference * edge_reference)) / K.sum((edge_reference + edge_reference) / 2)

In [None]:
np.min(edge_reference)

In [None]:
bce = tf.keras.losses.BinaryCrossentropy()

scharr_weight = 1
# beta = 0.99

def scharr(reference, evaluation):   
    edge_reference = _collect_scharr_filters_without_merging(reference)
    edge_evaluation = _collect_scharr_filters_without_merging(evaluation)
    
    return bce(K.sigmoid(edge_reference), K.sigmoid(edge_evaluation))

def scharr_and_bce(reference, evaluation):
    bce_loss = bce(reference, evaluation)
    scharr_loss = scharr(reference, evaluation)
    
    loss = scharr_weight * scharr_loss + bce_loss
    
    return loss

In [None]:
# for i, label in enumerate(['eye', 'brain', 'patient']):
#     loss = scharr(testing_masks[...,i:i+1], predicted_masks[...,i:i+1])
#     print(f"{label} loss = {loss}")

In [None]:
weights = {
    'eye': 1/0.055,
    'brain': 1/0.24,
    'patient': 1/1.7
}

def weighted_scharr(reference, evaluation):
    loss = 0
    for i, label in enumerate(['eye', 'brain', 'patient']):
        loss += (
            weights[label] * 
            scharr(reference[...,i:i+1], evaluation[...,i:i+1])
        )
    return loss

In [None]:
scharr(left_pixel_removed, right_pixel_removed)

In [None]:
scharr(right_pixel_removed, left_pixel_removed)

In [None]:
scharr(right_pixel_removed, right_pixel_removed)

In [None]:
# np.random.randint

In [None]:
# np.random.choice?
testing_images.shape

In [None]:
has_brain = np.sum(testing_masks[:,:,:,1], axis=(1,2))
has_eyes = np.sum(testing_masks[:,:,:,0], axis=(1,2))

brain_sort = 1 - np.argsort(has_brain) / len(has_brain)
eyes_sort = 1 - np.argsort(has_eyes) / len(has_eyes)

max_combo = np.argmax(brain_sort * eyes_sort * has_brain * has_eyes)
sample_index = max_combo

In [None]:
# eyes_sort

In [None]:
# brain_sort

In [None]:
sample_image = testing_images[max_combo,:,:,:]
sample_mask = testing_masks[max_combo,:,:,:]

In [None]:
plt.imshow(sample_image)

In [None]:
plt.imshow(sample_mask)

In [None]:
assert mask_dims[1] == mask_dims[2]
grid_size = int(mask_dims[2])
output_channels = int(mask_dims[-1])

tf.keras.backend.clear_session()
model = unet.unet(
    grid_size=grid_size, 
    output_channels=output_channels, 
    number_of_filters_start=32,
    max_filter_num=32,
    min_grid_size=8,
    num_of_fc=2
)
model.summary()

In [None]:
def show_prediction():
    predicted_masks = model.predict(testing_images)
    
    display(
        [
            sample_image, sample_mask,
            predicted_masks[sample_index,:,:,:]
        ]
    )
    for i, label in enumerate(['eye', 'brain', 'patient']):
        loss = scharr(testing_masks[...,i:i+1], predicted_masks[...,i:i+1])
        print(f"{label} loss = {loss}")
        
        
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        show_prediction()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
        
show_prediction()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(
#         learning_rate=0.0001
    ),
    loss=bce,
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision()
    ]
)

In [None]:
# model.load_weights('./checkpoints/binomial-cross-entropy')

In [None]:
history = model.fit(
    training_images, 
    training_masks,
    epochs=100,
#     batch_size=training_masks.shape[0]//3,
    validation_data=(testing_images, testing_masks),
    callbacks=[DisplayCallback()]
)

In [None]:

        
# show_predictions(num=1)

In [None]:
# loss_object = scharr_loss_with_bce
# optimizer = tf.keras.optimizers.Adam()

# def train_step(images, masks):
#     with tf.GradientTape() as tape:
#         logits = model(images, training=True)
#         loss_value = loss_object(masks, logits, debug=True)
        
#     grads = tape.gradient(loss_value, model.trainable_variables)
#     optimizer.apply_gradients(zip(grads, model.trainable_variables))
    

# def train(epochs):
#     for epoch in range(epochs):
#         train_step(training_images, training_masks)

#         print ('Epoch {} finished'.format(epoch))
#         show_predictions()
        
# train(100)

In [None]:
tf.keras.optimizers.Adam?

In [None]:
# model.fit?

In [None]:
checkpoints_dir = pathlib.Path('checkpoints')
checkpoints_dir.mkdir(exist_ok=True)

In [None]:
model.save_weights(checkpoints_dir.joinpath('scharr'))