In [None]:
import os
import time
import datetime
import warnings

import tqdm

import IPython

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
tf.__version__

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
# reversed?

In [None]:
tf.keras.backend.clear_session()

def convolve_block(filters, kernel, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(
            filters, kernel, strides=1, padding='same',
            kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result


def evacuation_hatch(x):
    evacuation_stages = [
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(8),
    ]
    for evac in evacuation_stages:
        x = evac(x)
        
    return x


def Model():
    down_blocks = [
        convolve_block(32, 3, apply_batchnorm=False),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
    ]
    fully_connected_blocks = [
        tf.keras.layers.Dense(32),
        tf.keras.layers.Dense(32),
        tf.keras.layers.Dense(32)
    ]
    up_blocks = [
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
        convolve_block(32, 3),
    ]
    
    inputs = tf.keras.layers.Input(shape=[None,None,1], batch_size=None)
    x = inputs
    
    skips = []
    evacs = []
    
    for down in down_blocks:
        x = down(x)
        skips.append(x)
        evacs.append(evacuation_hatch(x))

    skips = reversed(skips)
    
    for fc in fully_connected_blocks:
        init = x
        x = fc(x)
        x = tf.keras.layers.Add()([x, init])   
        evacs.append(evacuation_hatch(x))
        
    for up, skip in zip(up_blocks, skips):
        x = up(x)
        x = tf.keras.layers.Add()([x, skip])
        evacs.append(evacuation_hatch(x))
        
    x = tf.keras.layers.Concatenate()(evacs)
    x = tf.keras.layers.Dense(8)(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


model = Model()

tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)

In [None]:
model.summary()

In [None]:
# model.trainable_variables

In [None]:
from pymedphys._mocks import wlutz, profiles
from pymedphys._wlutz import reporting

In [None]:
def create_single_dataset(grid_size, fixed_vals=None):
    if fixed_vals is None:
        fixed_vals = {}
    
    while True:
        if not 'field_centre' in fixed_vals:
            field_centre = tf.random.uniform((2,), -0.5, 0.5)
        else:
            field_centre = tf.convert_to_tensor(fixed_vals['field_centre'], dtype=tf.float32)
        
        if not 'field_side_lengths' in fixed_vals:
            field_side_lengths = tf.random.uniform((2,), 0.2, 1.5)
        else:
            field_side_lengths = tf.convert_to_tensor(fixed_vals['field_side_lengths'], dtype=tf.float32)
            
        field_penumbra = tf.random.uniform((), 0.05, 0.2).numpy()
        
        if not 'field_rotation' in fixed_vals:
            field_rotation_raw = tf.random.uniform((), -1, 1).numpy()
            field_rotation = field_rotation_raw * 180
        else:
            field_rotation = tf.convert_to_tensor(fixed_vals['field_rotation'], dtype=tf.float32)
            field_rotation_raw = field_rotation / 180
            field_rotation = field_rotation.numpy()

        if not 'bb_centre' in fixed_vals:
            random_range = np.max(field_side_lengths.numpy()) / 2
            
            bb_centre = tf.random.uniform((2,), -random_range, random_range) + field_centre
        else:
            bb_centre = tf.convert_to_tensor(fixed_vals['bb_centre'], dtype=tf.float32)
            
        if not 'bb_diameter' in fixed_vals:
            bb_diameter = tf.random.uniform((), 0.05, 0.3).numpy()
        else:
            bb_diameter = tf.convert_to_tensor(fixed_vals['bb_diameter'], dtype=tf.float32).numpy()
            
        bb_max_attenuation = tf.random.uniform((), 0.2, 0.5).numpy()


        field = profiles.create_rectangular_field_function(
            field_centre, field_side_lengths, field_penumbra, field_rotation
        )
        bb_penumbra = field_penumbra / 3
        bb_attenuation_map = wlutz.create_bb_attenuation_func(
            bb_diameter, bb_penumbra, bb_max_attenuation
        )

        x = np.linspace(-1, 1, grid_size)
        xx, yy = np.meshgrid(x, x)

        without_bb = field(xx, yy)

        def field_with_bb(x, y):
            return field(x, y) * bb_attenuation_map(x - bb_centre[0], y - bb_centre[1])

        with_bb = field_with_bb(xx, yy)
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            log_mean_sqr_diff = np.log(np.mean((without_bb - with_bb)**2))
#         print(log_mean_sqr_diff)

        if log_mean_sqr_diff > -10:
            break
            
    parameters = tf.concat([field_centre, field_side_lengths, [field_rotation_raw], bb_centre, [bb_diameter]], 0)
    img = tf.convert_to_tensor(with_bb, dtype=tf.float32)
            
    return parameters, img

fixed_vals = {
#     'field_centre': [0, 0],
#     'field_side_lengths': [1, 1],
#     'field_rotation': 0,
#     'bb_centre': [0, 0],
#     'bb_diameter': 0.5
}

parameters, img = create_single_dataset(128, fixed_vals)


plt.figure()
plt.pcolormesh(img)
plt.axis('equal')

In [None]:
img.shape[0]

In [None]:
def create_pipeline_dataset(batch_size, fixed_vals=None):
    if fixed_vals is None:
        fixed_vals = {}
        
    def dataset_generator():
        if 'image_size' in fixed_vals:
            image_size_for_current_batch = fixed_vals['image_size']
        else:
            image_size_for_current_batch = tf.random.uniform((), 32, 128, dtype=tf.int32).numpy()
        for _ in range(batch_size):
            yield create_single_dataset(image_size_for_current_batch, fixed_vals=fixed_vals)

    dataset = tf.data.Dataset.from_generator(
        dataset_generator, 
        (tf.float32, tf.float32), 
        (tf.TensorShape([8]), tf.TensorShape([None, None]))
    )

    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [None]:
for parameters, img in create_pipeline_dataset(2).take(3):    
    dim = img.shape
    print(dim)
    for i in range(dim[0]):        
        plt.figure()
        plt.pcolormesh(img[i, :, :])
        plt.axis('equal')
    plt.show()

In [None]:
def extract_parameters(tensor_parameters):
    parameters = {
        'field_centre': (tensor_parameters[0], tensor_parameters[1]),
        'field_side_lengths': (tensor_parameters[2], tensor_parameters[3]),
        'field_rotation': tensor_parameters[4] * 180,
        'bb_centre': (tensor_parameters[5], tensor_parameters[6]),
        'bb_diameter': tensor_parameters[7]
    }
    
    return parameters

In [None]:
def create_figure(image, field_centre, field_side_lengths, field_rotation, bb_centre, bb_diameter):
    dim = image.shape
    x = np.linspace(-1, 1, dim[0])
    y = x
    
    return reporting.image_analysis_figure(
        x, y, np.array(image),
        np.array(bb_centre), np.array(field_centre), np.array(field_rotation),
        bb_diameter, field_side_lengths, penumbra=0.2, units=''
    )

In [None]:
def results_figures(model, batch_images, batch_ground_truth_parameters):
    batch_dim = batch_images.shape
    num_batches = batch_dim[0]
    image_shape = dim[1]
    x = np.linspace(-1, 1, image_shape)
    y = x
    
    batch_predicted_parameters = model(batch_images, training=True)
    
    for i in range(num_batches):
        
        ground_truth_parameters = extract_parameters(batch_ground_truth_parameters[i, :])
        predicted_parameters = extract_parameters(batch_predicted_parameters[i, :])
    
        fig, axs = create_figure(batch_images[i,:,:], **ground_truth_parameters)
        axs[0,0].set_title("Ground Truth")

        fig, axs = create_figure(batch_images[i,:,:], **predicted_parameters)
        axs[0,0].set_title("Predicted")

        plt.show()

In [None]:
fixed_vals = {
#     'field_centre': [0, 0],
#     'field_side_lengths': [1.5, 1],
#     'field_rotation': 45,
#     'bb_centre': [0, 0],
#     'bb_diameter': 0.4,
#     'image_size': 16,
}


for parameters, img in create_pipeline_dataset(1, fixed_vals).take(1):
    dim = img.shape
    print(dim)
    for i in range(dim[0]):        
        plt.figure()
        plt.pcolormesh(img[i, :, :])
        plt.axis('equal')
    plt.show()
    
    
    results_figures(model, img, parameters)

In [None]:
def flip_parameters(parameters):
    new_parameters = tf.concat([
        parameters[:, 0:2],
        parameters[:, 3:4],
        parameters[:, 2:3],
        parameters[:, 4:5] + 0.5,
        parameters[:, 5::]
    ], axis=-1)
    
    return new_parameters


for parameters, img in create_pipeline_dataset(1).take(1):
    flipped_parameters = flip_parameters(parameters)
    
    print(parameters)
    print(flipped_parameters)
    
    batch_dim = img.shape
    num_batches = batch_dim[0]
    
    for i in range(num_batches):
        fig, axs = create_figure(img[i,:,:], **extract_parameters(parameters[i,:]))
        axs[0,0].set_title("Original")

        fig, axs = create_figure(img[i,:,:], **extract_parameters(flipped_parameters[i,:]))
        axs[0,0].set_title("Flipped")

In [None]:
def determine_rotation_diff(predicted_parameters, ground_truth_parameters):
    predicted_rotations = predicted_parameters[:, 4] * 180
    ground_truth_rotations = ground_truth_parameters[:, 4] * 180
    
    diff_rotation = (predicted_rotations - ground_truth_rotations) % 180
    diff_rotation = tf.reduce_min(tf.stack([diff_rotation, 180 - diff_rotation]), axis=0)
    
    return diff_rotation


def adjust_rotation(parameters, degrees):
    new_parameters = tf.concat([
        parameters[:,0:4],
        parameters[:,4:5] + degrees/180,
        parameters[:,5::]
    ], axis=-1)
    
    return new_parameters


for parameters, img in create_pipeline_dataset(4).take(1):    
    print(determine_rotation_diff(parameters, parameters))
    print(determine_rotation_diff(parameters, adjust_rotation(parameters, 1)))
    print(determine_rotation_diff(parameters, adjust_rotation(parameters, 90)))
    print(determine_rotation_diff(parameters, adjust_rotation(parameters, 180)))
    print(determine_rotation_diff(parameters, adjust_rotation(parameters, 181)))
    print(determine_rotation_diff(parameters, adjust_rotation(parameters, 179)))

In [None]:
def determine_diff_with_rotation(predicted_parameters, ground_truth_parameters):
    diff = tf.concat([
        predicted_parameters[:, 0:4] - ground_truth_parameters[:, 0:4],
        determine_rotation_diff(predicted_parameters, ground_truth_parameters)[:, None] / 180,
        predicted_parameters[:, 5::] - ground_truth_parameters[:, 5::],
    ], axis=-1)
    
    return diff


for parameters, img in create_pipeline_dataset(4).take(1):
    
    flipped_parameters = flip_parameters(parameters)
    print(determine_diff_with_rotation(parameters, parameters))
    print(determine_diff_with_rotation(parameters, flipped_parameters))
    
    print()
    
    print(determine_diff_with_rotation(parameters, parameters + 1))
    print(determine_diff_with_rotation(parameters, adjust_rotation(parameters, 1)))
    print(determine_diff_with_rotation(parameters, adjust_rotation(parameters, 90)))
    print(determine_diff_with_rotation(parameters, adjust_rotation(parameters, 180)))
    print(determine_diff_with_rotation(parameters, adjust_rotation(parameters, 181)))
    print(determine_diff_with_rotation(parameters, adjust_rotation(parameters, 179)))

In [None]:
def determine_corrected_diff(predicted_parameters, ground_truth_parameters):
    flipped_predicted_parameters = flip_parameters(predicted_parameters)
    
    non_flipped_diff = determine_diff_with_rotation(predicted_parameters, ground_truth_parameters)
    flipped_diff = determine_diff_with_rotation(flipped_predicted_parameters, ground_truth_parameters)
    
    non_flipped_sum_sqrd = tf.reduce_sum(
        non_flipped_diff**2, axis=-1
    )
    flipped_sum_sqrd = tf.reduce_sum(
        flipped_diff**2, axis=-1
    )
    
    minimisation_stack = tf.stack([
        non_flipped_sum_sqrd, flipped_sum_sqrd
    ])
#     print(minimisation_stack)
    
    min_index = tf.argmin(minimisation_stack, axis=0)
    a_range = tf.range(min_index.shape[0], dtype=tf.int64)
    min_index = tf.stack([
        min_index, a_range
    ], axis=-1)
    
    
    
#     print(min_index)
    
#     print(non_flipped_diff)
#     print(flipped_diff)
    
    both_diffs = tf.stack([
        non_flipped_diff, flipped_diff
    ])
#     print(both_diffs)
#     print(min_index)

    diff = tf.gather_nd(
        both_diffs, min_index
    )
    
    return diff


for parameters, img in create_pipeline_dataset(2).take(1):
    flipped_parameters = flip_parameters(parameters)
    print(determine_corrected_diff(parameters, parameters + 0.5))
#     print(determine_corrected_diff(parameters, flipped_parameters))
#     print(determine_corrected_diff(flipped_parameters, parameters))

In [None]:
def loss_function(predicted_parameters, ground_truth_parameters):
    diff = determine_corrected_diff(predicted_parameters, ground_truth_parameters)
    loss = tf.abs(diff)

    return loss


# for parameters, img in create_pipeline_dataset(1).take(1):
#     print(parameters)
#     print(loss_function(parameters, parameters))
#     print(loss_function(parameters, parameters + 1))
    
    
for parameters, img in create_pipeline_dataset(5).take(1):
    print(loss_function(parameters, parameters + 0.5))

In [None]:
optimizer = tf.keras.optimizers.Adam()

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 model=model)

In [None]:
BATCH_SIZE = 10

fixed_vals = {
#     'field_centre': [0, 0],
#     'field_side_lengths': [1, 1],
#     'field_rotation': 0,
# #     'bb_centre': [0, 0],
#     'bb_diameter': 0.4,
# #     'image_size': 32,
}

test_dataset = create_pipeline_dataset(1, fixed_vals)
train_dataset = create_pipeline_dataset(BATCH_SIZE, fixed_vals)


@tf.function(experimental_relax_shapes=True)
def train_step(ground_truth_parameters, input_image):
    with tf.GradientTape() as tape:
        predicted_parameters = model(input_image, training=True)
        loss = loss_function(predicted_parameters, ground_truth_parameters)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    reduced_loss = tf.reduce_mean(loss, axis=0)
    return reduced_loss
        
        
concrete_train_step = train_step.get_concrete_function(
    tf.TensorSpec(shape=[BATCH_SIZE, 8], dtype=tf.float32), 
    tf.TensorSpec(shape=[BATCH_SIZE, None, None], dtype=tf.float32)
)

In [None]:
# def extract_parameters(tensor_parameters):
#     parameters = {
#         'field_centre': (tensor_parameters[0], tensor_parameters[1]),
#         'field_side_lengths': (tensor_parameters[2], tensor_parameters[3]),
#         'field_rotation': tensor_parameters[4] * 180,
#         'bb_centre': (tensor_parameters[5], tensor_parameters[6]),
#         'bb_diameter': tensor_parameters[7]
#     }
    
#     return parameters

In [None]:
def fit(train_ds, epochs, test_ds):
    logging_index = tf.convert_to_tensor(0, dtype=tf.int64)
    
    for epoch in range(epochs):
        start = time.time()

        for parameters, img in test_ds.take(1):
            results_figures(model, img, parameters)

        iters_per_epoch = 10
        for parameters, img in tqdm.tqdm(train_ds.take(iters_per_epoch), total=iters_per_epoch):
            reduced_loss = concrete_train_step(parameters, img)
            
            logging_index += 1
            with summary_writer.as_default():
                tf.summary.scalar('field_centre_x', reduced_loss[0], step=logging_index)
                tf.summary.scalar('field_centre_y', reduced_loss[1], step=logging_index)
                tf.summary.scalar('field_side_length_a', reduced_loss[2], step=logging_index)
                tf.summary.scalar('field_side_length_b', reduced_loss[3], step=logging_index)
                tf.summary.scalar('field_rotation', reduced_loss[4], step=logging_index)
                tf.summary.scalar('bb_centre_x', reduced_loss[5], step=logging_index)
                tf.summary.scalar('bb_centre_y', reduced_loss[6], step=logging_index)
                tf.summary.scalar('bb_diameter', reduced_loss[7], step=logging_index)
            
        
        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                            time.time()-start))
    checkpoint.save(file_prefix = checkpoint_prefix)

```bash
poetry run tensorboard --logdir examples/site-specific/cancer-care-associates/production/Winston\ Lutz/prototyping/tf_model/logs/
```

In [None]:
EPOCHS = 150

log_dir="logs/"
summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

fit(train_dataset, EPOCHS, test_dataset)