In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path

from tqdm import tqdm

import pymedphys
import pymedphys._wlutz.findfield
import pymedphys._wlutz.iview
import pymedphys._wlutz.imginterp
import pymedphys._wlutz.reporting
import pymedphys._wlutz.interppoints

In [None]:
bb_diameter = 8
edge_lengths = [20, 24]
penumbra = 2

In [None]:
dx = 10 / 8 / 4
vec_about_zero = np.arange(-20,20, dx)
dim = (128, 128)

assert len(vec_about_zero) == dim[0] and len(vec_about_zero) == dim[1]

xx_about_zero, yy_about_zero = np.meshgrid(vec_about_zero, vec_about_zero)

In [None]:
training_data_paths = pymedphys.zenodo_data_paths('wlutz_tensorflow_training_data')

In [None]:
image_paths = {path.stem: path for path in training_data_paths if path.suffix == '.png'}
labels_path = [path for path in training_data_paths if path.suffix == '.json'][0]

In [None]:
with open(labels_path, 'r') as labels_file:
    all_labels = json.load(labels_file)

In [None]:
keys = list(image_paths.keys())

In [None]:
training_keys = keys[0:100]

In [None]:
def create_mask(field_centre, field_rotation, bb_centre):
    field_transform = pymedphys._wlutz.interppoints.translate_and_rotate_transform(field_centre, field_rotation)
    rect_dx = [-edge_lengths[0] / 2, 0, edge_lengths[0], 0, -edge_lengths[0]]
    rect_dy = [-edge_lengths[1] / 2, edge_lengths[1], 0, -edge_lengths[1], 0]

    draw_x = np.cumsum(rect_dx)
    draw_y = np.cumsum(rect_dy)

    rect_x, rect_y = pymedphys._wlutz.interppoints.apply_transform(draw_x, draw_y, field_transform)
    rect_points = list(zip(rect_x, rect_y))

    rectangle = matplotlib.path.Path(rect_points)

    points = np.swapaxes(np.vstack([xx_about_zero.ravel(), yy_about_zero.ravel()]), 0, 1)
    rectangle_mask = rectangle.contains_points(points).reshape(len(yy_about_zero), len(xx_about_zero))

    within_bb = np.sqrt((xx_about_zero - bb_centre[0])**2 + (yy_about_zero - bb_centre[1])**2) <= bb_diameter/2
    
    background = np.invert(rectangle_mask) & np.invert(within_bb)
    
    segmentation_mask = np.concatenate([background[:,:,None], rectangle_mask[:,:,None], within_bb[:,:,None]], axis=2)
    
    return segmentation_mask

In [None]:
def load_and_regularise_data_and_labels(image_paths, all_labels, keys):    
    masks = []
    images = []
    labels = []
    
    for key in tqdm(keys):
        label = all_labels[key]['pymedphys']
        
        if 'bb_centre' not in label.keys():
            continue
        
        image_path = image_paths[key]
        x, y, img = pymedphys._wlutz.iview.iview_image_transform(image_path)
        
        centre_of_mass = pymedphys._wlutz.findfield.get_centre_of_mass(x, y, img)
        field = pymedphys._wlutz.imginterp.create_interpolated_field(x, y, img)
        
        x_interp = vec_about_zero + centre_of_mass[0]
        y_interp = vec_about_zero + centre_of_mass[1]
        
        xx, yy = np.meshgrid(x_interp, y_interp)
        interpolated_image = field(xx, yy)
        
        field_centre = np.array(label['field_centre']) - np.array(centre_of_mass)
        field_rotation = label['field_rotation']
        bb_centre = np.array(label['bb_centre']) - np.array(centre_of_mass)
        
        mask = create_mask(field_centre, field_rotation, bb_centre)
        
        masks.append(mask)
        labels.append([field_centre[0], field_centre[1], field_rotation, bb_centre[0], bb_centre[1]])
        images.append(interpolated_image[:,:,None])
    
    return np.array(images), np.array(masks), np.array(labels)

In [None]:
images, masks, labels = load_and_regularise_data_and_labels(image_paths, all_labels, training_keys)

In [None]:
import tensorflow as tf

In [None]:
dataset = tf.data.Dataset.from_tensors((masks, images))

In [None]:
for image, mask in dataset.take(1):
    sample_image, sample_mask = image, mask

display([sample_image, sample_mask])

In [None]:
np.shape(sample_image)

In [None]:
np.shape(sample_mask)

In [None]:
i=20

sample_image = images[i]
sample_mask = masks[i]
# label = labels[i]

plt.figure(figsize=(10,10))

plt.contourf(xx_about_zero, yy_about_zero, image, 100)
plt.contour(xx_about_zero, yy_about_zero, mask, [0,1,2])
# plt.scatter(label[0], label[1])
# plt.scatter(label[3], label[4])

plt.axis('equal')

In [None]:
from pymedphys._vendor.tensorflow.pix2pix import downsample, upsample
import tensorflow as tf

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [None]:
output_channels = 3
norm_type="batchnorm"

down_stack = [
    downsample(64, 3, norm_type, apply_norm=False),  # (bs, 128, 128, 64)
    downsample(128, 3, norm_type),  # (bs, 64, 64, 128)
    downsample(256, 3, norm_type),  # (bs, 32, 32, 256)
    downsample(512, 3, norm_type),  # (bs, 16, 16, 512)
]

up_stack = [
    upsample(512, 3),  # 4x4 -> 8x8
    upsample(256, 3),  # 8x8 -> 16x16
    upsample(128, 3),  # 16x16 -> 32x32
    upsample(64, 3),   # 32x32 -> 64x64
]

initializer = tf.random_normal_initializer(0.0, 0.02)
last = tf.keras.layers.Conv2DTranspose(
    output_channels,
    4,
    strides=2,
    padding="same",
    kernel_initializer=initializer,
    activation="tanh",
)  # (bs, 256, 256, 3)

concat = tf.keras.layers.Concatenate()

inputs = tf.keras.layers.Input(shape=[128, 128, 3])
x = inputs

# Downsampling through the model
skips = []
for down in down_stack:
    x = down(x)
    skips.append(x)

skips = reversed(skips[:-1])

# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])

x = last(x)

model = tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model.summary()

In [None]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
                 create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [None]:
show_predictions()

In [None]:
# base_model = tf.keras.applications.MobileNetV2(input_shape=(96, 96, 3), include_top=True)

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=(96, 96, 3), include_top=False)

In [None]:
# tf.keras.utils.plot_model(base_model, show_shapes=True)

In [None]:
# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',  
    'block_3_expand_relu',   
    'block_6_expand_relu',  
    'block_13_expand_relu', 
    'block_16_project',    
]
layers = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = True

In [None]:
down_stack.summary()

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [None]:
def unet_model(output_channels):

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same', activation='softmax')

    inputs = tf.keras.layers.Input(shape=[96, 96, 3])
    x = inputs

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
OUTPUT_CHANNELS = 3
model = unet_model(OUTPUT_CHANNELS)

In [None]:
inputs = tf.keras.layers.Input(shape=[96, 96, 3])
x = inputs

inputs

In [None]:
skips = down_stack(x)
skips

In [None]:
skips = down_stack(x)
skips

In [None]:
np.array([4,8,16,32,64]) * 8

In [None]:

x = skips[-1]
skips = reversed(skips[:-1])

skips

In [None]:
pix2pix

In [None]:
layer_names = [layer.name for layer in base_model.layers]
layer_names

In [None]:
layer.name

In [None]:
down_stack = [
    pix2pix.downsample(64, 3),  # 8x8 -> 16x16
    pix2pix.downsample(128, 3),  # 4x4 -> 8x8
]

In [None]:
up_stack = [
    pix2pix.upsample(128, 3),  # 4x4 -> 8x8
    pix2pix.upsample(64, 3),  # 8x8 -> 16x16
]

In [None]:
def unet_model(output_channels):

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same', activation='softmax')  #16x16 -> 32x32

    inputs = tf.keras.layers.Input(shape=[32, 32, 3])
    x = inputs

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])
    
    for down, skip in zip(down_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
OUTPUT_CHANNELS = 3

In [None]:
model = unet_model(OUTPUT_CHANNELS)

In [None]:
32 * 4