In [None]:
import json

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path

import tensorflow as tf

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
import pymedphys._wlutz.findfield
import pymedphys._wlutz.iview
import pymedphys._wlutz.imginterp
import pymedphys._wlutz.reporting
import pymedphys._wlutz.interppoints

In [None]:
bb_diameter = 8 * 2
edge_lengths = np.array([20, 24]) * 2
penumbra = 2 * 2

In [None]:
training_data_paths = pymedphys.zenodo_data_paths('wlutz_tensorflow_training_data')

In [None]:
image_paths = {path.stem: path for path in training_data_paths if path.suffix == '.png'}
labels_path = [path for path in training_data_paths if path.suffix == '.json'][0]

In [None]:
with open(labels_path, 'r') as labels_file:
    all_labels = json.load(labels_file)

In [None]:
labels = {key: label['pymedphys'] for key, label in all_labels.items() if 'bb_centre' in label['pymedphys']}
keys = np.array(list(labels.keys()))
np.random.shuffle(keys)

In [None]:
split_a = len(keys) // 8
split_b = len(keys) // 4

validation_keys = keys[0:split_a]
test_keys = keys[split_a:split_b]
train_keys = keys[split_b::]

In [None]:
key = keys[0]

image_path = str(image_paths[key])
label = labels[key]

# image = tf.io.read_file(image_path)
# image = tf.io.decode_png(image)

# tf.shape(image)

In [None]:
label

In [None]:
field_rotation = label['field_rotation']

In [None]:
image = tf.io.read_file(image_path)
image = tf.io.decode_png(image)

dim = tf.shape(image)
if dim[0] == 1024 and dim[1] == 1024:
    image = image[1::2, ::2, :]
    
image = tf.image.central_crop(image, 0.25)
image = tf.reverse(image, [1])
image = tf.cast(image, tf.float32)

image = 1 - (image / 127.5)


plt.imshow(image[:,:,0])
plt.colorbar()

In [None]:
x = tf.range(0,128)
y = x

In [None]:
[-0.5, -1]

In [None]:
def transform_to_abs(coords):
    return 63 - np.array(coords)*2


field_centre = transform_to_abs(label['field_centre'])
bb_centre = transform_to_abs(label['bb_centre'])

In [None]:
dim[0] == 1024 and dim[1] == 1024

In [None]:
np.array(pymedphys._wlutz.interppoints.translate_and_rotate_transform([10,20], 60))

In [None]:
np.sin(60/180*np.pi)

In [None]:
tf.constant([7, 9, 11], shape=[3])

In [None]:
tf.math.sin(field_rotation_radians)

In [None]:
# field_rotation_radians = field_rotation / 180 * np.pi
# sin = tf.math.sin(field_rotation_radians)
# cos = tf.math.cos(field_rotation_radians)
# x = field_centre[0]
# y = field_centre[1]

In [None]:
def get_transformation_matrix(field_centre, field_rotation):
    field_rotation_radians = field_rotation / 180 * np.pi
    sin = tf.math.sin(field_rotation_radians)
    cos = tf.math.cos(field_rotation_radians)
    x = field_centre[0]
    y = field_centre[1]

    return tf.constant([(cos, -sin, x), (sin, cos, y), (0, 0, 1)])

In [None]:
def create_mask(field_centre, field_rotation, bb_centre):
    field_transform = get_transformation_matrix(field_centre, field_rotation)
    rect_dx = [-edge_lengths[0] / 2, 0, edge_lengths[0], 0, -edge_lengths[0]]
    rect_dy = [-edge_lengths[1] / 2, edge_lengths[1], 0, -edge_lengths[1], 0]

    draw_x = np.cumsum(rect_dx)
    draw_y = np.cumsum(rect_dy)

    rect_x, rect_y = pymedphys._wlutz.interppoints.apply_transform(draw_x, draw_y, field_transform)
    rect_points = list(zip(rect_x, rect_y))

    rectangle = matplotlib.path.Path(rect_points)

    points = np.swapaxes(np.vstack([xx_about_zero.ravel(), yy_about_zero.ravel()]), 0, 1)
    rectangle_mask = rectangle.contains_points(points).reshape(len(yy_about_zero), len(xx_about_zero))

    within_bb = np.sqrt((xx_about_zero - bb_centre[0])**2 + (yy_about_zero - bb_centre[1])**2) <= bb_diameter/2
    
    background = np.invert(rectangle_mask) & np.invert(within_bb)
    
    segmentation_mask = np.concatenate([background[:,:,None], rectangle_mask[:,:,None], within_bb[:,:,None]], axis=2)
    
    return segmentation_mask

In [None]:
def decode(tensor):
    field_centre = [tensor[0], tensor[1]]
    field_rotation = tensor[2]
    bb_centre = [tensor[3], tensor[4]]
    
    return create_mask(field_centre, field_rotation, bb_centre)

In [None]:
def load(image_path, encoding):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image)

    dim = tf.shape(image)
    if dim[0] == 1024 and dim[1] == 1024:
        image = image[1::2, ::2, :]

    image = tf.image.central_crop(image, 0.25)
    image = tf.reverse(image, [1])
    image = tf.cast(image, tf.float32)

    image = 1 - (image / 127.5)
    
    mask = decode(encoding)
    
    return image, mask

In [None]:
# train_labels = {key: labels[key] for key in train_keys}
# train_labels

In [None]:
def transform_to_abs(coords):
    return 63 - np.array(coords)*2


def transform_labels(label):
    field_rotation = label['field_rotation']
    field_centre = transform_to_abs(label['field_centre'])
    bb_centre = transform_to_abs(label['bb_centre'])
    encoding = [field_centre[0], field_centre[1], field_rotation, bb_centre[0], bb_centre[1]]
    
    return encoding

In [None]:
def get_dataset(keys, image_paths, labels):
    image_paths_array = np.array([str(image_paths[key]) for key in keys])
    labels_array = np.array([transform_labels(labels[key]) for key in keys])

    dataset = tf.data.Dataset.from_tensor_slices((image_paths_array, labels_array))
    dataset = dataset.map(load)

In [None]:
train_dataset = get_dataset(train_keys, image_paths, labels)

In [None]:
image_paths train_keys

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_keys,))
train_dataset

In [None]:
train_dataset = tf.data.Dataset.from_generator

In [None]:
# fig, axs = pymedphys._wlutz.reporting.image_analysis_figure(
#     x, y, np.array(image)[:,:,0],
#     bb_centre, field_centre, field_rotation,
#     bb_diameter, edge_lengths, penumbra, units=''
# )

# # plt.contourf(xx_about_zero, yy_about_zero, segmentation_mask, alpha=0.5, cmap='bwr')