In [None]:
import json
import os
import time

import IPython

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path

import tensorflow as tf

tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [None]:
tf.__version__

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
import pymedphys._wlutz.findfield
import pymedphys._wlutz.iview
import pymedphys._wlutz.imginterp
import pymedphys._wlutz.reporting
import pymedphys._wlutz.interppoints

In [None]:
bb_diameter = 8 * 2
edge_lengths = np.array([20, 24]) * 2
penumbra = 2 * 2

In [None]:
training_data_paths = pymedphys.zenodo_data_paths('wlutz_tensorflow_training_data')

In [None]:
image_paths = {path.stem: path for path in training_data_paths if path.suffix == '.png'}
labels_path = [path for path in training_data_paths if path.suffix == '.json'][0]

In [None]:
with open(labels_path, 'r') as labels_file:
    all_labels = json.load(labels_file)

In [None]:
labels = {key: label['pymedphys'] for key, label in all_labels.items() if 'bb_centre' in label['pymedphys']}
keys = np.array(list(labels.keys()))
np.random.shuffle(keys)

In [None]:
split_a = len(keys) // 8
split_b = len(keys) // 4

validation_keys = keys[0:split_a]
test_keys = keys[split_a:split_b]
train_keys = keys[split_b::]

In [None]:
rect_dx = [-edge_lengths[0] / 2, 0, edge_lengths[0], 0, -edge_lengths[0]]
rect_dy = [-edge_lengths[1] / 2, edge_lengths[1], 0, -edge_lengths[1], 0]

draw_x = tf.convert_to_tensor(np.cumsum(rect_dx), dtype=tf.float32)
draw_y = tf.convert_to_tensor(np.cumsum(rect_dy), dtype=tf.float32)

coord = tf.range(0,128)

In [None]:
IMG_SIZE = 128

In [None]:
x = np.arange(0,IMG_SIZE)
y = np.arange(0,IMG_SIZE)

xx, yy = np.meshgrid(x, y)

dx = 1/16
x_expand = np.arange(-0.5 + dx/2, 127.5, dx)
y_expand = np.arange(-0.5 + dx/2, 127.5, dx)

xx_expand, yy_expand = np.meshgrid(x_expand, y_expand)

bb_radius_sqrd = (bb_diameter / 2)**2

In [None]:
def transform_to_abs(coords):
    return 63 - np.array(coords)*2

def transform_labels(label):
    field_rotation = label['field_rotation'] / 90
    field_centre = transform_to_abs(label['field_centre'])
    bb_centre = transform_to_abs(label['bb_centre'])
    encoding = [field_centre[0], field_centre[1], field_rotation, bb_centre[0], bb_centre[1]]
    
    return encoding

In [None]:
@tf.function
def reduce_expanded_mask(expanded_mask):
    expanded_mask = tf.dtypes.cast(expanded_mask, tf.float32)
    return tf.reduce_mean(tf.reduce_mean(tf.reshape(expanded_mask, (128, 16, 128, 16)), axis=1), axis=2)

In [None]:
@tf.function
def get_circle_mask(bb_centre):
    expanded_mask = (xx_expand - bb_centre[0])**2 + (yy_expand - bb_centre[1])**2 <= bb_radius_sqrd
    circle_mask = reduce_expanded_mask(expanded_mask)
    
    return circle_mask * 2 - 1

In [None]:
# tf.convert_to_tensor?

In [None]:
@tf.function
def get_transformation_matrix(field_centre, field_rotation):
    field_rotation_radians = field_rotation / 180 * np.pi
    sin = tf.math.sin(field_rotation_radians)
    cos = tf.math.cos(field_rotation_radians)
    x = field_centre[0]
    y = field_centre[1]
    rand = np.random.uniform(0,1)
    
    return tf.convert_to_tensor([[cos, sin, x], [-sin, cos, y], [0, 0, 1]], name=f"transformation_{rand}", dtype=tf.float32)


@tf.function
def apply_transform(xx, yy, transform):
    xx_flat = tf.reshape(xx, (-1,))
    yy_flat = tf.reshape(yy, (-1,))
    transformed = tf.matmul(transform, tf.stack([xx_flat, yy_flat, tf.ones_like(xx_flat, dtype=tf.float32)], axis=0))

    xx_transformed = transformed[0]
    yy_transformed = transformed[1]
    
    xx_transformed = tf.reshape(xx_transformed, xx.shape)
    yy_transformed = tf.reshape(yy_transformed, yy.shape)

    return xx_transformed, yy_transformed

In [None]:
@tf.function
def get_partial_rect_mask(field_centre, x1, x2, y1, y2):  
    m = (y2 - y1)/(x2 - x1)
    c = y1 - m * x1
    
    field_x = field_centre[0]
    field_y = field_centre[1]
    
    if (field_y <= field_x*m + c):
        rect_mask = yy_expand <= xx_expand*m + c
    else:
        rect_mask = yy_expand >= xx_expand*m + c
    
    return rect_mask

In [None]:
@tf.function
def get_rect_mask(field_centre, field_rotation):
    field_rotation = field_rotation
    
    transform = get_transformation_matrix(field_centre, field_rotation)
    transformed_x, transformed_y = apply_transform(draw_x, draw_y, transform)

    bounds_x = transformed_x[0:4]
    bounds_y = transformed_y[0:4]

    partial_masks = [
        get_partial_rect_mask(
            field_centre, bounds_x[i], bounds_x[(i + 1) % 4], bounds_y[i], bounds_y[(i + 1) % 4]
        )
        for i in range(4)]
    
    expanded_mask = (
        partial_masks[0] &
        partial_masks[1] &
        partial_masks[2] &
        partial_masks[3]
    )
    
    return reduce_expanded_mask(expanded_mask) * 2 - 1

In [None]:
def extract_items_from_encoding(encoding):   
    field_centre = [encoding[0], encoding[1]]
    field_rotation = encoding[2] * 90
    bb_centre = [encoding[3], encoding[4]]

    return field_centre, field_rotation, bb_centre


def extract_items_from_encodings(encodings):
    field_centres = []
    field_rotations = []
    bb_centres = []
    
    for encoding in encodings:
        field_centres.append([encoding[0,0,0], encoding[0,0,1]])
        field_rotations.append(encoding[0,0,2] * 90)
        bb_centres.append([encoding[0,0,3], encoding[0,0,4]])
    
    return field_centres, field_rotations, bb_centres

def decode(encoding):
    return create_mask(*extract_items_from_encoding(encoding))

In [None]:
def load(image_path, encoding):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image)

    dim = tf.shape(image)
    if dim[0] == 1024 and dim[1] == 1024:
        image = image[1::2, ::2, :]

    image = tf.image.central_crop(image, 0.25)
    image = tf.reverse(image, [1])
    image = tf.cast(image, tf.float32)

    image = 1 - (image / 127.5)
    
    encoding = tf.cast(encoding, dtype=tf.float32)
    
    mask = decode(encoding)
    
    return image[None, ...], mask[None, ...], encoding


def get_dataset(keys, image_paths, labels):
    image_paths_array = np.array([str(image_paths[key]) for key in keys])
    labels_array = np.array([transform_labels(labels[key]) for key in keys])

    dataset = tf.data.Dataset.from_tensor_slices((image_paths_array, labels_array))
    dataset = dataset.map(load)
    dataset = dataset.shuffle(400)
    
    return dataset

In [None]:
def create_mask(field_centre, field_rotation, bb_centre):
    
    circle_mask = get_circle_mask(bb_centre)
    rect_mask = get_rect_mask(field_centre, field_rotation)
    
    mask = tf.concat([circle_mask[:,:,None], rect_mask[:,:,None]], axis=2)
    
    return mask


train_dataset = get_dataset(train_keys, image_paths, labels)
test_dataset = get_dataset(test_keys, image_paths, labels)
# train_dataset.batch(1)

In [None]:
for item in train_dataset.take(1):
    sample_image, sample_mask, sample_encoding = item

In [None]:
# tf.keras.layers.Dense(5, activation=tf.keras.layers.LeakyReLU)

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(
            filters, size, strides=2, padding='same',
            kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
@tf.function
def decode_batch(encoding):
    
    field_centre = [encoding[0, 0, 0, 0], encoding[0, 0, 0, 1]]
    field_rotation = encoding[0, 0, 0, 2] * 90
    bb_centre = [encoding[0, 0, 0, 3], encoding[0, 0, 0, 4]]

    circle_mask = get_circle_mask(bb_centre)
    rect_mask = get_rect_mask(field_centre, field_rotation)

    mask = tf.concat([circle_mask[None,:,:,None], rect_mask[None,:,:,None]], axis=3)
    
    return mask

In [None]:
# class DecodeToMask(tf.keras.layers.Layer):
#     def __init__(self, *args, input_dim=512, **kwargs):
#         super(DecodeToMask, self).__init__(*args, **kwargs)
#         units = 5
        
#         w_init = tf.random_normal_initializer()
#         self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units),
#                                                   dtype='float32'),
#                              trainable=True)
#         b_init = tf.zeros_initializer()
#         self.b = tf.Variable(initial_value=b_init(shape=(units,),
#                                                   dtype='float32'),
#                              trainable=True)
        
#     def call(self, inputs):
#         encodings = tf.matmul(inputs, self.w) + self.b
#         return decode_batch(encodings)
    
# #     def compute_output_shape(self, *args, **kwargs):
# #         return (1, 128, 128, 2)

In [None]:
# down_model = downsample(3, 4)
# down_result = down_model(tf.expand_dims(inp, 0))
# print (down_result.shape)

In [None]:
# tf.keras.layers.Dense?

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape=[128,128,1], batch_size=1)

    down_stack = [
        downsample(64, 4, apply_batchnorm=False), # (bs, 64, 64, 64)
        downsample(128, 4), # (bs, 32, 32, 128)
        downsample(256, 4), # (bs, 16, 16, 256)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)

    x = inputs

    # Downsampling through the model
    for down in down_stack:
        x = down(x)
        
    dense = tf.keras.layers.Dense(5)
    
#     decode_to_mask = DecodeToMask()
    x = dense(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


generator = Generator()

In [None]:
# generator.trainable_variables

In [None]:
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
gen_output = generator(sample_image, training=False)
print(gen_output)

In [None]:

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# loss_object?

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[128, 128, 1], name='input_image')
    tar = tf.keras.layers.Input(shape=[128, 128, 2], name='target_image')

    x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

    down1 = downsample(128, 4, False)(x) # (bs, 64, 64, 128)
    down2 = downsample(256, 4)(down1) # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down2) # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(
        512, 4, strides=1,
        kernel_initializer=initializer,
        use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(
        1, 4, strides=1,
        kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
disc_out = discriminator([sample_image, decode_batch(gen_output)], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):    
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
# for image, mask, encoding in train_dataset.take(10):
#     field_centre, field_rotation, bb_centre = extract_items_from_encoding(encoding)
    
#     fig, axs = pymedphys._wlutz.reporting.image_analysis_figure(
#         x, y, np.array(image)[:,:,0],
#         np.array(bb_centre), np.array(field_centre), np.array(field_rotation),
#         bb_diameter, edge_lengths, penumbra, units=''
#     )
    
#     plt.show()

In [None]:
def generate_images(model, image, ground_truth_encoding):
    predicted_encoding = model(image, training=True)
    
    predicted_field_centre = [predicted_encoding[0, 0, 0, 0], predicted_encoding[0, 0, 0, 1]]
    predicted_field_rotation = predicted_encoding[0, 0, 0, 2] * 90
    predicted_bb_centre = [predicted_encoding[0, 0, 0, 3], predicted_encoding[0, 0, 0, 4]]
    
    ground_field_centre = [ground_truth_encoding[0], ground_truth_encoding[1]]
    ground_field_rotation = ground_truth_encoding[2] * 90
    ground_bb_centre = [ground_truth_encoding[3], ground_truth_encoding[4]]
    
    fig, axs = pymedphys._wlutz.reporting.image_analysis_figure(
        x, y, np.array(image)[0,:,:,0],
        np.array(ground_bb_centre), np.array(ground_field_centre), np.array(ground_field_rotation),
        bb_diameter, edge_lengths, penumbra, units=''
    )
    axs[0,0].set_title("Ground Truth")
    
    fig, axs = pymedphys._wlutz.reporting.image_analysis_figure(
        x, y, np.array(image)[0,:,:,0],
        np.array(predicted_bb_centre), np.array(predicted_field_centre), np.array(predicted_field_rotation),
        bb_diameter, edge_lengths, penumbra, units=''
    )
    axs[0,0].set_title("Predicted")

    plt.show()

In [None]:
for image, mask, encoding in test_dataset.take(1):
    generate_images(generator, image, encoding)

In [None]:
import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function
def train_step(input_image, target_mask, target_encoding, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
        gen_output_image = decode_batch(gen_output)

        disc_real_output = discriminator([input_image, target_mask], training=True)
        disc_generated_output = discriminator([input_image, gen_output_image], training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target_encoding)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
        tf.summary.scalar('disc_loss', disc_loss, step=epoch)

In [None]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        IPython.display.clear_output(wait=True)

        for example_input, example_target_image, example_target_encoding in test_ds.take(1):
            generate_images(generator, example_input, example_target_encoding)
        print("Epoch: ", epoch)

        # Train
        for n, (input_image, target_mask, target_encoding) in train_ds.enumerate():
            print('.', end='')
            if (n+1) % 100 == 0:
                print()
            train_step(input_image, target_mask, target_encoding, epoch)
        print()

        # saving (checkpoint) the model every 20 epochs
        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                            time.time()-start))
    checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir {log_dir}

In [None]:
def generator_loss(disc_generated_output, gen_output, target):   
    rotation_target = target[2] * 90
    rotation_predicted = gen_output[0,0,0,2] * 90
    
    diff_rotation = (rotation_target - rotation_predicted) % 180
    diff_rotation = tf.reduce_min([diff_rotation, 180 - diff_rotation])
    reshaped_rotation_diff = tf.reshape(diff_rotation / 90, (-1,))
    
    positions_target = tf.concat((target[0:2], target[3::]), axis=0)
    positions_predicted = tf.concat((gen_output[0,0,0,0:2], gen_output[0,0,0,3::]), axis=-1)
    
    diff_positions = tf.abs(positions_target - positions_predicted)
    diff = tf.concat([diff_positions, reshaped_rotation_diff], axis=0)
    
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

    # mean absolute error
    l1_loss = tf.reduce_mean(diff)

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss, gan_loss, l1_loss

In [None]:
LAMBDA = 1/30
EPOCHS = 150

In [None]:
fit(train_dataset, EPOCHS, test_dataset)

In [None]:

# x = np.arange(0, IMG_SIZE)
# y = np.arange(0, IMG_SIZE)

# for image, mask, _ in train_dataset.take(1):


In [None]:
# x = np.arange(0, IMG_SIZE)
# y = np.arange(0, IMG_SIZE)

# for image, mask, encoding in train_dataset.take(10):
#     field_centre, field_rotation, bb_centre = extract_items_from_encoding(encoding)
    
#     fig, axs = pymedphys._wlutz.reporting.image_analysis_figure(
#         x, y, np.array(image)[:,:,0],
#         np.array(bb_centre), np.array(field_centre), np.array(field_rotation),
#         bb_diameter, edge_lengths, penumbra, units=''
#     )

#     plt.contour(x, y, mask[:,:,0], [0], cmap='bwr_r', zorder=20)
#     plt.contour(x, y, mask[:,:,1], [0], cmap='bwr_r', zorder=20)
    
#     plt.show()