In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib
import urllib.request

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms
import scipy.ndimage.measurements
import scipy.interpolate
import scipy.optimize

import imageio

import pymedphys
import pymedphys._mocks.profiles
import pymedphys._wlutz.findfield
import pymedphys._wlutz.createaxis
import pymedphys._wlutz.interppoints
import pymedphys._wlutz.imginterp

In [None]:
image_path = pymedphys.data_path('wlutz_image.png')
image_path

In [None]:
img = imageio.imread(image_path)
assert np.shape(img) == (1024, 1024)
img = img[:, 1:-1]
assert np.shape(img) == (1024, 1022)
assert img.dtype == np.dtype('uint16')
img = 1 - img[::-1,:] / 2**16
assert img.dtype == np.dtype('float64')

In [None]:
shape = np.shape(img)
x = np.arange(-shape[1]/2, shape[1]/2)/4
y = np.arange(-shape[0]/2, shape[0]/2)/4

In [None]:
# def img_centre_and_rotation(x, y, img, edge_lengths, penumbra=2, initial_rotation=0, rounding=True):
#     field = pymedphys._wlutz.imginterp.create_interpolated_field(x, y, img)
#     initial_centre = pymedphys._wlutz.findfield._initial_centre(x, y, img)
    
#     (
#         centre, rotation
#     ) = pymedphys._wlutz.findfield.field_centre_and_rotation_refining(
#         field, edge_lengths, penumbra, initial_centre, initial_rotation=initial_rotation
#     )
    
#     if rounding:
#         centre = np.round(centre, decimals=2).tolist()
#         rotation = np.round(rotation, decimals=1)
    
#     return centre, rotation

In [None]:
edge_lengths = [20, 20]

centre, rotation = pymedphys._wlutz.findfield.find_centre_and_rotation(x, y, img, edge_lengths)
centre, rotation

In [None]:
centre_rounded, rotation_rounded = img_centre_and_rotation(x, y, img, edge_lengths, rounding=False)
centre_rounded, rotation_rounded

In [None]:
centre, rotation

In [None]:
def draw_by_diff(dx, dy, transform):
    draw_x = np.cumsum(dx)
    draw_y = np.cumsum(dy)

    draw_x, draw_y = pymedphys._wlutz.interppoints.apply_transform(draw_x, draw_y, transform)
    
    return draw_x, draw_y

In [None]:
transform = matplotlib.transforms.Affine2D()
transform.rotate_deg(-rotation)
transform.translate(*centre)

In [None]:
rotation_x_points = np.linspace(-edge_lengths[0]/2, edge_lengths[0]/2, 51)
rotation_y_points = np.linspace(-edge_lengths[1]/2, edge_lengths[1]/2, 61)

rot_xx_points, rot_yy_points = np.meshgrid(rotation_x_points, rotation_y_points)

rot_xx_points, rot_yy_points = pymedphys._wlutz.interppoints.apply_transform(rot_xx_points, rot_yy_points, transform)

In [None]:
rotation_points = pymedphys._wlutz.interppoints.define_rotation_field_points(centre, [20, 20], 2, rotation)

In [None]:
rect_dx = [-edge_lengths[0]/2, 0, edge_lengths[0], 0, -edge_lengths[0]]
rect_dy = [-edge_lengths[1]/2, edge_lengths[1], 0, -edge_lengths[1], 0]

rect_crosshair_dx = [-edge_lengths[0]/2, edge_lengths[0], -edge_lengths[0], edge_lengths[0]]
rect_crosshair_dy = [-edge_lengths[1]/2, edge_lengths[1], 0, -edge_lengths[1]]

plt.figure(figsize=(10,10))
plt.pcolormesh(x, y, img)
plt.plot(*draw_by_diff(rect_dx, rect_dy, transform), 'k', lw=2)
plt.plot(*draw_by_diff(rect_crosshair_dx, rect_crosshair_dy, transform), 'k', lw=0.5)

# plt.plot(rot_xx_points, rot_yy_points, '.')

plt.scatter(centre[0], centre[1], c='r', s=1)

plt.scatter(rotation_points[0], rotation_points[1], s=1)

plt.axis('equal')
plt.xlim([-20, 20])
plt.ylim([-20, 20])

In [None]:
rect_dx = [-edge_lengths[0]/2, 0, edge_lengths[0], 0, -edge_lengths[0]]
rect_dy = [-edge_lengths[1]/2, edge_lengths[1], 0, -edge_lengths[1], 0]

rect_crosshair_dx = [-edge_lengths[0]/2, edge_lengths[0], -edge_lengths[0], edge_lengths[0]]
rect_crosshair_dy = [-edge_lengths[1]/2, edge_lengths[1], 0, -edge_lengths[1]]

plt.figure(figsize=(10,10))
plt.pcolormesh(x, y, img)
plt.plot(*draw_by_diff(rect_dx, rect_dy, transform), 'k', lw=2)
plt.plot(*draw_by_diff(rect_crosshair_dx, rect_crosshair_dy, transform), 'k', lw=0.5)

# plt.plot(rot_xx_points, rot_yy_points, '.')

plt.scatter(centre[0], centre[1], c='r', s=1)

plt.axis('equal')
plt.xlim([-20, 20])
plt.ylim([-20, 20])

In [None]:
plt.contourf(x, y, img, 30)
plt.axis('equal')
plt.xlim([-25, 25])
plt.ylim([-25, 25])
plt.colorbar()

In [None]:
xx, yy = np.meshgrid(x, y)

In [None]:
plt.contourf(x, y, field(xx, yy), 30)
plt.axis('equal')
plt.xlim([-25, 25])
plt.ylim([-25, 25])
plt.colorbar()

In [None]:
assert np.all(field(xx, yy) == img)