In [None]:
import zipfile
from urllib import request
import pathlib
import collections
import warnings
import random
import copy

import numpy as np
import matplotlib.pyplot as plt
import imageio

import IPython

import tensorflow as tf

import ipywidgets

In [None]:
# url = 'https://github.com/pymedphys/data/releases/download/mini-lung/mini-lung-medical-decathlon.zip'
# filename = url.split('/')[-1]

In [None]:
# request.urlretrieve(url, filename)

In [None]:
data_path = pathlib.Path('data')

In [None]:
# with zipfile.ZipFile(filename, 'r') as zip_ref:
#     zip_ref.extractall(data_path)

In [None]:
image_paths = sorted(data_path.glob('**/*_image.png'))

mask_paths = [
    path.parent.joinpath(path.name.replace('_image.png', '_mask.png'))
    for path in image_paths
]

In [None]:
image_mask_pairs = collections.defaultdict(lambda: [])

for image_path, mask_path in zip(image_paths, mask_paths):
    patient_label = image_path.parent.name
    
    image = imageio.imread(image_path)
    mask = imageio.imread(mask_path)
    
    image_mask_pairs[patient_label].append((image, mask))

In [None]:
def get_contours_from_mask(mask, contour_level=128):
    if np.max(mask) < contour_level:
        return []
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        fig, ax = plt.subplots()
        cs = ax.contour(range(mask.shape[0]), range(mask.shape[0]), mask, [contour_level])

    contours = [path.vertices for path in cs.collections[0].get_paths()]
    plt.close(fig)

    return contours

In [None]:
def display(patient_label, chosen_slice):
    image = image_mask_pairs[patient_label][chosen_slice][0]
    mask = image_mask_pairs[patient_label][chosen_slice][1]

    plt.figure(figsize=(10,10))
    plt.imshow(image, vmin=0, vmax=100)

    contours = get_contours_from_mask(mask)
    for contour in contours:
        plt.plot(*contour.T, 'r', lw=3)

In [None]:
def view_patient(patient_label):
    def view_slice(chosen_slice):
        display(patient_label, chosen_slice)
        
    number_of_slices = len(image_mask_pairs[patient_label])
    
    ipywidgets.interact(view_slice, chosen_slice=ipywidgets.IntSlider(min=0, max=number_of_slices, step=1, value=0));

In [None]:
patient_labels = sorted(list(image_mask_pairs.keys()))
# patient_labels

In [None]:
ipywidgets.interact(view_patient, patient_label=patient_labels);

In [None]:
has_tumour_map = collections.defaultdict(lambda: [])
for patient_label, pairs in image_mask_pairs.items():
    for image, mask in pairs:
        has_tumour_map[patient_label].append(np.max(mask) >= 128)

In [None]:
tumour_to_slice_map = collections.defaultdict(lambda: collections.defaultdict(lambda: []))

for patient_label, tumour_slices in has_tumour_map.items():
    for i, has_tumour in enumerate(tumour_slices):
        tumour_to_slice_map[patient_label][has_tumour].append(i)

In [None]:
training = patient_labels[0:50]
test = patient_labels[50:60]
validation = patient_labels[60:]

len(validation)

In [None]:
len(test)

In [None]:
num_images_per_patient = 5
batch_size = len(training) * num_images_per_patient
batch_size

In [None]:
random.uniform(0, 1)

In [None]:
# random.shuffle?

In [None]:
# tensor_image_mask_pairs = collections.defaultdict(lambda: [])

# for patient_label, pairs in image_mask_pairs.items():
#     for image, mask in pairs:
#         tensor_image_mask_pairs[patient_label].append((
#             tf.convert_to_tensor(image[:,:,None], dtype=tf.float32) / 255 * 2 - 1,
#             tf.convert_to_tensor(mask[:,:,None], dtype=tf.float32) / 255 * 2 - 1
#         ))

In [None]:
def random_select_from_each_patient(patient_labels, tumour_class_probability):
    patient_labels_to_use = copy.copy(patient_labels)
    random.shuffle(patient_labels_to_use)
    
    images = []
    masks = []
    
    for patient_label in patient_labels_to_use:
        if random.uniform(0, 1) < tumour_class_probability:
            find_tumour = True
        else:
            find_tumour = False
            
        slice_to_use = random.choice(tumour_to_slice_map[patient_label][find_tumour])
        
        mask = image_mask_pairs[patient_label][slice_to_use][1]
        if find_tumour:
            assert np.max(mask) >= 128
        else:
            assert np.max(mask) < 128
        
        images.append(image_mask_pairs[patient_label][slice_to_use][0])
        masks.append(image_mask_pairs[patient_label][slice_to_use][1])
        
    return images, masks

In [None]:
def create_pipeline_dataset(patient_labels, batch_size, grid_size=128):  
    def image_mask_generator():
        while True:
            images, masks = random_select_from_each_patient(
                patient_labels, tumour_class_probability=0.5)

            for image, mask in zip(images, masks):
                yield (
                tf.convert_to_tensor(image[:,:,None], dtype=tf.float32) / 255 * 2 - 1,
                tf.convert_to_tensor(mask[:,:,None], dtype=tf.float32) / 255 * 2 - 1
            )
    
    generator_params = (
        (tf.float32, tf.float32), 
        (tf.TensorShape([grid_size, grid_size, 1]), tf.TensorShape([grid_size, grid_size, 1]))
    )

    dataset = tf.data.Dataset.from_generator(
        image_mask_generator, *generator_params
    )

    dataset = dataset.batch(batch_size)
    
    return dataset

training_dataset = create_pipeline_dataset(training, batch_size)
validation_dataset = create_pipeline_dataset(validation, len(validation))

In [None]:
for image, mask in training_dataset.take(1):
    print(image.shape)
    print(mask.shape)

In [None]:
# random_select_from_each_patient()

In [None]:
# random_select_from_each_patient()

In [None]:
def display_first_of_batch(image, mask):
    plt.figure(figsize=(10,10))
    plt.imshow(image[0,:,:,0], vmin=-1, vmax=1)

    contours = get_contours_from_mask(mask[0,:,:,0], contour_level=0)
    for contour in contours:
        plt.plot(*contour.T, 'r', lw=3)
        

for image, mask in training_dataset.take(1):
    display_first_of_batch(image, mask)

In [None]:
def encode(x, convs, filters, kernel, drop=False, pool=True, norm=True):
    # Convolution
    for _ in range(convs):
        x = tf.keras.layers.Conv2D(
            filters, kernel, padding="same", kernel_initializer="he_normal"
        )(x)
        if norm is True:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

    # Skips
    skip = x

    # Regularise and down-sample
    if drop is True:
        x = tf.keras.layers.Dropout(0.2)(x)
    if pool is True:
        x = tf.keras.layers.Conv2D(
            filters,
            kernel,
            strides=2,
            padding="same",
            kernel_initializer="he_normal",
        )(x)
        if norm is True:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)

    return x, skip


def decode(x, skip, convs, filters, kernel, drop=False, norm=False):
    # Up-convolution
    x = tf.keras.layers.Conv2DTranspose(
        filters, kernel, strides=2, padding="same", kernel_initializer="he_normal"
    )(x)

    if norm is True:
        x = tf.keras.layers.BatchNormalization()(x)

    x = tf.keras.layers.Activation("relu")(x)

    # Concat with skip
    x = tf.keras.layers.concatenate([skip, x], axis=3)

    # Convolution
    for _ in range(convs):
        x = tf.keras.layers.Conv2D(
            filters, kernel, padding="same", kernel_initializer="he_normal"
        )(x)
        if norm is True:
            x = tf.keras.layers.BatchNormalization()(x)

        x = tf.keras.layers.Activation("relu")(x)

    if drop is True:
        x = tf.keras.layers.Dropout(0.2)(x)

    return x


def create_network(grid_size=128, output_channels=1):
    inputs = tf.keras.layers.Input((grid_size, grid_size, 1))

    encoder_args = [
        # convs, filter, kernel, drop, pool, norm
        (2, 32, 3, False, True, True),  # 64, 32
        (2, 64, 3, False, True, True),  # 32, 64
        (2, 128, 3, False, True, True),  # 16, 128
        (2, 256, 3, False, True, True),  # 8, 256
    ]

    decoder_args = [
        # convs, filter, kernel, drop, norm
        (2, 128, 3, True, True),  # 16, 512
        (2, 64, 3, True, True),  # 32, 256
        (2, 32, 3, False, True),  # 64, 128
        (2, 16, 3, False, True),  # 128, 64
    ]

    x = inputs
    skips = []

    for args in encoder_args:
        x, skip = encode(x, *args)
        skips.append(skip)

    skips.reverse()

    for skip, args in zip(skips, decoder_args):
        x = decode(x, skip, *args)
        
    outputs = tf.keras.layers.Conv2D(
        output_channels,
        1,
        activation="sigmoid",
        padding="same",
        kernel_initializer="he_normal",
    )
    
    x = outputs(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
tf.keras.backend.clear_session()

model = create_network()

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.MeanAbsoluteError(),
    metrics=['accuracy']
)

In [None]:
# tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)

In [None]:
# model.summary()

In [None]:
def show_a_prediction():
    for image, mask in training_dataset.take(10):
        plt.figure(figsize=(10,10))
        plt.imshow(image[0,:,:,0], vmin=-1, vmax=1)

        contours = get_contours_from_mask(mask[0,:,:,0], contour_level=0)
        for contour in contours:
            plt.plot(*contour.T, 'k--', lw=1)
            
        predicted_mask = model.predict(image[0:1, :, :, 0:1])
        predicted_contours = get_contours_from_mask(predicted_mask[0,:,:,0], contour_level=0)
        for contour in predicted_contours:
            plt.plot(*contour.T, 'r', lw=3)
            
        plt.show()
        
    
show_a_prediction()

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        IPython.display.clear_output(wait=True)
        show_a_prediction()

In [None]:
EPOCHS = 5
STEPS_PER_EPOCH = 1
VALIDATION_STEPS = 1

model_history = model.fit(
    training_dataset, epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    validation_data=validation_dataset,
    callbacks=[
        DisplayCallback(), 
#         tensorboard_callback
    ],
    use_multiprocessing=True,
    shuffle=False,
)