In [None]:
from jax.config import config
config.update("jax_debug_nans", True)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np

from src.lib_phaseshifts import readPHASESHIFTS
from src.lib_tensors import *
from src.lib_tscatf import *
from src.delta import *
from src.delta import _select_phaseshifts
from src.utils import *

%matplotlib widget

In [None]:
#From "PARAM"
LMAX = 14  # maximum angular momentum to be used in calculation
n_beams = 9  # no. of TLEED output beams
n_atoms = 1  # currently 1 is the only possible choice
n_geo = 1  # number of geometric variations ('displacements') to be considered

DR = 0.1908624

In [None]:
# unit vectors in Angstrom
u_vec1 = np.array([1.2722, -2.2036])
u_vec2 = np.array([1.2722,  2.2036])

# area of (overlayer) lateral unit cell - in case TLEED wrt smaller unit cell is used, TVA from reference computation must be set.
unit_cell_area = np.linalg.norm(np.cross(u_vec1, u_vec2))
# In Bohr radii
unit_cell_area = unit_cell_area / BOHR**2

In [None]:
IEL = 1  # element no. (in phase shifts supplied with input) that delta amplitudes
#          will be calculated for (not necessarily the same element as the one
#          used in the reference calculation!) - IEL = 0 means a vacancy will be assumed



In [None]:
phaseshifts_file = Path("PHASESHIFTS")
T1_file = Path("T_1")

In [None]:
_, phaseshifts, _, _ = readPHASESHIFTS(None, None, readfile=phaseshifts_file,
                                       check=False, ignoreEnRange=False)


In [None]:
n_energies = 0
with open(T1_file, 'r') as datei:
    for zeile in datei:
        if '-1' in zeile:
            n_energies += 1


In [None]:
tensor_dict = read_tensor(T1_file, n_beams=9, n_energies= n_energies, l_max=LMAX+1)

# TODO: raise Error if requested energies are out of range respective to
# phaseshift energies (can't interpolate if out of range)

e_inside = tensor_dict['e_kin']  # computational energy inside crystal
interpolated_phaseshifts = interpolate_phaseshifts(phaseshifts, LMAX, e_inside)

In [None]:
t_matrix_ref = tensor_dict['t_matrix']  # atomic t-matrix of current site as used in reference calculation
VV = tensor_dict['v0r']  # real part of the inner potential
v_imag = tensor_dict['v0i_substrate']# imaginary part of the inner potential, substrate

tensor_amps_out = tensor_dict['tensor_amps_out']  # spherical wave amplitudes incident from exit beam NEXIT in "time-reversed"
#                                       LEED experiment (or rather, all terms of Born series immediately after
#                                       scattering on current atom)
tensor_amps_in = tensor_dict['tensor_amps_in']  # spherical wave amplitudes incident on current atomic site in reference calculation
# crop tensors to LMAX
tensor_amps_out = tensor_amps_out
tensor_amps_in = tensor_amps_in
#                                     (i.e., scattering path ends before scattering on that atom)
out_k_par2, out_k_par3 = tensor_dict['kx_in'], tensor_dict['ky_in']  # (negative) absolute lateral momentum of Tensor LEED beams
#                                                        (for use as incident beams in time-reversed LEED calculation)

# NewCAF: working array in which current (displaced) atomic t-matrix is stored
# TODO: we could also either append empty phaseshifts to the phaseshifts array or move the conditional around tscatf
selected_phaseshifts = _select_phaseshifts(IEL, interpolated_phaseshifts)
tscatf_vmap = jax.vmap(tscatf, in_axes=(None, 0, 0, None, None, None, None))
t_matrix_new = tscatf_vmap(LMAX,
                            selected_phaseshifts,
                            e_inside, DR)

# amplitude differences
matel_dwg_vmap_energy = jax.vmap(MATEL_DWG, in_axes=(0, 0, 0, 0, None, 0, 0, 0, 0, None, None))
delta_amp = lambda displacement: matel_dwg_vmap_energy(t_matrix_ref, t_matrix_new, e_inside, v_imag,
                    LMAX, tensor_amps_out, tensor_amps_in, out_k_par2, out_k_par3,
                    unit_cell_area, displacement)

# Intensity

In [None]:
from src.lib_intensity import *

In [None]:
# Reference Amplitudes
ref_amps = tensor_dict['ref_amps']

beam_indices = np.array([[1, 0], [0, 1], [1, 1], [2, 0], [0, 2], [2, 1], [1, 2], [3, 0], [0, 3]])

trar = np.empty(shape=(2, 2), dtype="float")
trar[0, :] = u_vec1
trar[1, :] = u_vec2

v_real = tensor_dict['v0r']

theta, phi = 0.0, 0.0

is_surface_atom = np.array([True]) # topmost atom is surface atom

In [None]:
lam_prefactor = lambda displacement: intensity_prefactor(tensor_dict, displacement, beam_indices, theta, phi, trar, is_surface_atom)

In [None]:
delta_intensity = lambda displacement: sum_intensity(
    lam_prefactor(displacement), ref_amps, delta_amp(displacement)
)

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(e_inside, delta_intensity(disp)[:,0])


In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(e_inside, jax.jacfwd(delta_intensity)(disp)[:, 0, 0, 0])

In [None]:
estimate_function_cost(delta_intensity, disp)

# Interpolation

In [None]:
from src.interpolation import *

In [None]:
target_grid = jnp.linspace(e_inside[0], e_inside[-1], 200)
interpolator = StaticNotAKnotSplineInterpolator(e_inside,
                                                target_grid, 3)

In [None]:
def intensity_interpolated(displacement, beam):
    raw_intensity = delta_intensity(displacement)[:,beam]
    rhs = not_a_knot_rhs(raw_intensity)
    bspline_coeffs = get_bspline_coeffs(interpolator, rhs)
    interpolated_intensity = evaluate_spline(bspline_coeffs, interpolator, 0)
    interpolated_deriv = evaluate_spline(bspline_coeffs, interpolator, 1)
    return interpolated_intensity, interpolated_deriv

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(target_grid*HARTREE, intensity_interpolated(disp,0)[0])
plt.title("Interpolated Intensity")

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(target_grid*HARTREE, intensity_interpolated(disp,0)[1])
plt.title("Interpolated Derivative")

In [None]:
from src.rfactor import *

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(target_grid*HARTREE, pendry_y(intensity_interpolated(disp,0)[1],intensity_interpolated(disp,0)[0], 4.5))
plt.title("Interpolated Y-function")

# Rfactor

In [None]:
from src.rfactor import *

In [None]:
ref_intensity = delta_intensity(jnp.array([[0.0, 0.0, 0.0],]))[:,0]
R_fun = pendry_R_vs_reference(
    ref_intensity,
    interpolator,
    interpolator,
    4.5,
    3.0,
    0.5,
)

In [None]:
lam_r = lambda z: jnp.real(R_fun(delta_intensity(jnp.array([[z, 0.0, 0.0],]))[:,0]))

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 500)
R_arr = [lam_r(r) for r in z_arr]
R_grad_arr = [jax.grad(lam_r)(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R_grad_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R_grad_arr)
plt.plot(z_arr[:-1], jnp.diff(np.array(R_arr))/ (z_arr[1]-z_arr[0]))

# R2

In [None]:
ref_intensity_all_beams = delta_intensity(jnp.array([[0.0, 0.0, 0.0],]))

In [None]:
lam_r2 = lambda z: jnp.real(((delta_intensity(jnp.array([[z, 0.0, 0.0],])) - ref_intensity_all_beams)**2).sum())

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 100)
R2_arr = [lam_r2(r) for r in z_arr]
R2_grad_arr = [jax.grad(lam_r2)(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R2_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R2_grad_arr)

# Timing

In [None]:
# Function cost
estimate_function_cost(lam_r, 0.0)

In [None]:
# Function cost
estimate_function_cost(lam_r2, 0.0)

In [None]:
l = jax.jit(lam_r2).lower(0.0).compile()
%timeit l(0.0)

In [None]:
l2 = jax.jit(jax.grad(lam_r2)).lower(0.0).compile()
%timeit l2(0.0)