In [None]:
import pathlib
import tensorflow as tf
import tensorflow.keras.backend as K
import skimage

import imageio

import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys._experimental.autosegmentation import unet

In [None]:
output_channels=3

In [None]:
model = unet.unet(grid_size=64, output_channels=output_channels)  # background, patient, brain, eyes

# tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
model

In [None]:
structure_uids = [
    path.name for path in pathlib.Path('data').glob('*')
]

structure_uids

In [None]:
split_num = len(structure_uids) - 2
training_uids = structure_uids[0:split_num]
testing_uids = structure_uids[split_num:]

In [None]:
training_uids

In [None]:
testing_uids

In [None]:
def get_image_paths_for_uids(uids):
    image_paths = [
        str(path) for path in pathlib.Path('data').glob('**/*_image.png')
        if not path.parent.name in uids
    ]
    np.random.shuffle(image_paths)
    
    return image_paths


def mask_paths_from_image_paths(image_paths):
    mask_paths = [
        f"{image_path.split('_')[0]}_mask.png"
        for image_path in image_paths
    ]
    
    return mask_paths

In [None]:
training_image_paths = get_image_paths_for_uids(training_uids)
training_mask_paths = mask_paths_from_image_paths(training_image_paths)

testing_image_paths = get_image_paths_for_uids(testing_uids)
testing_mask_paths = mask_paths_from_image_paths(testing_image_paths)

In [None]:
# training_image_paths

In [None]:

# mask_paths

In [None]:
# mask_weights = np.array([
#     0.9864694074789978, 0.9251728496022601, 0.0883577429187421
# ])[None, None, :]

In [None]:
def _normalise_mask(png_mask):
    normalised_mask = np.round(png_mask / 255).astype(float)
    
    return normalised_mask

In [None]:
# def _remove_mask_weights(weighted_mask):
#     return weighted_mask / mask_weights
    
 

In [None]:
png_mask = imageio.imread(testing_mask_paths[0])
normalised_mask = _normalise_mask(png_mask)
plt.imshow(png_mask)
plt.show()
plt.imshow(normalised_mask)
plt.colorbar()

In [None]:
normalised_mask.shape

In [None]:
np.max(normalised_mask[:,:,0])

In [None]:
np.max(normalised_mask[:,:,1])

In [None]:
np.max(normalised_mask[:,:,2])

In [None]:
def _normalise_image(png_image):
    normalised_image = png_image[:,:,None].astype(float) / 255
    return normalised_image

In [None]:
input_array = _normalise_image(imageio.imread(testing_image_paths[0]))
plt.imshow(input_array)

In [None]:
BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = 200

def get_dataset(image_paths, mask_paths):
    input_arrays = []
    output_arrays = []
    for image_path, mask_path in zip(image_paths, mask_paths):
        input_arrays.append(_normalise_image(imageio.imread(image_path)))
        output_arrays.append(_normalise_mask(imageio.imread(mask_path)))
        
    dataset = tf.data.Dataset.from_tensor_slices((input_arrays, output_arrays))
    dataset = dataset.repeat().shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
    
    return dataset

In [None]:
training_dataset = get_dataset(training_image_paths, training_mask_paths)
testing_dataset = get_dataset(testing_image_paths, testing_mask_paths)

In [None]:
for image, mask in training_dataset.take(1):
    sample_image_raw, sample_mask_raw = image, mask

has_brain = np.sum(sample_mask_raw[:,:,:,1], axis=(1,2))
has_eyes = np.sum(sample_mask_raw[:,:,:,0], axis=(1,2))

max_brain_eyes_combo = np.argmax(has_brain * has_eyes)

sample_image = sample_image_raw[max_brain_eyes_combo,:,:,:]
sample_mask = sample_mask_raw[max_brain_eyes_combo,:,:,:]

In [None]:
np.sum(sample_mask_raw[:,:,:,0])

In [None]:
np.sum(sample_mask_raw[:,:,:,0]==0) / np.sum(sample_mask_raw[:,:,:,0])

In [None]:
# scharr_operators 

In [None]:
# sch_mag = np.sqrt(sum([scharr(image, axis=i)**2
#                        for i in range(image.ndim)]) / image.ndim)

In [None]:
def _add_channels(kernel, output_channels, batch_size):
    kernel = np.concatenate([kernel[:,:,None],]*output_channels, axis=-1)
#     kernel = np.concatenate([kernel[None,:,:,:],]*batch_size, axis=0)
    return kernel

In [None]:
scharr_x = np.array([
    [47, 0, -47],
    [162, 0, -162],
    [47, 0, -47]
]).astype(np.float32)
scharr_y = scharr_x.T
scharr_x = K.constant(scharr_x)
scharr_y = K.constant(scharr_y)


# scharr_x = _add_channels(scharr_x, output_channels, BATCH_SIZE)
# scharr_y = _add_channels(scharr_y, output_channels, BATCH_SIZE)

In [None]:
sample_mask_raw.shape

In [None]:
sample_mask_raw[0,:,:,2][None,:,:,None].shape

In [None]:
scharr_x[None,:,:,None].shape

In [None]:
# dir(K)

In [None]:
def _apply_sharr_filter(image):
    items = []
    for i in range(image.shape[-1]):
        x = tf.compat.v1.nn.convolution(image[:,:,:,i][:,:,:,None], scharr_x[:,:,None,None], padding="VALID")
        y = tf.compat.v1.nn.convolution(image[:,:,:,i][:,:,:,None], scharr_y[:,:,None,None], padding="VALID")
        items.append(K.sqrt(x**2 + y**2))
        
    return K.concatenate(items, axis=-1)

In [None]:
image = K.constant(tf.cast(sample_mask_raw, tf.float32))

In [None]:
filtered = _apply_sharr_filter(image)

In [None]:
# K.conv1d?

In [None]:
# def _apply_kernel(image, kernel):
#     return K.conv2d(image[0:1,:,:,2:], kernel[:,:,None], padding="same", data_format='channels_last', dilation_rate=1, strides=1)

In [None]:
# x_dir = _apply_kernel(sample_mask_raw, scharr_x)
# y_dir = _apply_kernel(sample_mask_raw, scharr_y)

# magnitude = K.sqrt(x_dir**2 + y_dir**2)

In [None]:
plt.imshow(filtered[0,:,:,2])

In [None]:
edge_reference = skimage.filters.scharr(sample_mask_raw[0,:,:,2])
plt.imshow(edge_reference)

In [None]:
def skimage_scharr_loss(reference, evaluation):
    edge_reference = skimage.filters.scharr(reference)
    edge_evaluation = skimage.filters.scharr(evaluation)

    score = np.sum(np.abs(edge_evaluation - edge_reference)) / np.sum(
        edge_evaluation + edge_reference
    )
    
    return score

In [None]:
custom_weights = [0.98, 0.92, 0.08]

In [None]:
def scharr_loss(reference, evaluation):
    edge_reference = _apply_sharr_filter(reference)
    edge_evaluation = _apply_sharr_filter(evaluation)

    score = 0
    for i in range(edge_evaluation.shape[-1]):
        score += custom_weights[i] * K.sum(K.abs(edge_evaluation[:,:,:,i] - edge_reference[:,:,:,i]))
    
    return score

In [None]:
def jaccard_distance_loss(y_true, y_pred, smooth=100):
    """
    Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
            = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
    
    The jaccard distance loss is usefull for unbalanced datasets. This has been
    shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
    gradient.
    
    Ref: https://en.wikipedia.org/wiki/Jaccard_index
    
    @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
    @author: wassname
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac) * smooth

In [None]:
cross_entropy_weights = []
for i in range(3):
    num_of_ones = np.sum(sample_mask_raw[:,:,:,i])
    num_of_zeros = np.sum(sample_mask_raw[:,:,:,i]==0)
    
    one_weight = (1-num_of_ones)/(num_of_ones + num_of_zeros)
    zero_weight = (1-num_of_zeros)/(num_of_ones + num_of_zeros)
    
    cross_entropy_weights.append([one_weight, zero_weight])                                

In [None]:
def weighted_cross_entropy(y_true, y_pred):
    loss = 0
    for i in range(y_pred.shape[-1]):
        one_weight, zero_weight = cross_entropy_weights[i]
        b_ce = K.binary_crossentropy(y_true[:,:,:,i], y_pred[:,:,:,i])
        
        weight_vector = y_true[:,:,:,i] * one_weight + (1 - y_true[:,:,:,i]) * zero_weight
        loss += K.mean(weight_vector * b_ce)
        
    return loss

In [None]:
scharr_loss(image, image)

In [None]:
# sample_mask

In [None]:
# total_class_weight_normalisation = 1/4 * (
#     number_of_not_background + number_of_not_patient + number_of_not_brain + number_of_not_eyes)

# class_weights = [
#     number_of_not_background / total_class_weight_normalisation,
#     number_of_not_patient / total_class_weight_normalisation,
#     number_of_not_brain / total_class_weight_normalisation,
#     number_of_not_eyes / total_class_weight_normalisation
# ]

# class_weights

In [None]:
import tensorflow.keras.backend as K



In [None]:
model.compile(
    optimizer='adam',
    loss=weighted_cross_entropy,
    metrics=['accuracy']
)

In [None]:
def display(display_list):
    plt.figure(figsize=(18, 5))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])            
        plt.imshow(display_list[i])
        plt.colorbar()
        plt.axis('off')
        
    plt.show()
    
display([sample_image, sample_mask])

In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], pred_mask[0]])
    else:
        display(
            [
                sample_image, sample_mask,
                model.predict(sample_image[tf.newaxis, ...])[0]
            ]
        )
        
show_predictions()

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
model_history = model.fit(
    training_dataset, epochs=1000,
    steps_per_epoch=10,
    callbacks=[DisplayCallback()],
)