In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jaxgym.ray import Ray
from jaxgym import CoordsXY
from microscope_calibration.model import Model, ModelParameters, create_stem_model
from jaxgym.run import run_to_end
from jaxgym.components import DescanError
from jaxgym.run import solve_model
from jaxgym.transfer import accumulate_transfer_matrices
from jaxgym import components as comp

import os

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='0.01'

def solve_model_fourdstem_wrapper(model: Model) -> tuple:
    # Unpack model components.
    PointSource = model.source
    ScanGrid = model.scan_grid
    Descanner = model.descanner
    Detector = model.detector

    ray = Ray(
        x=0.0,
        y=0.0,
        dx=0.0,
        dy=0.0,
        _one=1.0,
        z=PointSource.z,
        pathlength=jnp.zeros(1),
    )

    scan_coords = (0.0, 0.0, 0.0, 0.0)  # (scan_pos_x, scan_pos_y, scan_tilt_x, scan_tilt_y)

    def _solve_model(scan_pos, descanner, idx_one, idx_two):
        # Create a new Descanner with the current scan offsets.
        wrapped_descanner = comp.Descanner(
            z=ScanGrid.z,
            descan_error=descanner.descan_error,
            scan_pos_x=scan_pos[0],
            scan_pos_y=scan_pos[1],
            scan_tilt_x=scan_pos[2],
            scan_tilt_y=scan_pos[3],
        )

        # Make a new model each time:
        current_model = Model(PointSource, ScanGrid, wrapped_descanner, Detector)

        # via a single ray and it's jacobian, get the transfer matrices for the model
        transfer_matrices = solve_model(ray, current_model)

        total_tm = accumulate_transfer_matrices(
            transfer_matrices, idx_one, idx_two
        )
        return total_tm, total_tm


    model_jac_fn = jax.jacobian(_solve_model, has_aux=True)

    total_grad_tm, total_tm = model_jac_fn(scan_coords, Descanner, 0, 3)
    scangrid_to_det_grad_tm, scangrid_to_det_tm = model_jac_fn(scan_coords, Descanner, 1, 3)

    return (total_tm, total_grad_tm), (scangrid_to_det_tm, scangrid_to_det_grad_tm)


In [2]:
s = 0.01  # Scale for descan error

descan_error = DescanError(
        pxo_pxi=np.random.uniform(-0.1, -0.1),
        pxo_pyi=np.random.uniform(-0.4, -0.2),
        pyo_pxi=np.random.uniform(-0.04, -0.8),
        pyo_pyi=np.random.uniform(-0.1, -0.4),
        sxo_pxi=np.random.uniform(-0.02, -0.8),
        sxo_pyi=np.random.uniform(-0.3, -0.1),
        syo_pxi=np.random.uniform(-0.3, -0.1),
        syo_pyi=np.random.uniform(-0.5, -0.2),
        offpxi=np.random.uniform(-0.0, 0.0),
        offsxi=np.random.uniform(-0.0, 0.0),
        offpyi=np.random.uniform(-0.0, 0.0),
        offsyi=np.random.uniform(-0.0, 0.0)
    )

de = DescanError(*descan_error)

params = ModelParameters(
    semi_conv=0.001,
    defocus=0.00001,  # Distance from the crossover to the sample
    camera_length=1.0,  # Distance from the sample to detector
    scan_shape=(49, 49),  # YX!
    det_shape=(64, 64),  # YX!
    scan_step=(1e-9, 1e-9),  # YX!
    det_px_size=(1e-4, 1e-4),  # YX!
    scan_rotation=33.,
    descan_error=de,
    flip_y=False,
)

import sympy as sp

B0 = sp.symbols('B_pt_src_to_descan')
B1 = sp.symbols('B_descan_to_det')

# Define symbolic variables for DescanError parameters
pxo_pxi, pxo_pyi, pyo_pxi, pyo_pyi, txo_pxi, txo_pyi, tyo_pxi, tyo_pyi, offpxi, offpyi, offsxi, offsyi = sp.symbols(
    "pxo_pxi, pxo_pyi, pyo_pxi, pyo_pyi, txo_pxi, txo_pyi, tyo_pxi, tyo_pyi, offpxi, offpyi, offsxi, offsyi"
)

# Define symbolic matrix for scan coordinates
scan_pos_x, scan_pos_y, scan_tilt_x, scan_tilt_y = sp.symbols("sp_x, sp_y, sp_tilt_x, sp_tilt_y")


descan_sym = sp.Matrix([
    [1, 0, 0, 0, pxo_pxi*scan_pos_x + pxo_pyi*scan_pos_y - scan_pos_x],
    [0, 1, 0, 0, pyo_pxi*scan_pos_x + pyo_pyi*scan_pos_y - scan_pos_y],
    [0, 0, 0, 0, scan_pos_x*txo_pxi + scan_pos_y*txo_pyi - scan_tilt_x],
    [0, 0, 0, 0, scan_pos_x*tyo_pxi + scan_pos_y*tyo_pyi - scan_tilt_y],
    [0, 0, 0, 0, 1]
])

prop_sym0 = sp.Matrix([
    [1, 0, B0, 0, 0],
    [0, 1, 0, B0, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 0, 0, 1]
])

prop_sym1 = sp.Matrix([
    [1, 0, B1, 0, 0],
    [0, 1, 0, B1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 0, 0, 1]
])

# Combine the symbolic matrices
transfer_matrix_symbolic = prop_sym1 * descan_sym * prop_sym0

# Differentiate total_sym with respect to scan_coords
total_sym_dspx = transfer_matrix_symbolic.diff(scan_pos_x)
total_sym_dspy = transfer_matrix_symbolic.diff(scan_pos_y)
total_sym_dstx = transfer_matrix_symbolic.diff(scan_tilt_x)
total_sym_dsty = transfer_matrix_symbolic.diff(scan_tilt_y)

final_matrix = scan_pos_x * total_sym_dspx + scan_pos_y * total_sym_dspy + scan_tilt_x * total_sym_dstx + scan_tilt_y * total_sym_dsty

final_matrix = final_matrix + sp.Matrix([
    [1, 0, B1, 0, 0],
    [0, 1, 0, B1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 0, 0, 0]
])


print("\nTotal Symbolic Matrix: ray offset equation (0, 4 variable):")
display(transfer_matrix_symbolic)

print("\nTotal Symbolic Matrix from derivatives:")
display(final_matrix)

print(transfer_matrix_symbolic.equals(final_matrix))

total_tm_and_grad, scangrid_tm_and_grad = solve_model_fourdstem_wrapper(create_stem_model(params))
total_tm, total_tm_grad = total_tm_and_grad
scangrid_tm, scangrid_tm_grad = scangrid_tm_and_grad

# convert JAX arrays into Sympy matrices
total_tm = sp.Matrix(total_tm.tolist())
total_tm_grad = tuple(sp.Matrix(arr.tolist()) for arr in total_tm_grad)
scangrid_tm = sp.Matrix(scangrid_tm.tolist())
scangrid_tm_grad = tuple(sp.Matrix(arr.tolist()) for arr in scangrid_tm_grad)

final_matrix_jax_numerical = scan_pos_x * total_tm_grad[0] + scan_pos_y * total_tm_grad[1] + scan_tilt_x * total_tm_grad[2] + scan_tilt_y * total_tm_grad[3]

final_matrix_sympy_numerical = final_matrix.subs({
    B0: params['defocus'],
    B1: params['camera_length'],
    pxo_pxi: params['descan_error'].pxo_pxi,
    pxo_pyi: params['descan_error'].pxo_pyi,
    pyo_pxi: params['descan_error'].pyo_pxi,
    pyo_pyi: params['descan_error'].pyo_pyi,
    txo_pxi: params['descan_error'].sxo_pxi,
    txo_pyi: params['descan_error'].sxo_pyi,
    tyo_pxi: params['descan_error'].syo_pxi,
    tyo_pyi: params['descan_error'].syo_pyi,
    offpxi: params['descan_error'].offpxi,
    offpyi: params['descan_error'].offpyi,
    offsxi: params['descan_error'].offsxi,
    offsyi: params['descan_error'].offsyi
})


print(final_matrix_jax_numerical)
print(final_matrix_sympy_numerical)



Total Symbolic Matrix: ray offset equation (0, 4 variable):


Matrix([
[1, 0, B_pt_src_to_descan,                  0, B_descan_to_det*(-sp_tilt_x + sp_x*txo_pxi + sp_y*txo_pyi) + pxo_pxi*sp_x + pxo_pyi*sp_y - sp_x],
[0, 1,                  0, B_pt_src_to_descan, B_descan_to_det*(-sp_tilt_y + sp_x*tyo_pxi + sp_y*tyo_pyi) + pyo_pxi*sp_x + pyo_pyi*sp_y - sp_y],
[0, 0,                  0,                  0,                                                        -sp_tilt_x + sp_x*txo_pxi + sp_y*txo_pyi],
[0, 0,                  0,                  0,                                                        -sp_tilt_y + sp_x*tyo_pxi + sp_y*tyo_pyi],
[0, 0,                  0,                  0,                                                                                               1]])


Total Symbolic Matrix from derivatives:


Matrix([
[1, 0, B_descan_to_det,               0, -B_descan_to_det*sp_tilt_x + sp_x*(B_descan_to_det*txo_pxi + pxo_pxi - 1) + sp_y*(B_descan_to_det*txo_pyi + pxo_pyi)],
[0, 1,               0, B_descan_to_det, -B_descan_to_det*sp_tilt_y + sp_x*(B_descan_to_det*tyo_pxi + pyo_pxi) + sp_y*(B_descan_to_det*tyo_pyi + pyo_pyi - 1)],
[0, 0,               1,               0,                                                                             -sp_tilt_x + sp_x*txo_pxi + sp_y*txo_pyi],
[0, 0,               0,               1,                                                                             -sp_tilt_y + sp_x*tyo_pxi + sp_y*tyo_pyi],
[0, 0,               0,               0,                                                                                                                    0]])

False
Matrix([[0, 0, 0, 0, -1.0*sp_tilt_x - 1.24082042891629*sp_x - 0.353433914120824*sp_y], [0, 0, 0, 0, -1.0*sp_tilt_y - 0.663008361121168*sp_x - 1.79988774644879*sp_y], [0, 0, 0, 0, -1.0*sp_tilt_x - 0.140820428916292*sp_x - 0.105400821728097*sp_y], [0, 0, 0, 0, -1.0*sp_tilt_y - 0.113994563622735*sp_x - 0.490407760019323*sp_y], [0, 0, 0, 0, 0]])
Matrix([[1, 0, 1.00000000000000, 0, -1.0*sp_tilt_x - 1.24082042891629*sp_x - 0.353433914120824*sp_y], [0, 1, 0, 1.00000000000000, -1.0*sp_tilt_y - 0.663008361121168*sp_x - 1.79988774644879*sp_y], [0, 0, 1, 0, -sp_tilt_x - 0.140820428916292*sp_x - 0.105400821728097*sp_y], [0, 0, 0, 1, -sp_tilt_y - 0.113994563622735*sp_x - 0.490407760019323*sp_y], [0, 0, 0, 0, 0]])
