In [51]:
from jaxgym.ode import odedopri,  electron_equation_of_motion_DA
from jaxgym.field import schiske_lens_expansion_xyz, obtain_first_order_electrostatic_lens_properties
import jaxgym.components as comp
from jaxgym.ray import Ray
from jaxgym.run import run_to_end, calculate_derivatives
from jaxgym.taylor import poly_dict, order_indices, poly_dict_to_sympy_expr

import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx

from scipy.constants import h as h_planck, e, m_e
from daceypy import array, DA
import sympy as sp
import os

from scipy.integrate import simpson

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_LIMIT_MB"] = "400"

In [130]:
X, Y, Z = sp.symbols('X Y Z')

scale = 1 #set to mm scale
# Set Parameters for Schiske Electrostatic Lens
# Define lens parameters

z_init = -0.020*scale  # Convert m to um units
a = 0.0004*scale  # Convert m to um units
phi_0 = 60e4  # Volts

k = 0.40**(1/2)  # Unitless

(
    phi_expansion_symbolic,
    E_lambda, phi_lambda,
    phi_lambda_axial,
    phi_lambda_prime,
    phi_lambda_double_prime,
    phi_lambda_quadruple_prime,
    phi_lambda_sextuple_prime
) = schiske_lens_expansion_xyz(X, Y, Z, phi_0, a, k)

wavelength = h_planck/(2*abs(e)*m_e*phi_lambda_axial(z_init))**(1/2)*scale

In [131]:
z_pos, g, g_, h, h_, mag_real, z_image, z_focal_real, z_focal_asymp, z_pi = obtain_first_order_electrostatic_lens_properties(
    z_init, phi_lambda_axial, phi_lambda_prime, phi_lambda_double_prime, z_sampling=1000)

R_image_to_aperture = np.abs(z_init)
p_I = phi_lambda_axial(z_image)
p_O = phi_lambda_axial(z_init)


In [132]:
Uz0 = phi_lambda_axial(z_pos[0])

U_val = phi_lambda_axial(z_pos)
U_val_ = phi_lambda_prime(z_pos)
U_val__ = phi_lambda_double_prime(z_pos)
U_val____ = phi_lambda_quadruple_prime(z_pos)


def L_1():
    return (1/(32*jnp.sqrt(U_val)))*((U_val__**2)/(U_val)-U_val____)


def L_2():
    return (1/(8*jnp.sqrt(U_val)))*(U_val__)


def L_3():
    return 1/2*(jnp.sqrt(U_val))


def F_020():
    return (L_1()/4)*h*h*h*h + (L_2()/2)*h*h*h_*h_ + (L_3()/4)*h_*h_*h_*h_


Cs = 4/jnp.sqrt(abs(Uz0))*simpson(F_020(), x=z_pos)*mag_real
B = simpson(F_020(), x=z_pos)

print('Cs in metres (Aberration Integral Method):', Cs)

Cs in metres (Aberration Integral Method): -184.69636177930158


In [133]:
order = 4

DA.init(order, 5)

# Set initial conditions. It does not matter that the slope and position are 0.
x0 = 0.
y0 = 0.
x0_slope = 0.0
y0_slope = 0.0
opl = 0.

u0 = phi_lambda_axial(z_init)  # initial potential
x = array([x0 + DA(1), y0 + DA(2), x0_slope + DA(3), y0_slope + DA(4), opl + DA(5)])

# solve the equation of motion via the differential algebra method, which delivers the solution as a taylor expansion
with DA.cache_manager():
    zf, x_f = odedopri(electron_equation_of_motion_DA, z_init,
                        x,  z_image,  1e-1, 10000, 1e-15,  
                        int(1e5), (phi_lambda, E_lambda, u0))
    
Cs_daceypy = x_f[0].getCoefficient([0, 0, 3, 0])

In [134]:
Cs = 4/jnp.sqrt(abs(Uz0))*simpson(F_020(), x=z_pos)*mag_real
print('Cs (Aberration Integral Method):', Cs)
print('Cs (DA) - x polynomial', Cs_daceypy)
print('Cs (DA) - opl polynomial', x_f[4].getCoefficient([0, 0, 4, 0, ]) * mag_real * 4/3)


Cs (Aberration Integral Method): -184.69636177930158
Cs (DA) - x polynomial -184.69636531168646
Cs (DA) - opl polynomial -184.69636531166043


In [144]:
z_init = jnp.array(z_init)
z_image = jnp.array(z_image)

PointSource = comp.InputPlane(z=z_init)
ElectrostaticLens = comp.ODE(z=z_init, z_end=z_image, phi_lambda=phi_lambda, E_lambda=E_lambda)
Detector = comp.Detector(z=z_image, det_pixel_size=(55e-6, 55e-6), det_shape = (256, 256))
model = [PointSource, ElectrostaticLens, Detector]

det_coords = Detector.get_coords()
print(jnp.max(det_coords[0]), jnp.max(det_coords[1]))

ray = Ray(0., 0., 0., 0., 0., z_init, 0.0)
ray_out = run_to_end(ray, model)

derivatives = calculate_derivatives(ray, model, order)

0.00704 0.00704


In [58]:
selected_vars = ['x', 'y', 'dx', 'dy', 'pathlength']
multi_indices = order_indices(order, n_vars=len(selected_vars))

In [59]:
poly_dicts = poly_dict(derivatives, selected_vars, multi_indices[1:])

In [136]:
x_var, y_var, dx_var, dy_var, x_out, y_out, dx_out, dy_out, opl_var = sp.symbols("x y x' y' x_out y_out dx_out, dy_out S", real=True)
x_a, y_a = sp.symbols("x_a y_a", real=True)
polynomials = poly_dict_to_sympy_expr(poly_dicts, selected_vars, sym_vars=[x_var, y_var, dx_var, dy_var, opl_var])

poly_opl = polynomials['pathlength']
print(polynomials['x'])
poly_x = sp.Eq(polynomials['x'], x_out)
poly_y = sp.Eq(polynomials['y'], y_out)
poly_dx = sp.solve(poly_x, dx_var)[0]
poly_dy = sp.solve(poly_y, dy_var)[0]


-23069095.9559152*x**3 - 1384309.95856267*x**2*x' - 27693.3774475891*x*x'**2 - 23069095.9559152*x*y**2 - 922797.959323745*x*y*y' - 9231.12581594256*x*y'**2 - 0.883574443486702*x - 184.696365322523*x'**3 - 461511.999238918*x'*y**2 - 18462.2516316465*x'*y*y' - 184.696365322523*x'*y'**2 - 3.97026161946812e-14*x'


Convert polynomials from function of input slope to function of aperture. Replace everywhere x' with x_a and y' with y_a.

For this we will simply say that x_a = x_o' * R where R is the distance to the aperture

In [145]:
poly_x_xa = polynomials['x'].subs({dx_var:x_a / R_image_to_aperture, dy_var:y_a / R_image_to_aperture})
poly_y_xa = polynomials['y'].subs({dx_var:x_a / R_image_to_aperture, dy_var:y_a / R_image_to_aperture})
poly_opl_xa = polynomials['pathlength'].subs({dx_var:x_a / R_image_to_aperture, dy_var:y_a / R_image_to_aperture})

In [146]:
jax_poly_x_xa = sp.lambdify([x_var, y_var, x_a, y_a], poly_x_xa, modules='jax')
jax_poly_y_ya = sp.lambdify([x_var, y_var, x_a, y_a], poly_y_xa, modules='jax')
jax_poly_opl_xa = sp.lambdify([x_var, y_var, x_a, y_a], poly_opl_xa, modules='jax')

jax_poly_x = sp.lambdify([x_var, y_var, dx_var, dy_var], polynomials['x'], modules='jax')
jax_poly_y = sp.lambdify([x_var, y_var, dx_var, dy_var], polynomials['y'], modules='jax')
jax_poly_opl = sp.lambdify([x_var, y_var, dx_var, dy_var], polynomials['pathlength'], modules='jax')

In [None]:
def input_slope_from_image_position(params, args):
    dx, dy = params           # params is a rank-1 array of length 2
    x_in, y_in, x_out, y_out = args
    res_x = jax_poly_x(x_in, y_in, dx, dy) - x_out 
    res_y = jax_poly_y(x_in, y_in, dx, dy) - y_out
    return jnp.stack([res_x, res_y])

def input_aperture_from_image_position(params, args):
    dx, dy = params           # params is a rank-1 array of length 2
    x_in, y_in, x_out, y_out = args
    res_x = jax_poly_x_xa(x_in, y_in, dx, dy) - x_out 
    res_y = jax_poly_y_ya(x_in, y_in, dx, dy) - y_out
    return jnp.stack([res_x, res_y])


solver = optx.Newton(rtol=1e-13, atol=1e-13)
params = jnp.array([0.0, 0.0])  # initial guess for dx and dy
args = (0.0, 0.0, 7e-3, 7e-3)
sol = optx.root_find(input_slope_from_image_position, solver, params, args).value
sol_a = optx.root_find(input_aperture_from_image_position, solver, params, args).value

jax_poly_x_xa(0, 0, sol_a[0], sol_a[1])
print('Input Aperture Pos:', sol_a)
print('Input Slopes:', sol)

print(jax_poly_x_xa(0, 0, sol_a[0], sol_a[1]))
print(jax_poly_y_ya(0, 0, sol_a[0], sol_a[1]))
print(jax_poly_x(0, 0, sol[0], sol[1]))
print(jax_poly_y(0, 0, sol[0], sol[1]))

print(jax_poly_opl_xa(0, 0, sol_a[0], sol_a[1]))
print(jax_poly_opl(0, 0, sol[0], sol[1]))

Input Aperture Pos: [-0.00053321 -0.00053321]
Input Slopes: [-0.0266606 -0.0266606]
0.006999999999999999
0.006999999999999999
0.006999999999999999
0.007000000000000001
1.0*S + 0.00031682252557148
1.0*S + 0.00031682252557148


In [151]:
dSdxodxa = sp.diff(poly_opl_xa, x_var, x_a)
dSdxodya = sp.diff(poly_opl_xa, x_var, y_a)
dSdyodxa = sp.diff(poly_opl_xa, y_var, x_a)
dSdyodya = sp.diff(poly_opl_xa, y_var, y_a)

dSdxodxp = sp.diff(poly_opl, x_var, dx_var)
dSdxodyp = sp.diff(poly_opl, x_var, dy_var)
dSdyodxp = sp.diff(poly_opl, y_var, dx_var)
dSdyodyp = sp.diff(poly_opl, y_var, dy_var)

hess_xa = sp.Matrix([[dSdxodxa, dSdxodya], [dSdyodxa, dSdyodya]])
hess_xp = sp.Matrix([[dSdxodxp, dSdxodyp], [dSdyodxp, dSdyodyp]])

det_hess_xa = sp.sqrt(hess_xa.det())
det_hess_xp = sp.sqrt(hess_xp.det())

det_hess_xa_f = sp.lambdify([x_var, y_var, x_a, y_a], det_hess_xa, modules='jax')
det_hess_xp_f = sp.lambdify([x_var, y_var, dx_var, dy_var], det_hess_xp, modules='jax')

print('Hessian det (xa):', det_hess_xa)
print('Hessian det (xp):', det_hess_xp)


Hessian det (xa): 30660304012.1624*sqrt(0.173546124834627*x**4 + 0.585807097233285*x**3*x_a + 0.733057810197867*x**2*x_a**2 + 0.347092249669253*x**2*y**2 + 0.585807097233286*x**2*y*y_a + 0.233057810197868*x**2*y_a**2 - 6.78590013299994e-20*x**2 + 0.402880957064423*x*x_a**3 + 0.585807097233286*x*x_a*y**2 + x*x_a*y*y_a + 0.402880957064423*x*x_a*y_a**2 - 1.14529940830302e-19*x*x_a + 0.0820841192959234*x_a**4 + 0.233057810197868*x_a**2*y**2 + 0.402880957064423*x_a**2*y*y_a + 0.164168238591846*x_a**2*y_a**2 - 4.66691160458557e-20*x_a**2 + 0.173546124834627*y**4 + 0.585807097233285*y**3*y_a + 0.733057810197867*y**2*y_a**2 - 6.78590013299994e-20*y**2 + 0.402880957064423*y*y_a**3 - 1.14529940830302e-19*y*y_a + 0.0820841192959234*y_a**4 - 4.66691160458557e-20*y_a**2 + 4.97509386830106e-39)
Hessian det (xp): 361267514.015826*sqrt(0.500000000000001*x**4 + 0.0337551240508255*x**3*x' + 0.000844798823248176*x**2*x'**2 + x**2*y**2 + 0.0337551240508255*x**2*y*y' + 0.000268583133642369*x**2*y'**2 - 1.9

In [152]:
amplitude_xa = 1 / (1j * wavelength) * det_hess_xa_f(0, 0, sol_a[0], sol_a[1])
amplitude_xp = 1 / (1j * wavelength) * det_hess_xp_f(0, 0, sol[0], sol[1])

print('Amplitude (xa):', amplitude_xa * R_image_to_aperture)
print('Amplitude (xp):', amplitude_xp)

Amplitude (xa): -63090702606539.63j
Amplitude (xp): -63090702606539.6j


In [153]:
var = 'pathlength'

if var == 'x':
    var_idx = 0
elif var == 'y':
    var_idx = 1
elif var == 'dx':
    var_idx = 2
elif var == 'dy':
    var_idx = 3
elif var == 'pathlength':
    var_idx = 4
else:
    raise ValueError(f"Unknown variable: {var}")

header = f"{'I':>6}  {'COEFFICIENT':>3}   {'ORDER':>16} {'EXPONENTS':>4}"
const = f"{1:6d}   {getattr(ray_out, var): .16e}   {0} {'  0  0  0  0  0':15s}"

print_jax = [header, const]
for idx, entry in enumerate(poly_dicts[var]):
    exponents = tuple(map(int, entry[:-1]))
    coeff = entry[-1]
    total_order = sum(exponents)
    exponents_str = " ".join(f"{e:2d}" for e in exponents)
    print_jax.append(f"{idx:6d}   {coeff: .16e}   {total_order}  {exponents_str}")
print_jax.append('------------------------------------------------')


header = f"{'I':>6}  {'COEFFICIENT':>3}   {'ORDER':>16} {'EXPONENTS':>4}"
const = f"{1:6d}   {x_f[var_idx].getCoefficient([0, 0]): .16e}   {0} {'  0  0  0  0  0':15s}"

print_daceypy = [header, const]

for idx, entry in enumerate(poly_dicts[var]):
    exponents = tuple(map(int, entry[:-1]))
    coeff = x_f[var_idx].getCoefficient(list(exponents))
    total_order = sum(exponents)
    exponents_str = " ".join(f"{e:2d}" for e in exponents)
    print_daceypy.append(f"{idx:6d}   {coeff: .16e}   {total_order}  {exponents_str}")
print_daceypy.append('------------------------------------------------')

# Print the two blocks side by side
for left, right in zip(print_jax, print_daceypy):
    print(f"{left:<60} {right}")


     I  COEFFICIENT              ORDER EXPONENTS                  I  COEFFICIENT              ORDER EXPONENTS
     1    3.7410593888927583e-02   0   0  0  0  0  0              1    3.7410593888934779e-02   0   0  0  0  0  0
     0    1.0000000000000000e+00   1   0  0  0  0  1              0    1.0000000000000000e+00   1   0  0  0  0  1
     1   -4.5055490371578915e-13   2   0  0  0  2  0              1    6.8988353441393202e-14   2   0  0  0  2  0
     2   -4.3252089410028560e-11   2   0  1  0  1  0              2    8.1939524609886405e-12   2   0  1  0  1  0
     3   -4.5055490371578915e-13   2   0  0  2  0  0              3    6.8988353441393202e-14   2   0  0  2  0  0
     4   -4.3252089410028560e-11   2   1  0  1  0  0              4    8.1939524609886405e-12   2   1  0  1  0  0
     5    4.7068796329324236e+01   2   0  2  0  0  0              5    4.7068796330550931e+01   2   0  2  0  0  0
     6    4.7068796329324236e+01   2   2  0  0  0  0              6    4.7068796330550931e+0