In [10]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
from jax.scipy.interpolate import RegularGridInterpolator

import jaxgym.components as comp
from jaxgym.stemoverfocus import compute_fourdstem_dataset

from scipy.ndimage import rotate
from scipy.ndimage import zoom, center_of_mass
import json

from scipy.optimize import curve_fit
%matplotlib widget
jax.config.update('jax_platform_name', 'cpu')


In [11]:
from libertem_ui.windows.imaging import VirtualDetectorWindow
import libertem.api as lt
ctx = lt.Context.make_with("inline")
ds = ctx.load("npy", "./fourdstem_array_0.5.npy")

In [12]:
v_window = VirtualDetectorWindow.using(ctx, ds)
v_window.layout()

BokehModel(combine_events=True, render_bundle={'docs_json': {'8438c257-84fc-495f-be43-55598d787ab1': {'version…

In [4]:
# from libertem_ui.windows.com import CoMImagingWindow
# com_window = CoMImagingWindow.using(ctx, ds)
# com_window.layout()

In [5]:

sample_image = plt.imread(r'SilverFast_Resolution_Target_USAF_1951.png')[:, :, 0]
image_shape = sample_image.shape

#remove the black pixel border on the edges
sample_image = sample_image[1:-1, 1:-1]

# Downsample the sample_image by a factor of 2
downsample_factor = 0.5
sample_image = zoom(sample_image, downsample_factor)
sample_image_shape = sample_image.shape

#This is something we can never access in the reverse model, 
# but we can use it to make rotation of an image on the detector possible in the forward model
sample_rotation = 64

sample_image = np.array(sample_image, dtype=np.complex64)
sample_px_size = 0.0001
sample_rotated = rotate(sample_image, sample_rotation, reshape=True, cval=1.0)
sample_rotated_edge_length_x = sample_rotated.shape[1] * sample_px_size
sample_rotated_edge_length_y = sample_rotated.shape[0] * sample_px_size

# Set up grid coordinates corresponding to the physical centers of the pixels.
# Note: We use the rotated image’s physical edge lengths (sample_rotated_edge_length_x/y)
# to generate coordinates that match each pixel center.
sample_coords_x = np.linspace(-sample_rotated_edge_length_x/2,
                              +sample_rotated_edge_length_x/2,
                              sample_rotated.shape[1])

sample_coords_y = np.linspace(-sample_rotated_edge_length_y/2,
                              +sample_rotated_edge_length_y/2,
                              sample_rotated.shape[0])

# Flip the y-axis of the sample_rotated image because regular grid interpolator from jax cannot 
# handle a positive to negative grid coordinate
sample_rotated_flipped = np.flipud(sample_rotated)

# Build the RegularGridInterpolator
sample_interpolant = RegularGridInterpolator(
    (sample_coords_y, sample_coords_x), sample_rotated_flipped,
    method='nearest', bounds_error=False, fill_value=0.0
)

extent = (
    sample_coords_x[0], sample_coords_x[-1],
    sample_coords_y[0], sample_coords_y[-1]
)

In [6]:
import json

#load params.json as a dictionary
with open('params.json') as f:
    params_dict = json.load(f)

# Set up the parameters for the simulation
camera_lengths = params_dict['camera_length']

In [7]:
x_ins = []
y_ins = []
Bs = []
x_dets = []
y_dets = []

for camera_length in camera_lengths:

    #Create ray input z plane
    crossover_z = jnp.zeros((1))
    ScanGrid = comp.ScanGrid(z=jnp.array([params_dict['defocus']]), scan_step=params_dict['scan_step'], scan_shape=params_dict['scan_shape'], scan_rotation=params_dict['scan_rotation'])
    scan_coordinates = ScanGrid.coords
    Detector = comp.Detector(z=jnp.array(camera_length), det_shape=params_dict['det_shape'], det_pixel_size=params_dict['det_px_size'], flip_y=params_dict['flip_y'])
    
    # Load the fourdstem array for the current camera length
    fourdstem_array = np.load(f'fourdstem_array_{camera_length}.npy')

    fourdstem_array = fourdstem_array.reshape(ScanGrid.scan_shape[0] * ScanGrid.scan_shape[1], *Detector.det_shape)
    
    x_in, y_in = scan_coordinates[:, 0], scan_coordinates[:, 1]
    x_ins.append(x_in)
    y_ins.append(y_in)

    B = camera_length - params_dict['defocus']
    B = np.ones_like(x_in) * B
    Bs.append(B)

    for image in fourdstem_array:

        # thresold the image
        image = np.where(np.abs(image) > 0.00, np.abs(image), 0)
        yx_px_det = center_of_mass(np.abs(image))

        xy_det = Detector.pixels_to_metres(yx_px_det)
        x_det, y_det = xy_det[0], xy_det[1]
        x_dets.append(float(x_det))
        y_dets.append(float(y_det))

    fourdstem_array = fourdstem_array.reshape(ScanGrid.scan_shape[0], ScanGrid.scan_shape[1], *Detector.det_shape)
    scan_idx_x, scan_idx_y = -1, -1
    scan_coords = scan_coordinates.reshape(ScanGrid.scan_shape[0], ScanGrid.scan_shape[1], 2)
    scan_pos_x, scan_pos_y = scan_coords[scan_idx_y, scan_idx_x]
    det_image_selected = fourdstem_array[scan_idx_y, scan_idx_x]


In [8]:
import sympy as sp
from sympy import symbols, Matrix

x_in, y_in, dx_in, dy_in = symbols('x_in y_in dx_in dy_in')
x_det, y_det, dx_det, dy_det = symbols('x_det y_det dx_det dy_det')

ray_in = Matrix([[x_in], [y_in], [dx_in], [dy_in]])
ray_out = Matrix([[x_det], [y_det], [dx_det], [dy_det]])

# Descan error matrix elements (position only)
Axx_dp, Axy_dp, Ayx_dp, Ayy_dp = symbols('A_{xx\\_dpos} A_{xy\\_dpos} A_{yx\\_dpos} A_{yy\\_dpos}')
Bxx_cl, Bxy_cl, Byx_cl, Byy_cl = symbols('B_{xx\\_cl} B_{xy\\_cl} B_{yx\\_cl} B_{yy\\_cl}')
Cxx_ds, Cxy_ds, Cyx_ds, Cyy_ds = symbols('C_{xx\\_dslope} C_{xy\\_dslope} C_{yx\\_dslope} C_{yy\\_dslope}')

descan_error_matrix = Matrix([[Axx_dp, Axy_dp, 0, 0],
                              [Ayx_dp, Ayy_dp, 0, 0],
                              [Cxx_ds, Cxy_ds, 1, 0],
                              [Cyx_ds, Cyy_ds, 0, 1]])

camera_length_prop_matrix = Matrix([[1, 0, Bxx_cl, 0],
                                    [0, 1, 0, Byy_cl],
                                    [0, 0, 1, 0],
                                    [0, 0, 0, 1]])

print('Transfer Matrix from sample to detector:')
T = camera_length_prop_matrix * descan_error_matrix
display(T)

print('Transfer Equation with input slope of rays set to 0.0 (i.e finding the central ray of the point source through the system)')
eq = sp.Eq(ray_out, T * ray_in)

# set dx_in and dy_in to 0.0
eq = eq.subs({dx_in: 0.0, dy_in: 0.0})

# Display the x_det and y_det equations independently by extracting the first two rows of the transfer matrix
eq_x = sp.Eq(x_det, eq.rhs[0])
eq_y = sp.Eq(y_det, eq.rhs[1])

display(eq_x)
display(eq_y)

Transfer Matrix from sample to detector:


Matrix([
[A_{xx\_dpos} + B_{xx\_cl}*C_{xx\_dslope}, A_{xy\_dpos} + B_{xx\_cl}*C_{xy\_dslope}, B_{xx\_cl},          0],
[A_{yx\_dpos} + B_{yy\_cl}*C_{yx\_dslope}, A_{yy\_dpos} + B_{yy\_cl}*C_{yy\_dslope},          0, B_{yy\_cl}],
[                          C_{xx\_dslope},                           C_{xy\_dslope},          1,          0],
[                          C_{yx\_dslope},                           C_{yy\_dslope},          0,          1]])

Transfer Equation with input slope of rays set to 0.0 (i.e finding the central ray of the point source through the system)


Eq(x_det, x_in*(A_{xx\_dpos} + B_{xx\_cl}*C_{xx\_dslope}) + y_in*(A_{xy\_dpos} + B_{xx\_cl}*C_{xy\_dslope}))

Eq(y_det, x_in*(A_{yx\_dpos} + B_{yy\_cl}*C_{yx\_dslope}) + y_in*(A_{yy\_dpos} + B_{yy\_cl}*C_{yy\_dslope}))

In [9]:
# Obtain centre of mass of each image in the fourdstem array
# using thresholding for anything above 0.0

def get_centre_of_mass(image):
    threshold = 0.0
    y_indices, x_indices = np.where(image > threshold)
    if len(x_indices) == 0 or len(y_indices) == 0:
        return None, None
    x_com = np.mean(x_indices)
    y_com = np.mean(y_indices)
    return x_com, y_com


def model_x(vars, Axx_dpos, Cxx_dslope, Axy_dpos, Cxy_dslope):
    xin, yin, Bxx = vars
    return xin*(Axx_dpos + Bxx*Cxx_dslope) + yin*(Axy_dpos + Bxx*Cxy_dslope)

def model_y(vars, Ayx_dpos, Cyx_dslope, Ayy_dpos, Cyy_dslope):
    xin, yin, Byy = vars
    return xin*(Ayx_dpos + Byy*Cyx_dslope) + yin*(Ayy_dpos + Byy*Cyy_dslope)

x_ins = np.array(x_ins).flatten()
y_ins = np.array(y_ins).flatten()
Bxx_cl = np.array(Bs).flatten()
Byy_cl = np.array(Bs).flatten()
x_dets = np.array(x_dets).flatten()
y_dets = np.array(y_dets).flatten()

popt_x, pcov_x = curve_fit(model_x, (x_ins, y_ins, Bxx_cl), x_dets, p0=np.zeros(4))
Axx_dpos, Cxx_dslope, Axy_dpos, Cxy_dslope = popt_x

popt_y, pcov_y = curve_fit(model_y, (x_ins, y_ins, Byy_cl), y_dets, p0=np.zeros(4))
Ayx_dpos, Cyx_dslope, Ayy_dpos, Cyy_dslope = popt_y

print("\nFit ABCD Values:")
print(f"Axx_dpos = {Axx_dpos:.1f}  Cxx_dslope = {Cxx_dslope:.1f}")
print(f"Axy_dpos = {Axy_dpos:.1f}  Cxy_dslope = {Cxy_dslope:.1f}")
print(f"Ayx_dpos = {Ayx_dpos:.1f}  Cyx_dslope = {Cyx_dslope:.1f}")
print(f"Ayy_dpos = {Ayy_dpos:.1f}  Cyy_dslope = {Cyy_dslope:.1f}")

Axx, Axy, Ayx, Ayy, Cxx, Cxy, Cyx, Cyy = params_dict['descan_error']
print("\nActual ABCD Values:")
print(f"Axx_dpos = {Axx:.1f}  Cxx_dslope = {Cxx:.1f}")
print(f"Axy_dpos = {Axy:.1f}  Cxy_dslope = {Cxy:.1f}")
print(f"Ayx_dpos = {Ayx:.1f}  Cyx_dslope = {Cyx:.1f}")
print(f"Ayy_dpos = {Ayy:.1f}  Cyy_dslope = {Cyy:.1f}")





Fit ABCD Values:
Axx_dpos = 8.1  Cxx_dslope = -9.7
Axy_dpos = 12.1  Cxy_dslope = -6.1
Ayx_dpos = -14.1  Cyx_dslope = 10.1
Ayy_dpos = 6.1  Cyy_dslope = -7.7

Actual ABCD Values:
Axx_dpos = 8.0  Cxx_dslope = -10.0
Axy_dpos = 12.0  Cxy_dslope = -6.0
Ayx_dpos = -14.0  Cyx_dslope = 10.0
Ayy_dpos = 6.0  Cyy_dslope = -8.0
