In [None]:
import os
import time
import datetime
import warnings

import tqdm

import IPython

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
tf.__version__

The following may be helpful:

https://github.com/tensorflow/tensorflow/issues/6271#issuecomment-266893850
https://arxiv.org/pdf/1807.03146.pdf


In [None]:
# Makes it so any changes in pymedphys is automatically
# propagated into the notebook without needing a kernel reset.
from IPython.lib.deepreload import reload
%load_ext autoreload
%autoreload 2

In [None]:
GRID_SIZE = 1024
PIXELS_PER_UNIT = 512

In [None]:
# reverse

In [None]:
tf.keras.backend.clear_session()

initializer = tf.random_normal_initializer(0., 0.02)


def down_block(x, depth, num_convs, channels, pool, batch_norm=True):
    convolution_sequence = tf.keras.Sequential(name=f'down-convolution-d{depth}')
    if batch_norm:
        convolution_sequence.add(tf.keras.layers.BatchNormalization())
        
    convolution_sequence.add(
        tf.keras.layers.ReLU()
    )
    for i in range(num_convs):
        convolution_sequence.add(
            tf.keras.layers.Conv2D(
                channels, (3, 3), strides=1, padding='same',
                kernel_initializer=initializer, use_bias=False)
        )
        if i != num_convs - 1:
            if batch_norm:
                convolution_sequence.add(tf.keras.layers.BatchNormalization())
            convolution_sequence.add(
                tf.keras.layers.ReLU()
            )
    
    short_circuit_sequence = tf.keras.Sequential(name=f'down-short-circuit-d{depth}')
    short_circuit_sequence.add(
        tf.keras.layers.Conv2D(
            channels, (1, 1), strides=1, padding='same',
            kernel_initializer=tf.ones_initializer(), 
            use_bias=False, trainable=False)
    )
    
    x = tf.keras.layers.Add()(
        [convolution_sequence(x), short_circuit_sequence(x)]
    )
    
    unet_short_circuit = x
    
    if pool != 0:
        x = tf.keras.layers.AveragePooling2D((pool, pool), strides=None, padding='valid')(x)
        
    return x, unet_short_circuit
    
    
def fully_connected_block(x, input_size, internal_channels, output_channels):
    x = tf.keras.layers.Conv2D(
            internal_channels, 
            (input_size, input_size),
            strides=1,
            padding='valid',
            kernel_initializer=initializer,
            use_bias=False
    )(x)
    
    repeats = 2
    for _ in range(repeats):
        short_circuit = x
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Dense(internal_channels)(x)
        x = tf.keras.layers.Add()([x, short_circuit])
    
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Dense(input_size * input_size * output_channels)(x)
    
    x = tf.keras.layers.Reshape((input_size, input_size, output_channels))(x)
    
    return x
    

def up_block(x, unet_short_circuit, depth, num_convs, channels, up_scale):
    if up_scale != 0:
        x = tf.keras.layers.UpSampling2D(size=(up_scale, up_scale))(x)
        
    x = tf.keras.layers.Add()([x, unet_short_circuit])
    
    convolution_sequence = tf.keras.Sequential(name=f'up-convolution-d{depth}')
    convolution_sequence.add(tf.keras.layers.BatchNormalization())
    convolution_sequence.add(
        tf.keras.layers.ReLU()
    )
    for i in range(num_convs):
        convolution_sequence.add(
            tf.keras.layers.Conv2D(
                channels, (3, 3), strides=1, padding='same',
                kernel_initializer=initializer, use_bias=False)
        )
        if i != num_convs - 1:
            convolution_sequence.add(tf.keras.layers.BatchNormalization())
            convolution_sequence.add(
                tf.keras.layers.ReLU()
            )
    
    
    internal_short_circuit = tf.keras.Sequential(name=f'up-short-circuit-d{depth}')
    internal_short_circuit.add(
        tf.keras.layers.Conv2D(
            channels, (1, 1), strides=1, padding='same',
            kernel_initializer=tf.ones_initializer(), 
            use_bias=False, trainable=False)
    )
    
    x = tf.keras.layers.Add()(
        [convolution_sequence(x), internal_short_circuit(x)]
    )
    
    return x
    
    
def Model(grid_size=GRID_SIZE):
    down_block_params = [
        (0, (3, 12, 2)),  # BS, 1024, 1024,  3 --> BS, 512, 512, 12
        (1, (3, 12, 4)),  # BS,  512,  512, 12 --> BS, 128, 128, 12
        (2, (3, 12, 4)),  # BS,  128,  128, 12 --> BS,  32,  32, 12
        (3, (3, 12, 4)),  # BS,   32,   32, 12 --> BS,   8,   8, 12
        (4, (4, 24, 0)),  # BS,    8,    8, 12 --> BS,   8,   8, 24
    ]
    fully_connected_params = (8, 96, 24)
    up_block_params = [
        (4, (4, 12, 0)),  
        (3, (4, 12, 4)),  
        (2, (3, 12, 4)), 
        (1, (3, 12, 4)), 
        (0, (3,  2, 2)), 
    ]
    
    inputs = tf.keras.layers.Input(shape=[grid_size,grid_size,1], batch_size=None)
    x = inputs

    unet_short_circuits = []
    for depth, down_block_param in down_block_params:
        x, unet_short_circuit = down_block(x, depth, *down_block_param)
        unet_short_circuits.append(unet_short_circuit)
        
    x = fully_connected_block(x, *fully_connected_params)
    
    unet_short_circuits = reversed(unet_short_circuits)
    
    for unet_shot_circuit, (depth, up_block_param) in zip(unet_short_circuits, up_block_params):
        x = up_block(x, unet_shot_circuit, depth, *up_block_param)
    
        
    return tf.keras.Model(inputs=inputs, outputs=x)


model = Model()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0000000001),
    loss=tf.keras.losses.MeanAbsoluteError(),
    metrics=['accuracy']
)

tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)

In [None]:
model.summary()

In [None]:
from pymedphys._mocks import wlutz, profiles
from pymedphys._wlutz import reporting, interppoints

In [None]:
def create_single_dataset(grid_size, pixels_per_unit, include_params=True):
    bounding_pixel_centre_val_mm = (grid_size / 2 - 0.5) / pixels_per_unit   

    field_centre = np.random.uniform(-0.2, 0.2, size=2)
    field_side_lengths = np.exp(np.random.uniform(np.log(0.15), np.log(1.5), size=2))
    field_penumbra = np.random.uniform(0.01, 0.05)

    field_rotation = np.random.uniform(-180, 180)

    transform = interppoints.translate_and_rotate_transform(field_centre, field_rotation)

    bb_centre_before_transform = [
        np.random.uniform(
            -field_side_lengths[0]/2, field_side_lengths[0]/2),
        np.random.uniform(
            -field_side_lengths[1]/2, field_side_lengths[1]/2)
    ]
    bb_centre = interppoints.apply_transform(*bb_centre_before_transform, transform)

    bb_diameter = tf.random.uniform((), 0.005, 0.1).numpy()
    bb_max_attenuation = tf.random.uniform((), 0.1, 0.5).numpy()


    field = profiles.create_rectangular_field_function(
        field_centre, field_side_lengths, field_penumbra, field_rotation
    )
    bb_penumbra = field_penumbra / 3
    bb_attenuation_map = wlutz.create_bb_attenuation_func(
        bb_diameter, bb_penumbra, bb_max_attenuation
    )

    x = np.linspace(
        -bounding_pixel_centre_val_mm,
        bounding_pixel_centre_val_mm,
        grid_size
    )
    xx, yy = np.meshgrid(x, x)

    without_bb = field(xx, yy)

    def field_with_bb(x, y):
        return field(x, y) * bb_attenuation_map(x - bb_centre[0], y - bb_centre[1])

    img = field_with_bb(xx, yy)    
    parameters = tf.concat([field_centre, field_side_lengths, [field_rotation], bb_centre, [bb_diameter]], 0)
    
    xx = tf.convert_to_tensor(xx, dtype=tf.float32)
    yy = tf.convert_to_tensor(yy, dtype=tf.float32)
    img = tf.convert_to_tensor(img, dtype=tf.float32)
    
    model_input = tf.stack([xx, yy, img], axis=-1)
    
    xx_out = bb_centre[0] - xx
    yy_out = bb_centre[1] - yy
    
    model_output = tf.stack([xx_out, yy_out], axis=-1)
    
    if include_params:
        return model_input, model_output, parameters
    
    return model_input[:,:,2::], model_output


model_input, model_output, parameters = create_single_dataset(GRID_SIZE, PIXELS_PER_UNIT)

plt.figure(figsize=(10,10))
plt.pcolormesh(model_input[:,:,0], model_input[:,:,1], model_input[:,:,2])
plt.axis('equal')

plt.figure(figsize=(10,10))
plt.contourf(model_input[:,:,0], model_input[:,:,1], model_input[:,:,2], 100)
plt.contour(model_input[:,:,0], model_input[:,:,1], model_output[:,:,0], levels=[0])
plt.contour(model_input[:,:,0], model_input[:,:,1], model_output[:,:,1], levels=[0])
plt.axis('equal')

In [None]:
plt.pcolormesh(model_input[:,:,0], model_input[:,:,1], model_output[:,:,0])
plt.colorbar()
plt.axis('equal')

In [None]:
np.average(model_input[:,:,0] + model_output[:,:,0])

In [None]:
plt.pcolormesh(model_input[:,:,0], model_input[:,:,1], model_output[:,:,1])
plt.colorbar()
plt.axis('equal')

In [None]:
np.average(model_input[:,:,1] + model_output[:,:,1])

In [None]:
def create_pipeline_dataset(batch_size, grid_size=GRID_SIZE, pixels_per_unit=PIXELS_PER_UNIT, include_params=True):
    def dataset_generator():
        yield create_single_dataset(grid_size, pixels_per_unit, include_params=include_params)
        
    if include_params:
        generator_params = (
            (tf.float32, tf.float32, tf.float32), 
            (tf.TensorShape([grid_size, grid_size, 3]), tf.TensorShape([grid_size, grid_size, 2]), tf.TensorShape([8]))
        )
    else:
        generator_params = (
            (tf.float32, tf.float32), 
            (tf.TensorShape([grid_size, grid_size, 1]), tf.TensorShape([grid_size, grid_size, 2]))
        )

    dataset = tf.data.Dataset.from_generator(
        dataset_generator, *generator_params
    )

    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [None]:
def plot_raw_images(model_input, model_output):
    dim = model_input.shape
    for i in range(dim[0]):        
        plt.figure()
        plt.pcolormesh(model_input[i, :, :, 0], model_input[i, :, :, 1], model_input[i, :, :, 2])
        plt.colorbar()
        plt.axis('equal')
        
        plt.figure()
        plt.pcolormesh(model_input[i, :, :, 0], model_input[i, :, :, 1], model_output[i, :, :, 0])
        plt.colorbar()
        plt.axis('equal')
        
        plt.figure()
        plt.pcolormesh(model_input[i, :, :, 0], model_input[i, :, :, 1], model_output[i, :, :, 1])
        plt.colorbar()
        plt.axis('equal')
    plt.show()


for model_input, model_output, parameters in create_pipeline_dataset(1).take(2):
    plot_raw_images(model_input, model_output)

In [None]:
def extract_parameters(parameters):
    parameters = {
        'field_centre': (parameters[0], parameters[1]),
        'field_side_lengths': (parameters[2], parameters[3]),
        'field_rotation': parameters[4],
        'bb_centre': (parameters[5], parameters[6]),
        'bb_diameter': parameters[7]
    }
    
    return parameters

In [None]:
def create_figure(image, field_centre, field_side_lengths, field_rotation, bb_centre, bb_diameter):
    dim = image.shape
    
    return reporting.image_analysis_figure(
        np.array(image[0,:,0]), np.array(image[:,0,1]), np.array(image[:,:,2]),
        np.array(bb_centre), np.array(field_centre), np.array(field_rotation),
        bb_diameter, field_side_lengths, penumbra=0.03, units=''
    )

In [None]:
def results_figures(model, batch_model_inputs, batch_model_outputs, batch_parameters, predicted):
    batch_dim = batch_model_inputs.shape
    num_batches = batch_dim[0]
    
    for i in range(num_batches):
        parameters = extract_parameters(batch_parameters[i, :])
        
        ground_truth_bb_centre = (
            np.median(model_input[i,:,:,0] + batch_model_outputs[i,:,:,0]),
            np.median(model_input[i,:,:,1] + batch_model_outputs[i,:,:,1]),            
        )
        
        ground_truth_parameters = {
            **parameters,
            'bb_centre': ground_truth_bb_centre
        }
        
        predicted_bb_centre = (
            np.median(model_input[i,:,:,0] + predicted[i,:,:,0]),
            np.median(model_input[i,:,:,1] + predicted[i,:,:,1]),            
        )
        
        predicted_parameters = {
            **parameters,
            'bb_centre': predicted_bb_centre
        }
    
        fig, axs = create_figure(batch_model_inputs[i,:,:,:], **ground_truth_parameters)
        axs[0,0].set_title("Ground Truth")

        fig, axs = create_figure(batch_model_inputs[i,:,:,:], **predicted_parameters)
        axs[0,0].set_title("Predicted")

        plt.show()

In [None]:
def show_predictions():
    for model_input, model_output, parameters in create_pipeline_dataset(1).take(1): 
        predicted = model(model_input[:,:,:,2::], training=True)
        
        plot_raw_images(model_input, predicted)
        results_figures(model, model_input, model_output, parameters, predicted)
        
show_predictions()

In [None]:
# optimizer = tf.keras.optimizers.Adam()

# checkpoint_dir = './training_checkpoints'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# checkpoint = tf.train.Checkpoint(optimizer=optimizer,
#                                  model=model)

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        IPython.display.clear_output(wait=True)
        show_predictions()

In [None]:
logdir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

```bash
poetry run tensorboard --logdir examples/site-specific/cancer-care-associates/production/Winston\ Lutz/prototyping/tf_model/logs/
```

In [None]:
BATCH_SIZE = 10
EPOCHS = 10000
STEPS_PER_EPOCH = 2
VALIDATION_STEPS = 1

test_dataset = create_pipeline_dataset(1, include_params=False)
train_dataset = create_pipeline_dataset(BATCH_SIZE, include_params=False)


model_history = model.fit(
    train_dataset, epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    validation_data=test_dataset,
    callbacks=[DisplayCallback(), tensorboard_callback],
    use_multiprocessing=True,
    shuffle=False,
)