In [None]:
import json

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path

import tensorflow as tf

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pymedphys
import pymedphys._wlutz.findfield
import pymedphys._wlutz.iview
import pymedphys._wlutz.imginterp
import pymedphys._wlutz.reporting
import pymedphys._wlutz.interppoints

In [None]:
bb_diameter = 8 * 2
edge_lengths = np.array([20, 24]) * 2
penumbra = 2 * 2

In [None]:
training_data_paths = pymedphys.zenodo_data_paths('wlutz_tensorflow_training_data')

In [None]:
image_paths = {path.stem: path for path in training_data_paths if path.suffix == '.png'}
labels_path = [path for path in training_data_paths if path.suffix == '.json'][0]

In [None]:
with open(labels_path, 'r') as labels_file:
    all_labels = json.load(labels_file)

In [None]:
labels = {key: label['pymedphys'] for key, label in all_labels.items() if 'bb_centre' in label['pymedphys']}
keys = np.array(list(labels.keys()))
np.random.shuffle(keys)

In [None]:
split_a = len(keys) // 8
split_b = len(keys) // 4

validation_keys = keys[0:split_a]
test_keys = keys[split_a:split_b]
train_keys = keys[split_b::]

In [None]:
key = keys[0]

image_path = str(image_paths[key])
label = labels[key]

# image = tf.io.read_file(image_path)
# image = tf.io.decode_png(image)

# tf.shape(image)

In [None]:
label

In [None]:
field_rotation = label['field_rotation']

In [None]:
image = tf.io.read_file(image_path)
image = tf.io.decode_png(image)

dim = tf.shape(image)
if dim[0] == 1024 and dim[1] == 1024:
    image = image[1::2, ::2, :]
    
image = tf.image.central_crop(image, 0.25)
image = tf.reverse(image, [1])
image = tf.cast(image, tf.float32)

image = 1 - (image / 127.5)


plt.imshow(image[:,:,0])
plt.colorbar()

In [None]:
[-0.5, -1]

In [None]:
# def transform_to_abs(coords):
#     return 63 - np.array(coords)*2


# field_centre = transform_to_abs(label['field_centre'])
# bb_centre = transform_to_abs(label['bb_centre'])

In [None]:
# dim[0] == 1024 and dim[1] == 1024

In [None]:
# np.array(pymedphys._wlutz.interppoints.translate_and_rotate_transform([10,20], 60))

In [None]:
# np.sin(60/180*np.pi)

In [None]:
# tf.constant([7, 9, 11], shape=[3])

In [None]:
# field_rotation_radians = field_rotation / 180 * np.pi
# sin = tf.math.sin(field_rotation_radians)
# cos = tf.math.cos(field_rotation_radians)
# x = field_centre[0]
# y = field_centre[1]

In [None]:
@tf.function
def load(image_path, encoding):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image)

    dim = tf.shape(image)
    if dim[0] == 1024 and dim[1] == 1024:
        image = image[1::2, ::2, :]

    image = tf.image.central_crop(image, 0.25)
    image = tf.reverse(image, [1])
    image = tf.cast(image, tf.float32)

    image = 1 - (image / 127.5)
    
    mask = decode(encoding)
    
    return image, mask, encoding

In [None]:
def transform_to_abs(coords):
    return 63 - np.array(coords)*2


def transform_labels(label):
    field_rotation = label['field_rotation']
    field_centre = transform_to_abs(label['field_centre'])
    bb_centre = transform_to_abs(label['bb_centre'])
    encoding = [field_centre[0], field_centre[1], field_rotation, bb_centre[0], bb_centre[1]]
    
    return encoding

In [None]:

def get_dataset(keys, image_paths, labels):
    image_paths_array = np.array([str(image_paths[key]) for key in keys])
    labels_array = np.array([transform_labels(labels[key]) for key in keys])

    dataset = tf.data.Dataset.from_tensor_slices((image_paths_array, labels_array))
    dataset = dataset.map(load)
    
    return dataset

In [None]:
rect_dx = [-edge_lengths[0] / 2, 0, edge_lengths[0], 0, -edge_lengths[0]]
rect_dy = [-edge_lengths[1] / 2, edge_lengths[1], 0, -edge_lengths[1], 0]

draw_x = np.cumsum(rect_dx)
draw_y = np.cumsum(rect_dy)

coord = tf.range(0,128)

In [None]:
draw_y

In [None]:
# # https://stackoverflow.com/a/31530106
# def extract_blocks(a, blocksize):
#     M,N = a.shape
#     b0, b1 = blocksize
    
#     a = tf.reshape(a, (M//b0,b0,N//b1,b1))
#     a = tf.transpose(a, (0,2,1))
#     a = tf.reshape(-1, b0, b1)
    
#     return a


In [None]:
IMG_SIZE = 128

@tf.function
def reduce_expanded_mask(expanded_mask):
    expanded_mask = tf.dtypes.cast(expanded_mask, tf.float32)
    return tf.reduce_mean(tf.reduce_mean(tf.reshape(expanded_mask, (128, 16, 128, 16)), axis=1), axis=2)

In [None]:
x = np.arange(0,IMG_SIZE)
y = np.arange(0,IMG_SIZE)

xx, yy = np.meshgrid(x, y)

dx = 1/16
x_expand = np.arange(-0.5 + dx/2, 127.5, dx)
y_expand = np.arange(-0.5 + dx/2, 127.5, dx)

xx_expand, yy_expand = np.meshgrid(x_expand, y_expand)

bb_radius_sqrd = (bb_diameter / 2)**2

@tf.function
def get_circle_mask(bb_centre):
    expanded_mask = (xx_expand - bb_centre[0])**2 + (yy_expand - bb_centre[1])**2 <= bb_radius_sqrd
    circle_mask = reduce_expanded_mask(expanded_mask)
    
    return circle_mask * 2 - 1


# circle_mask = get_circle_mask(bb_centre)

# plt.pcolormesh(x - .5, y - .5, circle_mask)
# plt.colorbar()
# plt.axis('equal')

In [None]:
def get_transformation_matrix(field_centre, field_rotation):
    field_rotation_radians = field_rotation / 180 * np.pi
    sin = tf.math.sin(field_rotation_radians)
    cos = tf.math.cos(field_rotation_radians)
    x = field_centre[0]
    y = field_centre[1]


    return tf.convert_to_tensor([[cos, -sin, x], [sin, cos, y], [0, 0, 1]])


def apply_transform(xx, yy, transform):
    xx_flat = np.ravel(xx)
    transformed = transform @ np.vstack([xx_flat, np.ravel(yy), np.ones_like(xx_flat)])

    xx_transformed = transformed[0]
    yy_transformed = transformed[1]
    
    xx_transformed = tf.reshape(xx_transformed, xx.shape)
    yy_transformed = tf.reshape(yy_transformed, yy.shape)

    return xx_transformed, yy_transformed

In [None]:
# transform = get_transformation_matrix(field_centre, field_rotation)
# transformed_x, transformed_y = apply_transform(draw_x, draw_y, transform)

# bounds_x = transformed_x[0:4]
# bounds_y = transformed_y[0:4]

# plt.plot(transformed_x, transformed_y)
# plt.axis('equal')

In [None]:
def get_partial_rect_mask(field_centre, x1, x2, y1, y2):
#     x1 = bounds_x[sort_indices[0]]
#     x2 = bounds_x[sort_indices[1]]

#     y1 = bounds_y[sort_indices[0]]
#     y2 = bounds_y[sort_indices[1]]
    
    m = (y2 - y1)/(x2 - x1)
    c = y1 - m * x1
    
    field_x = field_centre[0]
    field_y = field_centre[1]
    
    if (field_y <= field_x*m + c):
        rect_mask = yy_expand <= xx_expand*m + c
    else:
        rect_mask = yy_expand >= xx_expand*m + c
    
    return rect_mask

In [None]:
# partial_masks = [
#     get_partial_rect_mask(
#         field_centre, bounds_x[i], bounds_x[(i + 1) % 4], bounds_y[i], bounds_y[(i + 1) % 4]
#     )
#     for i in range(4)]


# for i in range(4):
#     plt.figure()
    
#     plt.pcolormesh(
#         xx_expand, 
#         yy_expand, 
#         partial_masks[i]
#     )
#     plt.plot(transformed_x, transformed_y)
#     plt.axis('equal')

In [None]:
def get_rect_mask(field_centre, field_rotation):
    field_rotation = 180 - field_rotation
    
    transform = get_transformation_matrix(field_centre, field_rotation)
    transformed_x, transformed_y = apply_transform(draw_x, draw_y, transform)

    bounds_x = transformed_x[0:4]
    bounds_y = transformed_y[0:4]

    partial_masks = [
        get_partial_rect_mask(
            field_centre, bounds_x[i], bounds_x[(i + 1) % 4], bounds_y[i], bounds_y[(i + 1) % 4]
        )
        for i in range(4)]
    
    expanded_mask = (
        partial_masks[0] &
        partial_masks[1] &
        partial_masks[2] &
        partial_masks[3]
    )
    
    return reduce_expanded_mask(expanded_mask) * 2 - 1

In [None]:
# rect_mask = get_rect_mask(field_centre, field_rotation)

# plt.pcolormesh(x - .5, y - .5, rect_mask)
# plt.axis('equal')
# plt.colorbar()

In [None]:
def extract_items_from_encoding(encoding):
    field_centre = [encoding[0], encoding[1]]
    field_rotation = encoding[2]
    bb_centre = [encoding[3], encoding[4]]
    
    return field_centre, field_rotation, bb_centre


@tf.function
def decode(encoding):    
    return create_mask(*extract_items_from_encoding(encoding))

In [None]:
@tf.function
def create_mask(field_centre, field_rotation, bb_centre):
    
    circle_mask = get_circle_mask(bb_centre)
    rect_mask = get_rect_mask(field_centre, field_rotation)
    
    mask = tf.concat([circle_mask[:,:,None], rect_mask[:,:,None]], axis=2)
    
    return mask


train_dataset = get_dataset(train_keys, image_paths, labels)

In [None]:
x = np.arange(0, IMG_SIZE)
y = np.arange(0, IMG_SIZE)

for image, mask, encoding in train_dataset.take(10):
    field_centre, field_rotation, bb_centre = extract_items_from_encoding(encoding)
    
    fig, axs = pymedphys._wlutz.reporting.image_analysis_figure(
        x, y, np.array(image)[:,:,0],
        np.array(bb_centre), np.array(field_centre), np.array(field_rotation),
        bb_diameter, edge_lengths, penumbra, units=''
    )

    plt.contour(x, y, mask[:,:,0], [0], cmap='bwr_r', zorder=20)
    plt.contour(x, y, mask[:,:,1], [0], cmap='bwr_r', zorder=20)
    
    plt.show()