In [None]:
from jax import config
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np

from src.lib_tensors import *
from src.lib_delta import *
from src.delta import *
from src.utils import *
from src.hashable_array import HashableArray
from src.data_structures import *

%matplotlib inline

In [None]:
jax.devices()

In [None]:
# master
import sys
sys.path.append('/Users/alexander/GitHub/')
import viperleed
from viperleed.tleedmlib.files import poscar
from viperleed.tleedmlib.files import parameters
from viperleed.tleedmlib.classes.rparams import Rparams
from viperleed.tleedmlib.files.beams import readIVBEAMS
from viperleed.tleedmlib.files.phaseshifts import readPHASESHIFTS
from viperleed.tleedmlib.files.vibrocc import readVIBROCC

In [None]:
# installable
# import viperleed
# from viperleed.calc.files import poscar
# from viperleed.calc.files import parameters
# from viperleed.calc.classes.rparams import Rparams
# from viperleed.calc.files.beams import readIVBEAMS
# from viperleed.calc.files.phaseshifts import readPHASESHIFTS
# from viperleed.calc.files.vibrocc import readVIBROCC

In [None]:
data_path = Path('tests') / 'test_data' / 'Cu_111'

In [None]:
# Read in data from POSCAR and PARAMETERS files
slab = poscar.read(data_path / 'POSCAR')
rparams = parameters.read(data_path / 'PARAMETERS')
parameters.interpret(rparams, slab, silent=False)
slab.full_update(rparams)

In [None]:
# reading IVBEAMS
rparams.ivbeams = readIVBEAMS(data_path / 'IVBEAMS')
beam_indices = np.array([beam.hk for beam in rparams.ivbeams])

In [None]:
# reading VIBROCC
readVIBROCC(rparams, slab, data_path / 'VIBROCC')

In [None]:
# incidence angles
rparams.THETA = 0.0
rparams.PHI = 0.0

In [None]:
# TODO: add support for energy-dependent LMAX
LMAX = rparams.LMAX.max

In [None]:
ref_energies = np.linspace(rparams.THEO_ENERGIES.start,
                           rparams.THEO_ENERGIES.stop,
                           rparams.THEO_ENERGIES.n_energies)

In [None]:
unit_cell_area = np.linalg.norm(np.cross(slab.ab_cell[:,0], slab.ab_cell[:,1]))
# In Bohr radii
unit_cell_area = unit_cell_area / BOHR**2

In [None]:
# read tensor files for non-bulk atoms
non_bulk_atoms = [at for at in slab.atlist if not at.is_bulk]

In [None]:
# assign vibration ampliudes
vibrational_amplitudes = np.array([at.site.vibamp[at.el] for at in non_bulk_atoms])

# TODO: similar to above, this implementation does currently not support
#       changes in vibrational amplitudes, mixed occupations or vacancies.

## Read Tensor files

In [None]:
read_tensor_num = lambda num: read_tensor(data_path / 'Tensors' / f'T_{num}', n_beams=len(rparams.ivbeams), n_energies=ref_energies.size, l_max=LMAX+1)
tensors = [read_tensor_num(at.num) for at in non_bulk_atoms]

ref = ReferenceData(tensors)

# delete tensors to free up memory
for t in tensors:
    del t
del tensors

In [None]:
# read phase shifts
phaseshifts_path = data_path /  'PHASESHIFTS'
_, raw_phaseshifts, _, _ = readPHASESHIFTS(slab, rparams, readfile=phaseshifts_path,
                                       check=True, ignoreEnRange=False)


In [None]:
# interpolate & assign phase shifts to atoms
site_indices = [np.argmax([at.site.isEquivalent(s) for s in slab.sitelist])
                for at in non_bulk_atoms]


# TODO: with this current implementation, we can not treat chemical
#       pertubations, nor vacancies. We need to implement this.
#       See e.g. iodeltas.generateDeltaInput()
#       (Treating vacancies requires setting zeros for that site)


phaseshifts = Phaseshifts(raw_phaseshifts, ref.energies, LMAX, site_indices)


In [None]:
delta_amp = lambda displacement: delta_amplitude(vibrational_amplitudes,
                                                 displacement,
                                                 ref,
                                                 unit_cell_area,
                                                 phaseshifts)

# Test timing and gradient

In [None]:
test_vib_amps = vibrational_amplitudes
test_displacements = np.array([[0.0, 0.0, 0.0],]*5)

operands = (test_vib_amps, test_displacements, ref, unit_cell_area, phaseshifts)

In [None]:
jit = jax.jit(delta_amplitude, static_argnums=(2, 3, 4))
grad_jit = jax.jit(jax.grad(delta_amplitude), static_argnums=(2, 3, 4))
jit_grad = jax.grad(jax.jit(delta_amplitude, static_argnums=(2, 3, 4)))

In [None]:
%time jit(*operands).block_until_ready()

In [None]:
%%timeit jit(*operands).block_until_ready()

In [None]:
print(jnp.sum(abs(jit(*operands))))

In [None]:
%time grad_jit(*operands).block_until_ready()

In [None]:
%%timeit grad_jit(*operands).block_until_ready()

In [None]:
%time jit_grad(*operands).block_until_ready()

In [None]:
%%timeit jit_grad(*operands).block_until_ready()

# Intensity

In [None]:
from src.lib_intensity import *

In [None]:
# TODO: this needs a better implementation
is_surface_atom = np.array([at.layer.num == 0 for at in non_bulk_atoms])

In [None]:
lam_prefactor = lambda displacements: intensity_prefactor(displacements,
                                                          ref,
                                                          beam_indices,
                                                          rparams.THETA, rparams.PHI,
                                                          slab.ab_cell, is_surface_atom)

In [None]:
delta_intensity = lambda displacement: sum_intensity(
    lam_prefactor(displacement), ref.ref_amps, delta_amp(displacement)
)

In [None]:
plt.figure()
for i in range(-5, 5):
    disp = np.array([[i*0.01, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4)
    plt.plot(ref_energies, delta_intensity(disp)[:,0])


In [None]:
%matplotlib inline
plt.show()

In [None]:
# Jacobian of intensity seem to be too expensive to compute

In [None]:
# # Try jacobian
disp = np.array([[0.00, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4)
jax.jacfwd(delta_intensity)(disp)[:, 0, 0, 0]

In [None]:
# plt.figure()
# for i in range(1, 10):
#     disp = np.array([[i*0.01, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4)
#     plt.plot(ref_energies, jax.jacfwd(delta_intensity)(disp)[:, 0, 0, 0])

# Interpolation

In [None]:
from src.interpolation import *

In [None]:
target_grid = jnp.linspace(ref_energies[0], ref_energies[-1], 200)
interpolator = StaticNotAKnotSplineInterpolator(ref_energies,
                                                target_grid, 3)

In [None]:
def intensity_interpolated(displacement, beam):
    raw_intensity = delta_intensity(displacement)[:,beam]
    rhs = not_a_knot_rhs(raw_intensity)
    bspline_coeffs = get_bspline_coeffs(interpolator, rhs)
    interpolated_intensity = evaluate_spline(bspline_coeffs, interpolator, 0)
    interpolated_deriv = evaluate_spline(bspline_coeffs, interpolator, 1)
    return interpolated_intensity, interpolated_deriv

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[i*0.01, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4)
    plt.plot(target_grid*HARTREE, intensity_interpolated(disp,0)[0])
plt.title("Interpolated Intensity")

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[i*0.01, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4)
    plt.plot(target_grid*HARTREE, intensity_interpolated(disp,0)[1])
plt.title("Interpolated Derivative")

In [None]:
from src.rfactor import *

In [None]:
plt.figure()
for i in range(1, 10):
    disp = np.array([[i*0.01, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4)
    plt.plot(target_grid*HARTREE, pendry_y(intensity_interpolated(disp,0)[1],intensity_interpolated(disp,0)[0], 4.5))
plt.title("Interpolated Y-function")

# Rfactor

In [None]:
from src.rfactor import *

In [None]:
no_displacement = np.array([[0.00, 0.0, 0.0],]*5)
ref_intensity = delta_intensity(no_displacement)[:,0]
R_fun = pendry_R_vs_reference(
    ref_intensity,
    interpolator,
    interpolator,
    4.5,
    3.0,
    0.5,
)

In [None]:
lam_r = lambda z: jnp.real(R_fun(delta_intensity(
    jnp.asarray([[z, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*4))[:,0]))

In [None]:
%timeit lam_r(0.01)

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 10)
R_arr = [lam_r(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R_arr)

In [None]:
jax.grad(lam_r)(0.01)

In [None]:
%timeit jax.grad(lam_r)(0.01)

In [None]:
R_grad_arr = [jax.grad(lam_r)(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R_grad_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R_grad_arr)
plt.plot(z_arr[:-1], jnp.diff(np.array(R_arr))/ (z_arr[1]-z_arr[0]))

# R2

In [None]:
ref_intensity_all_beams = delta_intensity(jnp.array([[0.0, 0.0, 0.0],]))

In [None]:
lam_r2 = lambda z: jnp.real(((delta_intensity(jnp.array([[z, 0.0, 0.0],])) - ref_intensity_all_beams)**2).sum())

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 100)
R2_arr = [lam_r2(r) for r in z_arr]
R2_grad_arr = [jax.grad(lam_r2)(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R2_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R2_grad_arr)

# Timing

In [None]:
# Function cost
estimate_function_cost(lam_r, 0.0)

In [None]:
# Function cost
estimate_function_cost(lam_r2, 0.0)

In [None]:
l = jax.jit(lam_r2).lower(0.0).compile()
%timeit l(0.0)

In [None]:
l2 = jax.jit(jax.grad(lam_r2)).lower(0.0).compile()
%timeit l2(0.0)