In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

import pymedphys
import pymedphys._wlutz.findfield
import pymedphys._wlutz.iview
import pymedphys._wlutz.imginterp
import pymedphys._wlutz.reporting

In [None]:
bb_diameter = 8
edge_length = [20, 24]
penumbra = 2

In [None]:
training_data_paths = pymedphys.zenodo_data_paths('wlutz_tensorflow_training_data')

In [None]:
image_paths = {path.stem: path for path in training_data_paths if path.suffix == '.png'}
labels_path = [path for path in training_data_paths if path.suffix == '.json'][0]

In [None]:
with open(labels_path, 'r') as labels_file:
    labels = json.load(labels_file)

In [None]:
keys = list(image_paths.keys())

In [None]:
training_keys = keys[0:100]

In [None]:
def load_and_regularise_data_and_labels(image_paths, all_labels, keys, dx=1/8, img_range=40):
    vec_about_zero = np.arange(-img_range/2, img_range/2+dx, dx)
    
    labels = []
    images = []    
    
    for key in tqdm(keys):
        label = all_labels[key]['pymedphys']
        
        if 'bb_centre' not in label.keys():
            continue
        
        image_path = image_paths[key]
        x, y, img = pymedphys._wlutz.iview.iview_image_transform(image_path)
        
        centre_of_mass = pymedphys._wlutz.findfield.get_centre_of_mass(x, y, img)
        field = pymedphys._wlutz.imginterp.create_interpolated_field(x, y, img)
        
        x_interp = vec_about_zero + centre_of_mass[0]
        y_interp = vec_about_zero + centre_of_mass[1]
        
        xx, yy = np.meshgrid(x_interp, y_interp)
        interpolated_image = field(xx, yy)
        
        field_centre = np.array(label['field_centre']) - np.array(centre_of_mass)
        field_rotation = label['field_rotation']
        bb_centre = np.array(label['bb_centre']) - np.array(centre_of_mass)
        
        labels.append([field_centre[0], field_centre[1], field_rotation, bb_centre[0], bb_centre[1]])
        images.append(interpolated_image)
    
    return np.array(images), np.array(labels)

In [None]:
images, labels = load_and_regularise_data_and_labels(image_paths, labels, training_keys)

In [None]:
np.shape(images)

In [None]:
np.shape(labels)

In [None]:
def create_masks