In [None]:
import os
import time
import datetime
import warnings

import tqdm

import IPython

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
tf.__version__

In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
tf.keras.backend.clear_session()

def convolve_block(filters, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(
            filters, 3, strides=1, padding='same',
            kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result


def Model():
    network_stack = [
        convolve_block(64, apply_batchnorm=False),
        convolve_block(128),
        convolve_block(256),
        convolve_block(512),
        convolve_block(512),
        convolve_block(512),
        convolve_block(512),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(8),
    ]

    inputs = tf.keras.layers.Input(shape=[None,None,1], batch_size=None)
    x = inputs

    for block in network_stack:
        x = block(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


model = Model()

tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)

In [None]:
# model.trainable_variables

In [None]:
from pymedphys._mocks import wlutz, profiles
from pymedphys._wlutz import reporting

In [None]:
def create_single_dataset(grid_size):

    while True:
        field_centre = tf.random.uniform((2,), -0.5, 0.5)
        field_side_lengths = tf.random.uniform((2,), 0.2, 1.5)
        field_penumbra = tf.random.uniform((), 0.05, 0.2).numpy()
        field_rotation_raw = tf.random.uniform((), -1, 1).numpy()
        field_rotation = field_rotation_raw * 180

        bb_centre = tf.random.uniform((2,), -0.5, 0.5)
        bb_diameter = tf.random.uniform((), 0.05, 0.3).numpy()
        bb_max_attenuation = tf.random.uniform((), 0.1, 0.5).numpy()


        field = profiles.create_rectangular_field_function(
            field_centre, field_side_lengths, field_penumbra, field_rotation
        )
        bb_penumbra = field_penumbra / 3
        bb_attenuation_map = wlutz.create_bb_attenuation_func(
            bb_diameter, bb_penumbra, bb_max_attenuation
        )

        x = np.linspace(-1, 1, grid_size)
        xx, yy = np.meshgrid(x, x)

        without_bb = field(xx, yy)

        def field_with_bb(x, y):
            return field(x, y) * bb_attenuation_map(x - bb_centre[0], y - bb_centre[1])

        with_bb = field_with_bb(xx, yy)
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            log_mean_sqr_diff = np.log(np.mean((without_bb - with_bb)**2))
#         print(log_mean_sqr_diff)

        if log_mean_sqr_diff > -10:
            break
            
    parameters = tf.concat([field_centre, field_side_lengths, [field_rotation_raw], bb_centre, [bb_diameter]], 0)
    img = tf.convert_to_tensor(with_bb, dtype=tf.float32)
            
    return parameters, img

parameters, img = create_single_dataset(128)


plt.figure()
plt.pcolormesh(img)
plt.axis('equal')

In [None]:
img.shape[0]

In [None]:
def create_pipeline_dataset(batch_size):
    def dataset_generator():
        image_size_for_current_batch = tf.random.uniform((), 32, 128, dtype=tf.int32).numpy()
        for _ in range(batch_size):
            yield create_single_dataset(image_size_for_current_batch)

    dataset = tf.data.Dataset.from_generator(
        dataset_generator, 
        (tf.float32, tf.float32), 
        (tf.TensorShape([8]), tf.TensorShape([None, None]))
    )

    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [None]:
test_dataset = create_pipeline_dataset(2)

for parameters, img in test_dataset.take(3):    
    dim = img.shape
    print(dim)
    for i in range(dim[0]):        
        plt.figure()
        plt.pcolormesh(img[i, :, :])
        plt.axis('equal')
    plt.show()

In [None]:
def extract_parameters(tensor_parameters):
    parameters = {
        'field_centre': (tensor_parameters[0], tensor_parameters[1]),
        'field_side_lengths': (tensor_parameters[2], tensor_parameters[3]),
        'field_rotation': tensor_parameters[4] * 180,
        'bb_centre': (tensor_parameters[5], tensor_parameters[6]),
        'bb_diameter': tensor_parameters[7]
    }
    
    return parameters

In [None]:
def create_figure(image, field_centre, field_side_lengths, field_rotation, bb_centre, bb_diameter):
    dim = image.shape
    x = np.linspace(-1, 1, dim[0])
    y = x
    
    return reporting.image_analysis_figure(
        x, y, np.array(image),
        np.array(bb_centre), np.array(field_centre), np.array(field_rotation),
        bb_diameter, field_side_lengths, penumbra=0.2, units=''
    )

In [None]:
def results_figures(model, batch_images, batch_ground_truth_parameters):
    batch_dim = batch_images.shape
    num_batches = batch_dim[0]
    image_shape = dim[1]
    x = np.linspace(-1, 1, image_shape)
    y = x
    
    batch_predicted_parameters = model(batch_images, training=True)
    
    for i in range(num_batches):
        
        ground_truth_parameters = extract_parameters(batch_ground_truth_parameters[i, :])
        predicted_parameters = extract_parameters(batch_predicted_parameters[i, :])
    
        fig, axs = create_figure(batch_images[i,:,:], **ground_truth_parameters)
        axs[0,0].set_title("Ground Truth")

        fig, axs = create_figure(batch_images[i,:,:], **predicted_parameters)
        axs[0,0].set_title("Predicted")

        plt.show()

In [None]:
for parameters, img in test_dataset.take(1):
    results_figures(model, img, parameters)

In [None]:
# def extract_parameters(tensor_parameters):
#     parameters = {
#         'field_centre': (tensor_parameters[0], tensor_parameters[1]),
#         'field_side_lengths': (tensor_parameters[2], tensor_parameters[3]),
#         'field_rotation': tensor_parameters[4] * 180,
#         'bb_centre': (tensor_parameters[5], tensor_parameters[6]),
#         'bb_diameter': tensor_parameters[7]
#     }
    
#     return parameters

In [None]:
def determine_diff_rotation_and_flipped(predicted_parameters, ground_truth_parameters):
    """Account for the fact that flipped edge lengths is equivalent to a 90 degree rotation
    """
    
    predicted_rotations = predicted_parameters[:, 4] * 180
    ground_truth_rotations = ground_truth_parameters[:, 4] * 180
    
    predicted_field_side_lengths = tf.stack([predicted_parameters[:, 2], predicted_parameters[:, 3]])
    ground_truth_field_side_lengths = tf.stack([ground_truth_parameters[:, 2], ground_truth_parameters[:, 3]])
    
    diff_rotation = (predicted_rotations - ground_truth_rotations) % 180
    diff_rotation = tf.reduce_min(tf.stack([diff_rotation, 180 - diff_rotation]), axis=0)
    
    diff_field_side_lengths = tf.reduce_sum(
        tf.abs(predicted_field_side_lengths - ground_truth_field_side_lengths),
        axis=0
    )
    
#     print(ground_truth_field_side_lengths)
#     print(tf.reverse(
#             ground_truth_field_side_lengths, [0]))
    diff_field_side_lengths_flipped = tf.reduce_sum(
        tf.abs(predicted_field_side_lengths - tf.reverse(
            ground_truth_field_side_lengths, [0])),
        axis=0
    )
    
    diff_rotation_flipped = (predicted_rotations - ground_truth_rotations + 90) % 180
    diff_rotation_flipped = tf.reduce_min(
        tf.stack([diff_rotation_flipped, 180 - diff_rotation_flipped]), 
        axis=0
    )
    
#     print(diff_rotation)
#     print(diff_field_side_lengths)
#     print(diff_rotation_flipped)
#     print(diff_field_side_lengths_flipped)
    
    diff_rotation_and_field_sides = tf.reduce_min(tf.stack([
        diff_rotation + diff_field_side_lengths,
        diff_rotation_flipped + diff_field_side_lengths_flipped
    ]), axis=0)
    
    return diff_rotation_and_field_sides

In [None]:
def cost_function(predicted_parameters, ground_truth_parameters):
    
    diff_rotation_and_field_sides = determine_diff_rotation_and_flipped(
        predicted_parameters, ground_truth_parameters)
    
    remaining_predicted = tf.concat([predicted_parameters[:, 0:2], predicted_parameters[:, 5::]], axis=-1)
    remaining_ground_truth = tf.concat([ground_truth_parameters[:, 0:2], ground_truth_parameters[:, 5::]], axis=-1)
    
    remaining_diff = tf.abs(remaining_predicted - remaining_ground_truth)
    
    diff = tf.concat([diff_rotation_and_field_sides[:, None], remaining_diff], axis=-1)
    loss = tf.reduce_mean(diff, axis=-1)

    return loss

In [None]:
for parameters, img in test_dataset.take(1):
    print(parameters)
    print(cost_function(parameters, parameters))

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 model=model)

In [None]:
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function(experimental_relax_shapes=True)
def train_step(ground_truth_parameters, input_image, epoch):
    with tf.GradientTape() as tape:
        predicted_parameters = model(input_image, training=True)
        loss = cost_function(predicted_parameters, ground_truth_parameters)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    reduced_loss = tf.reduce_mean(loss)

    with summary_writer.as_default():
        tf.summary.scalar('loss', reduced_loss, step=epoch)

In [None]:
# IPython.display.clear_output?

In [None]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        for parameters, img in test_ds.take(1):
            results_figures(model, img, parameters)

        iters_per_epoch = 10
        for parameters, img in tqdm.tqdm(train_ds.take(iters_per_epoch), total=iters_per_epoch):
            train_step(parameters, img, epoch)
        
        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                            time.time()-start))
    checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
# tqdm.tqdm?

```bash
poetry run tensorboard --logdir examples/site-specific/cancer-care-associates/production/Winston\ Lutz/prototyping/tf_model/logs/
```

In [None]:
# LAMBDA = 1/30
EPOCHS = 150

In [None]:
test_dataset = create_pipeline_dataset(1)
train_dataset = create_pipeline_dataset(10)
fit(train_dataset, EPOCHS, test_dataset)