In [None]:
import pathlib
import tensorflow as tf
import imageio

import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
from pymedphys._experimental.autosegmentation import unet

In [None]:
model = unet.unet(grid_size=64, output_channels=4)  # background, patient, brain, eyes
# tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
model

In [None]:
image_paths = [
    str(path) for path in pathlib.Path('data').glob('**/*_image.png')
]
np.random.shuffle(image_paths)

In [None]:
mask_paths = [
    f"{image_path.split('_')[0]}_mask.png"
    for image_path in image_paths
]
# mask_paths

In [None]:
def _normalise_mask(png_mask):
    normalised_mask = np.round(png_mask / 255).astype(bool)  # It would be nice if this wasn't needed
    categorical_mask = np.zeros_like(normalised_mask[:,:,0]).astype(np.uint8)  # 0 -> background
    categorical_mask[normalised_mask[:,:,2]] = 1  # patient
    categorical_mask[normalised_mask[:,:,1]] = 2  # brain
    categorical_mask[normalised_mask[:,:,0]] = 3  # eyes
    
    return categorical_mask[:,:,None]

In [None]:
png_mask = imageio.imread(mask_paths[0])
categorical_mask = _normalise_mask(png_mask)
plt.imshow(png_mask)
plt.show()
plt.imshow(categorical_mask)
plt.colorbar()

In [None]:
categorical_mask.shape

In [None]:
def _normalise_image(png_image):
    normalised_image = png_image[:,:,None] / 255
    return normalised_image

In [None]:
input_array = _normalise_image(imageio.imread(image_paths[0]))
plt.imshow(input_array)

In [None]:
input_arrays = []
output_arrays = []
for image_path, mask_path in zip(image_paths, mask_paths):
    input_arrays.append(_normalise_image(imageio.imread(image_path)))
    output_arrays.append(_normalise_mask(imageio.imread(mask_path)))

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((input_arrays, output_arrays))

In [None]:
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

In [None]:
dataset = dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [None]:
model_history = model.fit(
    dataset, epochs=1,
    steps_per_epoch=20,
)