In [None]:
import zipfile
from urllib import request
import pathlib
import collections
import warnings
import random
import copy

import numpy as np
import matplotlib.pyplot as plt
import imageio

import IPython
import ipywidgets

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow_examples.models.pix2pix import pix2pix

In [None]:
data_path = pathlib.Path('data')

In [None]:
# url = 'https://github.com/pymedphys/data/releases/download/mini-lung/mini-lung-medical-decathlon.zip'
# filename = url.split('/')[-1]

# request.urlretrieve(url, filename)

# with zipfile.ZipFile(filename, 'r') as zip_ref:
#     zip_ref.extractall(data_path)

In [None]:
image_paths = sorted(data_path.glob('**/*_image.png'))

mask_paths = [
    path.parent.joinpath(path.name.replace('_image.png', '_mask.png'))
    for path in image_paths
]

In [None]:
image_mask_pairs = collections.defaultdict(lambda: [])

for image_path, mask_path in zip(image_paths, mask_paths):
    patient_label = image_path.parent.name
    
    image = imageio.imread(image_path)
    mask = imageio.imread(mask_path)
    
    image_mask_pairs[patient_label].append((image, mask))

In [None]:
def get_contours_from_mask(mask, contour_level=0):
    if np.max(mask) < contour_level:
        return []
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        fig, ax = plt.subplots()
        cs = ax.contour(range(mask.shape[0]), range(mask.shape[0]), mask, [contour_level])

    contours = [path.vertices for path in cs.collections[0].get_paths()]
    plt.close(fig)

    return contours

In [None]:
has_tumour_map = collections.defaultdict(lambda: [])
for patient_label, pairs in image_mask_pairs.items():
    for image, mask in pairs:
        has_tumour_map[patient_label].append(np.max(mask) >= 128)

In [None]:
tumour_to_slice_map = collections.defaultdict(lambda: collections.defaultdict(lambda: []))

for patient_label, tumour_slices in has_tumour_map.items():
    for i, has_tumour in enumerate(tumour_slices):
        tumour_to_slice_map[patient_label][has_tumour].append(i)

In [None]:
patient_labels = sorted(list(image_mask_pairs.keys()))

training = patient_labels[0:50]
test = patient_labels[50:60]
validation = patient_labels[60:]

In [None]:
def random_select_from_each_patient(patient_labels, tumour_class_probability):
    patient_labels_to_use = copy.copy(patient_labels)
    random.shuffle(patient_labels_to_use)
    
    images = []
    masks = []
    
    for patient_label in patient_labels_to_use:
        if random.uniform(0, 1) < tumour_class_probability:
            find_tumour = True
        else:
            find_tumour = False
            
        slice_to_use = random.choice(tumour_to_slice_map[patient_label][find_tumour])
        
        mask = image_mask_pairs[patient_label][slice_to_use][1]
        if find_tumour:
            assert np.max(mask) >= 128
        else:
            assert np.max(mask) < 128
        
        images.append(image_mask_pairs[patient_label][slice_to_use][0])
        masks.append(image_mask_pairs[patient_label][slice_to_use][1])
        
    return images, masks

In [None]:
def create_pipeline_dataset(patient_labels, batch_size, grid_size=128, tumour_class_probability=0.5):  
    def image_mask_generator():
        while True:
            images, masks = random_select_from_each_patient(
                patient_labels, tumour_class_probability)

            for image, mask in zip(images, masks):
                yield (
                tf.convert_to_tensor(image[:,:,None], dtype=tf.float32) / 255 * 2 - 1,
                tf.convert_to_tensor(mask[:,:,None], dtype=tf.float32) / 255 * 2 - 1
            )
    
    generator_params = (
        (tf.float32, tf.float32), 
        (tf.TensorShape([grid_size, grid_size, 1]), tf.TensorShape([grid_size, grid_size, 1]))
    )

    dataset = tf.data.Dataset.from_generator(
        image_mask_generator, *generator_params
    )

    dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    
    return dataset


num_images_per_patient = 5
batch_size = len(training) * num_images_per_patient

training_dataset = create_pipeline_dataset(training, batch_size)
validation_dataset = create_pipeline_dataset(validation, len(validation))

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [None]:
tumour_from_validation_dataset = create_pipeline_dataset(validation, 1, tumour_class_probability=1)

def show_a_prediction():
    for image, mask in tumour_from_validation_dataset.take(1):
        plt.figure(figsize=(18,10))
        plt.subplot(1, 3, 1)
        plt.imshow(image[0,:,:,0], vmin=-1, vmax=1)

        if 'model' in globals():
            predicted_mask = model.predict(image[0:1, :, :, 0:1])
            predicted_contours = get_contours_from_mask(predicted_mask[0,:,:,0], contour_level=0)
            for contour in predicted_contours:
                plt.plot(*contour.T, 'r', lw=2, alpha=0.5)
        
        contours = get_contours_from_mask(mask[0,:,:,0], contour_level=0)
        for contour in contours:
            plt.plot(*contour.T, 'b', lw=2, alpha=0.5)
        
                
        plt.title('Image')
                
        plt.subplot(1, 3, 2)
        plt.imshow(mask[0,:,:,0], vmin=-1, vmax=1)
        
        plt.title('Gold standard mask')
        
        if 'predicted_mask' in locals():
            plt.subplot(1, 3, 3)
            plt.imshow(predicted_mask[0,:,:,0], vmin=-1, vmax=1)
            
            plt.title('Predicted mask')
            
        plt.show()
        
    
show_a_prediction()

In [None]:
# base_model = tf.keras.applications.ResNet101V2(input_shape=[128, 128, 3], include_top=False)

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = False

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [None]:
def unet_model():
    inputs = tf.keras.layers.Input(shape=[128, 128, 1])
    x = inputs
    
    first = tf.keras.layers.Conv2D(
        3, 1, padding="same"
    )
    
    # Spread out to 3 channels to match backbone
    x = first(x)

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        1, 3, strides=2,
        padding='same')  #64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

## Train the model
Now, all that is left to do is to compile and train the model. The loss being used here is `losses.SparseCategoricalCrossentropy(from_logits=True)`. The reason to use this loss function is because the network is trying to assign each pixel a label, just like multi-class prediction. In the true segmentation mask, each pixel has either a {0,1,2}. The network here is outputting three channels. Essentially, each channel is trying to learn to predict a class, and `losses.SparseCategoricalCrossentropy(from_logits=True)` is the recommended loss for 
such a scenario. Using the output of the network, the label assigned to the pixel is the channel with the highest value. This is what the create_mask function is doing.

In [None]:
def weighted_dsc(y_true, y_pred):
    smooth = 1
    total_pixels = 128 * 128 * 1
    weights = total_pixels / (K.sum(y_true, axis=(0, 1, 2)) + smooth)
    intersection = K.sum(y_true * y_pred, axis=(0, 1, 2))
    union = K.sum(y_true + y_pred, axis=(0, 1, 2))
    w_intersection = K.sum(intersection * weights)
    w_union = K.sum(union * weights)
    
    return (2 * w_intersection + smooth) / (w_union + smooth)


def weighted_dsc_loss(y_true, y_pred):
    return 1 - weighted_dsc(y_true, y_pred)

In [None]:
K.clear_session()

model = unet_model()
model.compile(optimizer='adam',
              loss=weighted_dsc_loss,
              metrics=['accuracy'])

Have a quick look at the resulting model architecture:

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
show_a_prediction()

Let's observe how the model improves while it is training. To accomplish this task, a callback function is defined below. 

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        IPython.display.clear_output(wait=True)
        show_a_prediction()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
EPOCHS = 3
STEPS_PER_EPOCH = 1
VALIDATION_STEPS = 1

model_history = model.fit(training_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=validation_dataset,
                          callbacks=[DisplayCallback()])

In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()