In [1]:
from jaxgym.ode import odedopri,  electron_equation_of_motion_DA
from jaxgym.field import schiske_lens_expansion_xyz, obtain_first_order_electrostatic_lens_properties
import jaxgym.components as comp
from jaxgym.ray import Ray
from jaxgym.run import run_to_end, calculate_derivatives
from jaxgym.taylor import poly_dict, order_indices, poly_dict_to_sympy_expr

import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx

from scipy.constants import h as h_planck, e, m_e
from daceypy import array, DA
import sympy as sp
import os

from scipy.integrate import simpson

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_LIMIT_MB"] = "400"

In [2]:
X, Y, Z = sp.symbols('X Y Z')

scale = 1 #set to mm scale
# Set Parameters for Schiske Electrostatic Lens
# Define lens parameters

z_init = -0.020*scale  # Convert m to um units
a = 0.0004*scale  # Convert m to um units
phi_0 = 1  # Volts
k = 0.40**(1/2)  # Unitless

(
    phi_expansion_symbolic,
    E_lambda, phi_lambda,
    phi_lambda_axial,
    phi_lambda_prime,
    phi_lambda_double_prime,
    phi_lambda_quadruple_prime,
    phi_lambda_sextuple_prime
) = schiske_lens_expansion_xyz(X, Y, Z, phi_0, a, k)

wavelength = h_planck/(2*abs(e)*m_e*phi_lambda_axial(z_init))**(1/2)*scale

In [3]:
z_pos, g, g_, h, h_, mag_real, z_image, z_focal_real, z_focal_asymp, z_pi = obtain_first_order_electrostatic_lens_properties(
    z_init, phi_lambda_axial, phi_lambda_prime, phi_lambda_double_prime, z_sampling=1000)

R = abs(z_pi) + z_image

In [4]:
Uz0 = phi_lambda_axial(z_pos[0])

U_val = phi_lambda_axial(z_pos)
U_val_ = phi_lambda_prime(z_pos)
U_val__ = phi_lambda_double_prime(z_pos)
U_val____ = phi_lambda_quadruple_prime(z_pos)


def L_1():
    return (1/(32*jnp.sqrt(U_val)))*((U_val__**2)/(U_val)-U_val____)


def L_2():
    return (1/(8*jnp.sqrt(U_val)))*(U_val__)


def L_3():
    return 1/2*(jnp.sqrt(U_val))


def F_020():
    return (L_1()/4)*h*h*h*h + (L_2()/2)*h*h*h_*h_ + (L_3()/4)*h_*h_*h_*h_


Cs = 4/jnp.sqrt(abs(Uz0))*simpson(F_020(), x=z_pos)*mag_real
B = simpson(F_020(), x=z_pos)

print('Cs in metres (Aberration Integral Method):', Cs)

Cs in metres (Aberration Integral Method): -184.69636175520125


In [5]:
order = 4

DA.init(order, 5)

# Set initial conditions. It does not matter that the slope and position are 0.
x0 = 0.
y0 = 0.
x0_slope = 0.0
y0_slope = 0.0
opl = 0.

u0 = phi_lambda_axial(z_init)  # initial potential
x = array([x0 + DA(1), y0 + DA(2), x0_slope + DA(3), y0_slope + DA(4), opl + DA(5)])

# solve the equation of motion via the differential algebra method, which delivers the solution as a taylor expansion
with DA.cache_manager():
    zf, x_f = odedopri(electron_equation_of_motion_DA, z_init,
                        x,  z_image,  1e-1, 10000, 1e-15,  
                        int(1e5), (phi_lambda, E_lambda, u0))
    
Cs_daceypy = x_f[0].getCoefficient([0, 0, 3, 0])

In [6]:
Cs = 4/jnp.sqrt(abs(Uz0))*simpson(F_020(), x=z_pos)*mag_real
print('Cs (Aberration Integral Method):', Cs)
print('Cs (DA) - x polynomial', Cs_daceypy)
print('Cs (DA) - opl polynomial', x_f[4].getCoefficient([0, 0, 4, 0, ]) * mag_real * 4/3)


Cs (Aberration Integral Method): -184.69636175520125
Cs (DA) - x polynomial -184.69636531277365
Cs (DA) - opl polynomial -184.69636531495235


In [7]:
z_init = jnp.array(z_init)
z_image = jnp.array(z_image)

PointSource = comp.InputPlane(z=z_init)
ElectrostaticLens = comp.ODE(z=z_init, z_end=z_image, phi_lambda=phi_lambda, E_lambda=E_lambda)
Detector = comp.Detector(z=z_image, det_pixel_size=(55e-6, 55e-6), det_shape = (256, 256))
model = [PointSource, ElectrostaticLens, Detector]

ray = Ray(0., 0., 0., 0., 0., z_init, 0.0)
ray_out = run_to_end(ray, model)

derivatives = calculate_derivatives(ray, model, order)

In [None]:
selected_vars = ['x', 'y', 'dx', 'dy', 'pathlength']
multi_indices = order_indices(order, n_vars=len(selected_vars))

In [9]:
poly_dicts = poly_dict(derivatives, selected_vars, multi_indices[1:])

In [10]:
x_var, y_var, dx_var, dy_var, x_out, y_out, dx_out, dy_out, opl_var = sp.symbols("x y x' y' x_out y_out dx_out, dy_out S", real=True)
polynomials = poly_dict_to_sympy_expr(poly_dicts, selected_vars, sym_vars=[x_var, y_var, dx_var, dy_var, opl_var])

x_out_eq = sp.Eq(x_out, polynomials['x'])
y_out_eq = sp.Eq(y_out, polynomials['y'])
dx_out_eq = sp.Eq(dx_out, polynomials['dx'])
dy_out_eq = sp.Eq(dy_out, polynomials['dy'])

display(x_out_eq)
display(y_out_eq)
display(dx_out_eq)
display(dy_out_eq)
dx_out = sp.solve(x_out_eq, dx_var)
dy_out = sp.solve(y_out_eq, dy_var)
S_opl = polynomials['pathlength']

display(dy_out[0].simplify())

Eq(x_out, -23069095.9559152*x**3 - 1384309.95856267*x**2*x' - 27693.3774475891*x*x'**2 - 23069095.9559152*x*y**2 - 922797.959323745*x*y*y' - 9231.12581594256*x*y'**2 - 0.883574443486702*x - 184.696365322523*x'**3 - 461511.999238918*x'*y**2 - 18462.2516316465*x'*y*y' - 184.696365322523*x'*y'**2 - 3.97026161946812e-14*x')

Eq(y_out, -23069095.9559152*x**2*y - 461511.999238918*x**2*y' - 922797.959323745*x*x'*y - 18462.2516316465*x*x'*y' - 9231.12581594256*x'**2*y - 184.696365322523*x'**2*y' - 23069095.9559152*y**3 - 1384309.95856267*y**2*y' - 27693.3774475891*y*y'**2 - 0.883574443486702*y - 184.696365322523*y'**3 - 3.97026161946812e-14*y')

Eq(dx_out, -1304694783.91109*x**3 - 78294402.5492511*x**2*x' - 1566362.58817763*x*x'**2 - 1304694783.91109*x*y**2 - 52187182.9264694*x*y*y' - 522024.349709946*x*y'**2 - 106.544172191599*x - 10447.0514652327*x'**3 - 26107219.6227816*x'*y**2 - 1044338.23846768*x'*y*y' - 10447.0514652327*x'*y'**2 - 1.13179197789377*x')

Eq(dy_out, -1304694783.91109*x**2*y - 26107219.6227816*x**2*y' - 52187182.9264694*x*x'*y - 1044338.23846768*x*x'*y' - 522024.349709946*x'**2*y - 10447.0514652327*x'**2*y' - 1304694783.91109*y**3 - 78294402.5492511*y**2*y' - 1566362.58817763*y*y'**2 - 106.544172191599*y - 10447.0514652327*y'**3 - 1.13179197789377*y')

(14.5575945808104*x**2 + 0.582359667236472*x*x' + 0.00582592611101888*x'**2 + 0.00614249111504384*y**2 - (49.9800079971435*y + 57.2155099431972*(0.333427112804452*x**2*y + 0.0133375776782558*x*x'*y + 0.000133421250430716*x'**2*y + y**3 - 0.000133421250429566*y*(2498.76059246218*x**2 + 99.9600159938562*x*x' + x'**2 + 7495.05793546798*y**2 + 2.14961545807093e-16) + 1.27706641041575e-8*y + 1.44534104605411e-8*y_out + (0.0164712872897871*(x**2 + 0.0400038388212933*x*x' + 0.000400198403567203*x'**2 + 0.000421944098040811*y**2 + 8.60272674603369e-20)**3 + (0.333427112804452*x**2*y + 0.0133375776782558*x*x'*y + 0.000133421250430716*x'**2*y + y**3 - 0.000133421250429566*y*(2498.76059246218*x**2 + 99.9600159938562*x*x' + x'**2 + 7495.05793546798*y**2 + 2.14961545807093e-16) + 1.27706641041575e-8*y + 1.44534104605411e-8*y_out)**2)**0.5)**(1/3))*(0.333427112804452*x**2*y + 0.0133375776782558*x*x'*y + 0.000133421250430716*x'**2*y + y**3 - 0.000133421250429566*y*(2498.76059246218*x**2 + 99.96001599

In [11]:

dS_opl_dx = sp.diff(S_opl, dx_var) * mag_real / 3
display(dS_opl_dx)
display(polynomials['x'])


dSdxdx_prime = sp.diff(S_opl, x_var, dx_var)
dSdydx_prime = sp.diff(S_opl, y_var, dx_var)
dSdxdy_prime = sp.diff(S_opl, x_var, dy_var)
dSdydy_prime = sp.diff(S_opl, y_var, dy_var)
Hessian = sp.Matrix([[dSdxdx_prime, dSdydx_prime], [dSdxdy_prime, dSdydy_prime]])
display(sp.sqrt(sp.det(Hessian)).simplify())

-43438535.6633821*x**3 - 2199428.57645442*x**2*x' - 35849.1228341767*x*x'**2 - 43438535.6633822*x*y**2 - 1466260.60217105*x*y*y' - 11949.7076113923*x*y'**2 + 1.27388136100571e-11*x - 184.696365307494*x'**3 - 733167.974283363*x'*y**2 - 23899.4152227845*x'*y*y' - 184.696365307494*x'*y'**2 + 2.65399198874395e-13*x'

-23069095.9559152*x**3 - 1384309.95856267*x**2*x' - 27693.3774475891*x*x'**2 - 23069095.9559152*x*y**2 - 922797.959323745*x*y*y' - 9231.12581594256*x*y'**2 - 0.883574443486702*x - 184.696365322523*x'**3 - 461511.999238918*x'*y**2 - 18462.2516316465*x'*y*y' - 184.696365322523*x'*y'**2 - 3.97026161946812e-14*x'

361267514.015826*sqrt(0.500000000000001*x**4 + 0.0337551240508255*x**3*x' + 0.000844798823248176*x**2*x'**2 + x**2*y**2 + 0.0337551240508255*x**2*y*y' + 0.000268583133642369*x**2*y'**2 - 1.95507106236636e-19*x**2 + 9.28585314015699e-6*x*x'**3 + 0.0337551240508255*x*x'*y**2 + 0.00115243137921162*x*x'*y*y' + 9.285853140157e-6*x*x'*y'**2 - 6.5993948836045e-21*x*x' + 3.78385259246288e-8*x'**4 + 0.000268583133642369*x'**2*y**2 + 9.285853140157e-6*x'**2*y*y' + 7.56770518492574e-8*x'**2*y'**2 - 5.37829537713126e-23*x'**2 + 0.500000000000001*y**4 + 0.0337551240508255*y**3*y' + 0.000844798823248176*y**2*y'**2 - 1.95507106236636e-19*y**2 + 9.28585314015699e-6*y*y'**3 - 6.5993948836045e-21*y*y' + 3.78385259246288e-8*y'**4 - 5.37829537713126e-23*y'**2 + 1.43336357208836e-38)

In [12]:
jax_poly_x = sp.lambdify([x_var, y_var, dx_var, dy_var], polynomials['x'], modules='jax')
jax_poly_y = sp.lambdify([x_var, y_var, dx_var, dy_var], polynomials['y'], modules='jax')
jax_poly_opl = sp.lambdify([x_var, y_var, dx_var, dy_var], polynomials['pathlength'], modules='jax')

In [41]:
def image_position(params, args):
    dx, dy = params[0], params[1]
    x_in, y_in = args[0], args[1]
    x_out, y_out = args[2], args[3]
    res_x = jax_poly_x(x_in, y_in, dx, dy) - x_out 
    res_y = jax_poly_y(y_in, y_in, dx, dy) - y_out
    return res_x, res_y

solver = optx.Newton(rtol=1e-8, atol=1e-8)
params = jnp.array([0.0, 0.0])
args = jnp.array([0.0, 0.0, 1e-6, 1e-6])
sol = optx.root_find(image_position, solver, params, args)

print(sol.value)

hess = Hessian.subs({x_var: 0.0, y_var: 0.0, dx_var: sol.value[2], dy_var: sol.value[3]})
hess = np.array(hess).astype(np.float64)

[-0.0013937 -0.0013937]


In [42]:
(1 / (wavelength * 1j)) * np.sqrt(np.linalg.det(hess))

-222582515.0322084j

In [15]:
var = 'pathlength'

if var == 'x':
    var_idx = 0
elif var == 'y':
    var_idx = 1
elif var == 'dx':
    var_idx = 2
elif var == 'dy':
    var_idx = 3
elif var == 'pathlength':
    var_idx = 4
else:
    raise ValueError(f"Unknown variable: {var}")

header = f"{'I':>6}  {'COEFFICIENT':>3}   {'ORDER':>16} {'EXPONENTS':>4}"
const = f"{1:6d}   {getattr(ray_out, var): .16e}   {0} {'  0  0  0  0  0':15s}"

print_jax = [header, const]
for idx, entry in enumerate(poly_dicts[var]):
    exponents = tuple(map(int, entry[:-1]))
    coeff = entry[-1]
    total_order = sum(exponents)
    exponents_str = " ".join(f"{e:2d}" for e in exponents)
    print_jax.append(f"{idx:6d}   {coeff: .16e}   {total_order}  {exponents_str}")
print_jax.append('------------------------------------------------')


header = f"{'I':>6}  {'COEFFICIENT':>3}   {'ORDER':>16} {'EXPONENTS':>4}"
const = f"{1:6d}   {x_f[var_idx].getCoefficient([0, 0]): .16e}   {0} {'  0  0  0  0  0':15s}"

print_daceypy = [header, const]

for idx, entry in enumerate(poly_dicts[var]):
    exponents = tuple(map(int, entry[:-1]))
    coeff = x_f[var_idx].getCoefficient(list(exponents))
    total_order = sum(exponents)
    exponents_str = " ".join(f"{e:2d}" for e in exponents)
    print_daceypy.append(f"{idx:6d}   {coeff: .16e}   {total_order}  {exponents_str}")
print_daceypy.append('------------------------------------------------')

# Print the two blocks side by side
for left, right in zip(print_jax, print_daceypy):
    print(f"{left:<60} {right}")


     I  COEFFICIENT              ORDER EXPONENTS                  I  COEFFICIENT              ORDER EXPONENTS
     1    3.7410593889031396e-02   0   0  0  0  0  0              1    3.7410593889038488e-02   0   0  0  0  0  0
     0    1.0000000000000000e+00   1   0  0  0  0  1              0    1.0000000000000000e+00   1   0  0  0  0  1
     1   -4.5055490371578915e-13   2   0  0  0  2  0              1    1.3498661281634294e-13   2   0  0  0  2  0
     2   -4.3252089410028560e-11   2   0  1  0  1  0              2    2.0658530441863832e-11   2   0  1  0  1  0
     3   -4.5055490371578915e-13   2   0  0  2  0  0              3    1.3498661281634294e-13   2   0  0  2  0  0
     4   -4.3252089410028560e-11   2   1  0  1  0  0              4    2.0658530441863832e-11   2   1  0  1  0  0
     5    4.7068796329324236e+01   2   0  2  0  0  0              5    4.7068796331136703e+01   2   0  2  0  0  0
     6    4.7068796329324236e+01   2   2  0  0  0  0              6    4.7068796331136703e+0