# CT scan UNet demo

This notebook creates a UNet for a minified dataset of animal CTs.

In [None]:
import pathlib
import urllib.request
import shutil
import collections
import random

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.backend as K
import imageio

In [None]:
zip_url = 'https://github.com/pymedphys/data/releases/download/mini-lung/mini-lung-medical-decathlon.zip'
zip_filepath = 'data.zip'

data_directory = pathlib.Path('data', 'unet', 'lung')

if not data_directory.exists():
    urllib.request.urlretrieve(zip_url, zip_filepath)
    shutil.unpack_archive(zip_filepath, data_directory)

In [None]:
directories = sorted(list(data_directory.glob('*')))

split = 4

training_directories = directories[:-split]
validation_directories = directories[-split:]

In [None]:
assert len(set(validation_directories).intersection(training_directories)) == 0

In [None]:
crop_slice = slice(32, -32)

def _load_image(image_path):
    png_image = imageio.imread(image_path)
    normalised_image = png_image[crop_slice,crop_slice,None] / 255
    
    return normalised_image


def _load_mask(mask_path):
    png_mask = imageio.imread(mask_path)
    normalised_mask = png_mask[None,crop_slice,crop_slice,None] / 255
    
    return normalised_mask

In [None]:
z_size = 8
mask_image_index = z_size // 2

In [None]:
def load_dataset_type(directories, shuffle=True):
    image_suffix = '_image.png'
    mask_suffix = '_mask.png'

    mask_paths = []
    for directory in directories:
        mask_paths += list(directory.glob(f'*{mask_suffix}'))
        
    if shuffle:
        np.random.shuffle(mask_paths)

    image_arrays = []
    mask_arrays = []
    for mask_path in mask_paths:
        mask = _load_mask(mask_path)
        if np.sum(mask) == 0:
            continue
            
        mask_number = int(mask_path.name.split('_')[0])
        start_image_slice = mask_number - mask_image_index
        end_image_slice = start_image_slice + z_size
        image_filenames = [
            str(slice_num).zfill(6) + image_suffix
            for slice_num in range(start_image_slice, end_image_slice)
        ]
        image_paths = [
            mask_path.parent.joinpath(filename)
            for filename in image_filenames
        ]
        try:
            images_for_mask = [
                _load_image(image_path)
                for image_path in image_paths
            ]
        except FileNotFoundError:
            continue

        mask_arrays.append(mask)
        image_arrays.append(images_for_mask)

    images = np.array(image_arrays)
    masks = np.array(mask_arrays)
    
    return images, masks

In [None]:
training_images, training_masks = load_dataset_type(training_directories)
validation_images, validation_masks = load_dataset_type(validation_directories, shuffle=False)

In [None]:
training_images.shape

In [None]:
training_masks.shape

In [None]:
i = 0
sample_image, sample_mask = validation_images[i,:,:,:], validation_masks[i,:,:,:]

In [None]:
def display(image, mask, prediction=None):
    plt.figure(figsize=(18, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Input Image')            
    plt.imshow(image[mask_image_index,:,:,0])
    plt.colorbar()
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.title('True Mask')            
    plt.imshow(mask[0,:,:,0])
    plt.colorbar()
    plt.axis('off')

    if prediction is None:
        try:
            prediction = model.predict(image[None, ...])[0,...]
        except NameError:
            return

    plt.subplot(1, 3, 3)
    plt.title('Predicted Mask')            
    plt.imshow(prediction[0,:,:,0])
    plt.colorbar()
    plt.axis('off')

    
    
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        display(sample_image, sample_mask)
        plt.show()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
    
    
display(sample_image, sample_mask)

In [None]:
def _activation(x):
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    return x


def _convolution(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv3D(
        number_of_filters, kernel_size, padding="same", kernel_initializer="he_normal"
    )(x)

    return x


def _conv_transpose(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv3DTranspose(
        number_of_filters,
        kernel_size,
        strides=2,
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    return x

In [None]:
def encode(
    x,
    number_of_filters,
    number_of_convolutions=2,
):
    for _ in range(number_of_convolutions):
        x = _convolution(x, number_of_filters)
        x = _activation(x)
    skip = x

    x = tf.keras.layers.MaxPool3D()(x)
    x = _activation(x)

    return x, skip

In [None]:
def decode(
    x,
    skip,
    number_of_filters,
    number_of_convolutions=2,
):
    x = _conv_transpose(x, number_of_filters)
    x = _activation(x)

    x = tf.keras.layers.concatenate([skip, x], axis=-1)

    for _ in range(number_of_convolutions):
        x = _convolution(x, number_of_filters)
        x = _activation(x)

    return x

In [None]:
mask_dims = training_masks.shape
image_dims = training_images.shape
assert mask_dims[2] == mask_dims[3]
grid_size = int(mask_dims[2])

assert grid_size == image_dims[2]
assert grid_size == image_dims[3]

assert z_size == image_dims[1]
output_channels = int(mask_dims[-1])

In [None]:
inputs = tf.keras.layers.Input((z_size, grid_size, grid_size, 1))
x = inputs
skips = []

for number_of_filters in [32, 64, 128]:
    x, skip = encode(x, number_of_filters)
    skips.append(skip)
    
skips.reverse()

for number_of_filters, skip in zip([256, 128, 64], skips):
    x = decode(x, skip, number_of_filters)

x = tf.keras.layers.MaxPool3D(
    pool_size=(z_size, 1, 1)
)(x)

x = tf.keras.layers.Conv3D(
    filters=1,
    kernel_size=1,
    activation="sigmoid",
    kernel_initializer="he_normal",
)(x)

model = tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model.summary()

In [None]:
balancing = len(np.ravel(training_masks)) / np.sum(training_masks) - 1
balancing

In [None]:
def weighted_bce(y_true, y_pred):
  weights = (y_true * balancing) + 1.
  bce = K.binary_crossentropy(y_true, y_pred)
  loss = K.mean(bce * weights)
  return loss

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=weighted_bce,
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision()
    ]
)

display(sample_image, sample_mask)

In [None]:
history = model.fit(
    training_images, 
    training_masks,
    epochs=100,
    validation_data=(validation_images, validation_masks),
    callbacks=[DisplayCallback()]
)

In [None]:
predictions = model.predict(validation_images)
image_combos = list(zip(validation_images, validation_masks, predictions))

In [None]:
num_val_images = validation_images.shape[0]
random_selection = np.array(random.sample(range(num_val_images), 30))

for i in random_selection:
    print(i)
    image = validation_images[i]
    mask = validation_masks[i]
    prediction = predictions[i]
    display(image, mask, prediction)
    plt.show()