In [1]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
from jaxgym.run import propagate, run_to_component
import jaxgym.components as comp
from jaxgym.stemoverfocus import project_frame_backward, project_frame_forward
from jaxgym.ray import Ray
import os
import tqdm.auto as tqdm
from scipy.ndimage import rotate
from scipy.ndimage import zoom
import json
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".2"



In [None]:
fourdstem_array = np.load('fourdstem_array.npy')
params_dict = json.load(open('params.json'))

semi_conv = params_dict['semi_conv']
defocus = params_dict['defocus']
camera_length = params_dict['camera_length']
scan_shape = params_dict['scan_shape']
det_shape = params_dict['det_shape']
scan_step_yx = params_dict['scan_step_yx']
det_px_size = params_dict['det_px_size']
scan_rotation = params_dict['scan_rotation']
descan_error = params_dict['descan_error']

In [None]:
# #Create ray input z plane
crossover_z = jnp.zeros((1))

PointSource = comp.PointSource(z=crossover_z, semi_conv=semi_conv)
ScanGrid = comp.ScanGrid(z=jnp.array([params_dict['defocus']]), 
                         scan_step_yx=params_dict['scan_step_yx'], 
                         scan_shape_yx=params_dict['scan_shape'], 
                         scan_rotation=params_dict['scan_rotation']
                         )

scan_coords = ScanGrid.coords
scan_y, scan_x = scan_coords[0, 0], scan_coords[0, 1]

Descanner = comp.Descanner(z=jnp.array([params_dict['defocus']]), 
                           descan_error=params_dict['descan_error'], 
                           offset_x=scan_x, 
                           offset_y=scan_y
                           )

Detector = comp.Detector(z=jnp.array([params_dict['camera_length']]), 
                         det_shape_yx=params_dict['det_shape'], 
                         det_pixel_size_yx=params_dict['det_px_size'], 
                         rotation=0.0
                         )

model = [PointSource, ScanGrid, Descanner, Detector]

In [None]:
def run_model_for_jacobians(ray, model):

    model_ray_jacobians = []

    # Get all jacobians from one component to another
    for i in range(1, len(model)):
        distance = (model[i].z - ray.z).squeeze()

        # Get the jacobian of the ray propagation
        # from the previous component to the current component
        propagate_jacobian = jax.jacobian(propagate, argnums=1)(distance, ray)
        model_ray_jacobians.append(propagate_jacobian)

        # Propagate the ray
        ray = propagate(distance, ray)

        # Get the jacobian of the step function of the current component
        component_jacobian = jax.jacobian(model[i].step)(ray)
        model_ray_jacobians.append(component_jacobian)

        #Propagate the ray
        ray = model[i].step(ray)

    # Edit the jacobian matrices to include shifts calculated 
    # from the optical path length derivative - not the best solution for now but it works.
    ABCDs = [] #ABCD matrices at each component

    for ray_jacobian in model_ray_jacobians:
        shift_vector = ray_jacobian.pathlength.matrix # This is the shift vector for each ray, dopl_out/dr_in
        ABCD = ray_jacobian.matrix.matrix # This is the ABCD matrix for each ray, dr_out/dr_in
        ABCD = ABCD.at[:, -1].set(shift_vector[0, :])
        ABCD = ABCD.at[-1, -1].set(1.0) # Add the final one to bottom right corner of the matrix.
        ABCDs.append(ABCD)

    return jnp.array(ABCDs)

In [None]:


# Prepare input ray position for this scan point.
input_ray_positions = jnp.array([scan_x, scan_y, 0.0, 0.0, 1.0])

ray = Ray(
    z=PointSource.z,
    matrix=input_ray_positions,
    amplitude=jnp.ones(1),
    pathlength=jnp.zeros(1),
    wavelength=jnp.ones(1),
    blocked=jnp.zeros(1, dtype=float)
)

transfer_matrices = run_model_for_jacobians(ray, model)

# total_transfer_matrix = transfer_matrices[-1]

# def detector_to_point_source(rays) -> 'Rays':

#     for tm, component in zip(reversed(transfer_matrices[:-1], reversed(model))):
#         if component.has_inverse():
#             rays = rays.inverse(rays)
#         else:
#             rays = jnp.linalg.inv(tm) @ rays
    
#     return rays


total_transfer_matrix = transfer_matrices[-1]

for tm in reversed(transfer_matrices[:-1]):
    total_transfer_matrix = total_transfer_matrix @ tm

detector_to_point_source = jnp.linalg.inv(total_transfer_matrix)

In [None]:
def loss(det_coords, tilts, inv_matrix, scan_pos):
    n_rays = det_coords.shape[0]
    
    rays_det_matrix = jnp.array([det_coords[:, 1], 
                                det_coords[:, 0],
                                tilts[:, 0],
                                tilts[:, 1],
                                jnp.ones(n_rays)]).T
    
    rays_pt_source = inv_matrix @ rays_det_matrix

    rays_pt_source_x = rays_pt_source[:, 0]
    rays_pt_source_y = rays_pt_source[:, 1]

    start_position = jnp.array([rays_pt_source_x, rays_pt_source_y])   

    error = jnp.linalg.norm(start_position - scan_pos) 

    return error


In [None]:
det_coords = Detector.get_coords()