In [2]:
from jaxgym.ode import odedopri, electron_equation_of_motion, electron_equation_of_motion_DA
from jaxgym.field import schiske_lens_expansion_xyz, obtain_first_order_electrostatic_lens_properties
import jaxgym.components as comp
from jaxgym.ray import Ray, ray_matrix
from jaxgym.run import run_to_end

import jax.numpy as jnp
import sympy as sp

from scipy.constants import h as h_planck, e, m_e
from daceypy import array, DA

import jax

ModuleNotFoundError: No module named 'jaxgym'

In [2]:
X, Y, Z = sp.symbols('X Y Z')

scale = 1e3 #set to mm scale
# Set Parameters for Schiske Electrostatic Lens
# Define lens parameters

z_init = -0.020*scale  # Convert m to um units
a = 0.0004*scale  # Convert m to um units
phi_0 = 1  # Volts
k = 0.40**(1/2)  # Unitless

(
    phi_expansion_symbolic,
    E_lambda, phi_lambda,
    phi_lambda_axial,
    phi_lambda_prime,
    phi_lambda_double_prime,
    phi_lambda_quadruple_prime,
    phi_lambda_sextuple_prime
) = schiske_lens_expansion_xyz(X, Y, Z, phi_0, a, k)

wavelength = h_planck/(2*abs(e)*m_e*phi_lambda_axial(z_init))**(1/2)*scale

In [3]:
z_pos, g, g_, h, h_, mag_real, z_image, z_focal_real, z_focal_asymp, z_pi = obtain_first_order_electrostatic_lens_properties(
    z_init, phi_lambda_axial, phi_lambda_prime, phi_lambda_double_prime, z_sampling=1000)

In [4]:
print(mag_real, z_init, z_image)

-0.8835744434957277 -20.0 17.67049607348454


In [5]:
DA.init(3, 5)

# Set initial conditions. It does not matter that the slope and position are 0.
x0 = 0.
y0 = 0.
x0_slope = 0.
y0_slope = 0.
opl = 0.

u0 = 1.0

x = array([x0 + DA(1), y0 + DA(2), x0_slope + DA(3), y0_slope + DA(4), opl + DA(5)])

# solve the equation of motion via the differential algebra method, which delivers the solution as a taylor expansion, directly calculating the value of spherical aberration.
with DA.cache_manager():
    zf, x_f = odedopri(electron_equation_of_motion_DA, z_init,
                        x,  z_image,  1e-6, 10000, 1e-15,  
                        int(1e6), (phi_lambda, E_lambda, u0))

magnification = x_f[0].getCoefficient([1, 0])
Cs_DA = x_f[0].getCoefficient([0, 3])

print(mag_real, magnification)

-0.8835744434957277 -0.8835743719960909


In [1]:
z_init = jnp.array(z_init)
z_image = jnp.array(z_image)

PointSource = comp.PointSource(z=z_init, semi_conv=0.0)
ElectrostaticLens = comp.ODE(z=z_init, z_end=z_image, phi_lambda=phi_lambda, E_lambda=E_lambda)
Detector = comp.Detector(z=z_image, det_pixel_size=[1e-6, 1e-6], det_shape=(128, 128))
model = [PointSource, ElectrostaticLens, Detector]

ray = ray_matrix(0., 0., 0., 0., z_init, 0.)
ray_out = run_to_end(ray, model)

print(ray_out)


NameError: name 'jnp' is not defined

In [None]:

jax.jacfwd(jax.jacfwd(jax.jacfwd(run_to_end(ray, model))))(y0)