In [None]:
from jax import config
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np

from src.lib_phaseshifts import readPHASESHIFTS
from src.lib_tensors import *
from src.lib_delta import *
from src.delta import *
from src.utils import *
from src.hashable_array import HashableArray
from src.data_structures import ReferenceData

#%matplotlib widget

In [None]:
#From "PARAM"
LMAX = 14  # maximum angular momentum to be used in calculation
n_beams = 9  # no. of TLEED output beams
n_atoms = 1  # currently 1 is the only possible choice
n_geo = 1  # number of geometric variations ('displacements') to be considered

vib_amps = 0.1908624 * BOHR

In [None]:
# unit vectors in Angstrom
u_vec1 = np.array([1.2722, -2.2036])
u_vec2 = np.array([1.2722,  2.2036])

# area of (overlayer) lateral unit cell - in case TLEED wrt smaller unit cell is used, TVA from reference computation must be set.
unit_cell_area = np.linalg.norm(np.cross(u_vec1, u_vec2))

In [None]:
IEL = 1  # element no. (in phase shifts supplied with input) that delta amplitudes
#          will be calculated for (not necessarily the same element as the one
#          used in the reference calculation!) - IEL = 0 means a vacancy will be assumed



In [None]:
data_path = Path('tests') / 'test_data' / 'Cu_111_2'
phaseshifts_file = data_path / 'PHASESHIFTS'
T1_file = data_path / 'Tensors' / 'T_1'

In [None]:
_, raw_phaseshifts, _, _ = readPHASESHIFTS(None, None, readfile=phaseshifts_file,
                                       check=False, ignoreEnRange=False)


In [None]:
n_energies = 0
with open(T1_file, 'r') as datei:
    for zeile in datei:
        if '-1' in zeile:
            n_energies += 1


In [None]:
# read single tensor file
tensors = [
    read_tensor(T1_file, n_beams, n_energies)
]

In [None]:
tensor_data = (read_tensor(T1_file, n_beams=9, n_energies= n_energies, l_max=LMAX+1),)

# covert to reference data
ref_data = ReferenceData(tensor_data)

In [None]:
ref_data.max_lmax

In [None]:
ref_data.lmax == jnp.asarray(ref_data.needed_lmax)

In [None]:
site_indices = np.array([0])
phaseshifts = Phaseshifts(raw_phaseshifts, ref_data.energies, LMAX, site_indices)

In [None]:
delta_amp = lambda displacement: delta_amplitude((vib_amps,),
                                                 displacement,
                                                 ref_data,
                                                 unit_cell_area,
                                                 phaseshifts, 
                                                 batch_lmax=True)

In [None]:
my_delta = delta_amp(np.array([[0.05, 0.0, 0.0],]))

In [None]:
%timeit delta_amp(np.array([[0.05, 0.0, 0.0],]))

In [None]:
abs_amp = []
for d in range(-5, 6):
    abs_amp.append(np.sum(abs(delta_amp(np.array([[d*0.01, 0.0, 0.0],])))))

plt.scatter(range(-5, 6), abs_amp)

# Intensity

In [None]:
from src.lib_intensity import *

In [None]:
# Reference Amplitudes
ref_amps = ref_data.ref_amps

beam_indices = np.array([[1, 0], [0, 1], [1, 1], [2, 0], [0, 2], [2, 1], [1, 2], [3, 0], [0, 3]])

trar = np.empty(shape=(2, 2), dtype="float")
trar[0, :] = [1.306759, -0.7544285]
trar[1, :] = [1.306759, 0.7544285]

v_real = ref_data.v0r

theta, phi = 0.0, 0.0

is_surface_atom = np.array([True]) # topmost atom is surface atom

In [None]:
lam_prefactor = lambda displacements: intensity_prefactor(displacements,
                                                          ref_data,
                                                          beam_indices,
                                                          theta, phi,
                                                          trar, is_surface_atom)

In [None]:
e_inside = ref_data.energies

In [None]:
delta_intensity = lambda displacement: sum_intensity(
    lam_prefactor(displacement), ref_amps, delta_amp(displacement)
)

In [None]:
plt.figure()
for i in range(0, 11):
    if i == 5:
        disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
        plt.plot((ref_data.energies-ref_data.v0r)*HARTREE, delta_intensity(disp)[:,0],linewidth=3,color='black',label='0.00 Å')
    else:
        disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
        plt.plot((ref_data.energies-ref_data.v0r)*HARTREE, delta_intensity(disp)[:,0],label = str(np.round(-0.05 + 0.01* i,2)) + ' Å')

plt.xlabel("Energy (eV)")
plt.ylabel("Intensity")
plt.legend(fontsize='small')
plt.show()

In [None]:
# plotting the numerical derivatives
plt.figure()
my_intensities = np.full(shape=(11,n_energies),dtype=np.complex128,fill_value=np.nan)
my_diffs = np.full(shape=(11,n_energies),dtype=np.complex128,fill_value=np.nan)
for i in range(11):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]])
    my_intensities[i,:]=delta_intensity(disp)[:,0]
for j in range(n_energies):
    my_diffs[:,j] = np.gradient(my_intensities[:,j],0.01)

for i in range(0, 11):
    if i == 5:
        disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
        plt.plot((ref_data.energies-ref_data.v0r)*HARTREE, my_diffs[i,:],linewidth=3,color='black',label='0.00 Å')
    else:
        disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
        plt.plot((ref_data.energies-ref_data.v0r)*HARTREE, my_diffs[i,:],label = str(np.round(-0.05 + 0.01* i,2)) + ' Å')

plt.xlabel("Energy (eV)")
plt.ylabel(r'$\left. \frac{dI}{dx} \right|_{x=x_0}$')
plt.ylim([-0.045,0.05])
plt.legend(fontsize='small', loc='upper left')
plt.show()

In [None]:
plt.figure()
for i in range(0, 11):
    if i == 5:   
        disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
        plt.plot((ref_data.energies-ref_data.v0r)*HARTREE, jax.jacfwd(delta_intensity)(disp)[:, 0, 0, 0],linewidth=3,color='black',label='0.00 Å')
    else:
        disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
        plt.plot((ref_data.energies-ref_data.v0r)*HARTREE, jax.jacfwd(delta_intensity)(disp)[:, 0, 0, 0],label = str(np.round(-0.05 + 0.01* i,2)) + ' Å')

plt.xlabel("Energy (eV)")
plt.ylabel(r'$\left. \frac{dI}{dx} \right|_{x=x_0}$')
plt.ylim([-0.045,0.05])
plt.legend(fontsize='small', loc='upper left')
plt.show()

In [None]:
estimate_function_cost(delta_intensity, disp)

# Interpolation

In [None]:
from src.interpolation import *

In [None]:
target_grid = jnp.linspace(e_inside[0], e_inside[-1], 200)
interpolator = StaticNotAKnotSplineInterpolator(e_inside,
                                                target_grid, 3)

In [None]:
def intensity_interpolated(displacement, beam):
    raw_intensity = delta_intensity(displacement)[:,beam]
    rhs = not_a_knot_rhs(raw_intensity)
    bspline_coeffs = get_bspline_coeffs(interpolator, rhs)
    interpolated_intensity = evaluate_spline(bspline_coeffs, interpolator, 0)
    interpolated_deriv = evaluate_spline(bspline_coeffs, interpolator, 1)
    return interpolated_intensity, interpolated_deriv

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(target_grid*HARTREE, intensity_interpolated(disp,0)[0])
plt.title("Interpolated Intensity")

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(target_grid*HARTREE, intensity_interpolated(disp,0)[1])
plt.title("Interpolated Derivative")

In [None]:
from src.rfactor import *

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[-0.05 + 0.01* i, 0.0, 0.0]]) 
    plt.plot(target_grid*HARTREE, pendry_y(intensity_interpolated(disp,0)[1],intensity_interpolated(disp,0)[0], 4.5))
plt.title("Interpolated Y-function")

# Rfactor

In [None]:
from src.rfactor import *

In [None]:
ref_intensity = delta_intensity(jnp.array([[0.0, 0.0, 0.0],]))[:,0]
R_fun = pendry_R_vs_reference(
    ref_intensity,
    interpolator,
    interpolator,
    4.5,
    3.0,
    0.5,
)

In [None]:
lam_r = lambda z: jnp.real(R_fun(delta_intensity(jnp.array([[z, 0.0, 0.0],]))[:,0]))

In [None]:
config.update("jax_debug_nans", True)
config.update('jax_disable_jit', False)

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 500)
#R_arr = [lam_r(r) for r in z_arr]
R_grad_arr = [jax.grad(lam_r)(r) for r in z_arr]

In [None]:
jax.grad(lam_r)(-0.05)
#-9.61726374

In [None]:
plt.figure()
plt.plot(z_arr, R_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R_grad_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R_grad_arr)
plt.plot(z_arr[:-1], jnp.diff(np.array(R_arr))/ (z_arr[1]-z_arr[0]))

# R2

In [None]:
ref_intensity_all_beams = delta_intensity(jnp.array([[0.0, 0.0, 0.0],]))

In [None]:
lam_r2 = lambda z: jnp.real(((delta_intensity(jnp.array([[z, 0.0, 0.0],])) - ref_intensity_all_beams)**2).sum())

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 100)
R2_arr = [lam_r2(r) for r in z_arr]
R2_grad_arr = [jax.grad(lam_r2)(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R2_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R2_grad_arr)

# Timing

In [None]:
# Function cost
estimate_function_cost(lam_r, 0.0)

In [None]:
# Function cost
estimate_function_cost(lam_r2, 0.0)

In [None]:
l = jax.jit(lam_r2).lower(0.0).compile()
%timeit l(0.0)

In [None]:
l2 = jax.jit(jax.grad(lam_r2)).lower(0.0).compile()
%timeit l2(0.0)