In [None]:
from jax import config
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update("jax_disable_jit", False)
config.update("jax_log_compiles", False)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np

from src.lib_tensors import read_tensor
from src.lib_phaseshifts import *
from src.data_structures import ReferenceData
from src.tensor_calculator import TensorLEEDCalculator

%matplotlib inline

In [None]:
jax.devices()

In [None]:
use_installable = False
if use_installable:
    import viperleed
    from viperleed.calc.files import poscar
    from viperleed.calc.files import parameters
    from viperleed.calc.classes.rparams import Rparams
    from viperleed.calc.files.beams import readIVBEAMS
    from viperleed.calc.files.phaseshifts import readPHASESHIFTS
    from viperleed.calc.files.vibrocc import readVIBROCC
else:
    # master
    import sys
    sys.path.append('/Users/alexander/GitHub/')
    import viperleed
    from viperleed.tleedmlib.files import poscar
    from viperleed.tleedmlib.files import parameters
    from viperleed.tleedmlib.classes.rparams import Rparams
    from viperleed.tleedmlib.files.beams import readIVBEAMS
    from viperleed.tleedmlib.files.phaseshifts import readPHASESHIFTS
    from viperleed.tleedmlib.files.vibrocc import readVIBROCC

In [None]:
data_path = Path('tests') / 'test_data' / 'Cu_111'

In [None]:
# Read in data from POSCAR and PARAMETERS files
slab = poscar.read(data_path / 'POSCAR')
rparams = parameters.read(data_path / 'PARAMETERS')
parameters.interpret(rparams, slab, silent=False)
slab.full_update(rparams)

# reading IVBEAMS
rparams.ivbeams = readIVBEAMS(data_path / 'IVBEAMS')
beam_indices = np.array([beam.hk for beam in rparams.ivbeams])

# reading VIBROCC
readVIBROCC(rparams, slab, data_path / 'VIBROCC')

# incidence angles
rparams.THETA = 0.0
rparams.PHI = 0.0

In [None]:
LMAX = rparams.LMAX.max

In [None]:
param_energies = np.linspace(rparams.THEO_ENERGIES.start,
                           rparams.THEO_ENERGIES.stop,
                           rparams.THEO_ENERGIES.n_energies)

## Read Tensor files

In [None]:
read_tensor_num = lambda num: read_tensor(data_path / 'Tensors' / f'T_{num}', n_beams=len(rparams.ivbeams), n_energies=param_energies.size, l_max=LMAX+1)
non_bulk_atoms = [at for at in slab.atlist if not at.is_bulk]
tensors = [read_tensor_num(at.num) for at in non_bulk_atoms]

ref = ReferenceData(tensors, fix_lmax=10)

#delete tensors to free up memory
for t in tensors:
    del t
del tensors

In [None]:
# read phase shifts
phaseshifts_path = data_path /  'PHASESHIFTS'
_, raw_phaseshifts, _, _ = readPHASESHIFTS(slab, rparams, readfile=phaseshifts_path,
                                       check=True, ignoreEnRange=False)


In [None]:
# interpolate & assign phase shifts to atoms
site_indices = [np.argmax([at.site.isEquivalent(s) for s in slab.sitelist])
                for at in non_bulk_atoms]

# TODO: with this current implementation, we can not treat chemical
#       pertubations, nor vacancies. We need to implement this.
#       See e.g. iodeltas.generateDeltaInput()
#       (Treating vacancies requires setting zeros for that site)

phaseshifts = Phaseshifts(raw_phaseshifts, ref.energies, LMAX, site_indices)

# Calculator

In [None]:
calculator = TensorLEEDCalculator(ref, phaseshifts, slab, rparams)

In [None]:
centered_vib_amps = calculator.ref_vibrational_amps
centered_displacements = np.array([[0.0, 0.0, 0.0],]*5)

In [None]:
%time ref_int = calculator.intensity(centered_vib_amps, centered_displacements)

In [None]:
# set reference point
calculator.set_experiment_intensity(ref_int)

# Intensity

In [None]:
# some displacements to play with
spaced_displacements = [
    np.array([[i*0.01-0.05, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4)
    for i in range (11)
]

In [None]:
beam = 0

In [None]:
%matplotlib inline
plt.figure()
for d in spaced_displacements:
    plt.plot(param_energies,
             calculator.intensity(centered_vib_amps, d)[:, beam])
plt.show()

# Interpolation

In [None]:
plt.figure()
for d in spaced_displacements:
    plt.plot(calculator.interpolated(centered_vib_amps, d, deriv_deg=0)[:, beam])
plt.title("Interpolated Intensity")

In [None]:
plt.figure()
for d in spaced_displacements:
    plt.plot(calculator.interpolated(centered_vib_amps, d, deriv_deg=1)[:, beam])
plt.title("Interpolated Derivative")

In [None]:
from src.rfactor import pendry_y
plt.figure()
for d in spaced_displacements:
    intensity = calculator.interpolated(centered_vib_amps, d, deriv_deg=0)[:, beam]
    deriv = calculator.interpolated(centered_vib_amps, d, deriv_deg=1)[:, beam]
    plt.plot(pendry_y(intensity, deriv, 4.5))
plt.title("Interpolated Y-function")

# Rfactor

In [None]:
# compile time
%time calculator.R_pendry_val_and_grad(centered_vib_amps, centered_displacements)

In [None]:
# execution time
%timeit calculator.R_pendry_val_and_grad(centered_vib_amps, centered_displacements, beam)

In [None]:
R_arr = []
R_grad_arr = []
z_arr = []
for d in spaced_displacements:
    R, gradient = calculator.R_pendry_val_and_grad(centered_vib_amps, d)
    R_arr.append(R)
    R_grad_arr.append(gradient)
    z_arr.append(d[0][0])

In [None]:
plt.figure()
plt.plot(z_arr, R_arr)

In [None]:
plt.figure()
plt.plot(z_arr, [g[0,0] for g in R_grad_arr])
plt.plot(z_arr, [g[0,1] for g in R_grad_arr])

In [None]:
plt.figure()
plt.plot(z_arr, [g[0,0] for g in R_grad_arr])
plt.plot(z_arr[:-1], jnp.diff(np.array(R_arr))/ (z_arr[1]-z_arr[0]))

# R2 - TODO

In [None]:
ref_intensity_all_beams = delta_intensity(jnp.array([[0.0, 0.0, 0.0],]))

In [None]:
lam_r2 = lambda z: jnp.real(((delta_intensity(jnp.array([[z, 0.0, 0.0],])) - ref_intensity_all_beams)**2).sum())

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 100)
R2_arr = [lam_r2(r) for r in z_arr]
R2_grad_arr = [jax.grad(lam_r2)(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R2_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R2_grad_arr)