In [None]:
from jax import config
config.update("jax_debug_nans", False)
config.update("jax_enable_x64", True)
config.update("jax_disable_jit", False)
config.update("jax_log_compiles", False)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np

from src.constants import BOHR
from src.lib_tensors import read_tensor
from src.lib_phaseshifts import *
from src.data_structures import ReferenceData
from src.tensor_calculator import TensorLEEDCalculator

%matplotlib inline

In [None]:
jax.devices()

In [None]:
use_installable = True
if use_installable:
    import viperleed
    from viperleed.calc import symmetry
    from viperleed.calc.files import poscar
    from viperleed.calc.files import parameters
    from viperleed.calc.classes.rparams import Rparams
    from viperleed.calc.files.beams import readIVBEAMS, readOUTBEAMS
    from viperleed.calc.files.phaseshifts import readPHASESHIFTS
    from viperleed.calc.files.vibrocc import readVIBROCC
    from viperleed.calc.files.iorfactor import beamlist_to_array
    from viperleed.calc.lib.leedbase import getBeamCorrespondence
else:
    # master
    import sys
    sys.path.append('/Users/alexander/GitHub/')
    import viperleed
    from viperleed.tleedmlib.files import poscar
    from viperleed.tleedmlib.files import parameters
    from viperleed.tleedmlib.classes.rparams import Rparams
    from viperleed.tleedmlib.files.beams import readIVBEAMS, readOUTBEAMS
    from viperleed.tleedmlib.files.phaseshifts import readPHASESHIFTS
    from viperleed.tleedmlib.files.vibrocc import readVIBROCC
    from viperleed.tleedmlib.files.iorfactor import beamlist_to_array

In [None]:
data_path = Path('tests') / 'test_data' / 'Fe2O3_012'

In [None]:
# Read in data from POSCAR and PARAMETERS files
slab = poscar.read(data_path / 'POSCAR')
rparams = parameters.read(data_path / 'PARAMETERS')
parameters.interpret(rparams, slab, silent=False)
slab.full_update(rparams)

# reading IVBEAMS
# rparams.ivbeams = readIVBEAMS(data_path / 'IVBEAMS')
# beam_indices = np.array([beam.hk for beam in rparams.ivbeams])

# reading VIBROCC
readVIBROCC(rparams, slab, data_path / 'VIBROCC')

# incidence angles
rparams.THETA = 0.0
rparams.PHI = 90.0

In [None]:
LMAX = rparams.LMAX.max

In [None]:
param_energies = np.linspace(rparams.THEO_ENERGIES.start,
                           rparams.THEO_ENERGIES.stop,
                           rparams.THEO_ENERGIES.n_energies)

# Experimental Data

In [None]:
expbeams = readOUTBEAMS(data_path / 'EXPBEAMS.csv')
exp_energies, _, _, exp_intensities = beamlist_to_array(expbeams)

In [None]:
theobeams = readOUTBEAMS(data_path / 'THEOBEAMS.csv')
theo_energies, _, _, theo_intensities = beamlist_to_array(theobeams)

In [None]:
beam_indices = ((1.00000,  0.00000), (1.00000,  1.00000), (1.00000, -1.00000), (0.00000,  2.00000), (0.00000, -2.00000), (2.00000,  0.00000), (1.00000,  2.00000), (1.00000, -2.00000), (2.00000,  1.00000), (2.00000, -1.00000), (2.00000,  2.00000), (2.00000, -2.00000), (1.00000,  3.00000), (1.00000, -3.00000), (3.00000,  0.00000), (3.00000,  1.00000), (3.00000, -1.00000), (2.00000,  3.00000), (2.00000, -3.00000), (3.00000,  2.00000), (3.00000, -2.00000), (0.00000,  4.00000), (0.00000, -4.00000), (1.00000,  4.00000), (1.00000, -4.00000), (4.00000,  0.00000), (3.00000,  3.00000), (3.00000, -3.00000), (4.00000,  1.00000), (4.00000, -1.00000), (2.00000,  4.00000), (4.00000,  2.00000), (1.00000,  5.00000), (4.00000,  3.00000), (4.00000, -3.00000), (4.00000,  4.00000), (3.00000,  5.00000), (0.00000, -6.00000), )

In [None]:
corr = [np.argmax([b == t.hk for t in expbeams])for b in beam_indices]

# Tensor files

In [None]:
read_tensor_num = lambda num: read_tensor(data_path / 'Tensors' / f'T_{num}',
                                          n_beams=len(beam_indices),
                                        n_energies=param_energies.size,
                                        l_max=LMAX+1)
non_bulk_atoms = [at for at in slab.atlist if not at.is_bulk]
tensors = [read_tensor_num(at.num) for at in non_bulk_atoms]

ref = ReferenceData(tensors, fix_lmax=10)

#delete tensors to free up memory
for t in tensors:
    del t
del tensors

In [None]:
# read phase shifts
phaseshifts_path = data_path /  'PHASESHIFTS'
_, raw_phaseshifts, _, _ = readPHASESHIFTS(
    slab, rparams, readfile=phaseshifts_path, check=True, ignoreEnRange=False)


In [None]:
# TODO: site_indices needs a general solution once we implement chemical pertubations
site_indices = [0,0,1,1,1,1,1,1,1,1,1,1,2,2,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3]

In [None]:
# TODO: with this current implementation, we can not treat chemical
#       pertubations, nor vacancies. We need to implement this.
#       See e.g. iodeltas.generateDeltaInput()
#       (Treating vacancies requires setting zeros for that site)

phaseshifts = Phaseshifts(raw_phaseshifts, ref.energies, LMAX, site_indices)

# Calculator set up

In [None]:
calculator = TensorLEEDCalculator(ref, phaseshifts, slab, rparams, beam_indices)

centered_vib_amps = calculator.ref_vibrational_amps
centered_displacements = np.array([[0.0, 0.0, 0.0],]*30)

In [None]:
v0r_range = (-3.5, +3.5) # in eV
vib_amp_range = (-0.1, +0.1) # in A
geo_range = (-0.1, +0.1) # in A

In [None]:
centered_reduced_vib_amps = np.array([0.089, 0.06, 0.141, 0.115])
centered_reduced_displacements = np.array([[0.0],]*5).flatten()

In [None]:
calculator.parameter_transformer.set_vib_amp_bounds(centered_reduced_vib_amps - 0.01, centered_reduced_vib_amps + 0.01)
calculator.parameter_transformer.set_v0r_bounds(*v0r_range)
calculator.parameter_transformer.set_displacement_bounds(centered_reduced_displacements - 0.01, centered_reduced_displacements + 0.01)

In [None]:
# read tensor files for non-bulk atoms
non_bulk_atoms = [at for at in slab.atlist if not at.is_bulk]

In [None]:
# vibration constraints to change all sites together
every_second_site = site_indices[::2] # for Fe2O3, every 2nd atom is symmetry independent
vib_constraints = np.zeros(shape=(calculator.parameter_transformer.n_irreducible_vib_amps, max(every_second_site)+1))
for at_id, site in enumerate(every_second_site):
    vib_constraints[at_id, site] = 1.0

In [None]:
# geometric constraints to move only z for the topmost layer (*L(1) z) in viperleed
atoms_in_first_layer = [0, 1, 8, 9, 10]
geo_constraints = np.zeros(shape=(calculator.parameter_transformer.n_irreducible_displacements, len(atoms_in_first_layer)))
for at_id, site in enumerate(atoms_in_first_layer):
    geo_constraints[site*3, at_id] = 1.0

In [None]:
# apply constraints
calculator.parameter_transformer.apply_geo_constraints(geo_constraints)
calculator.parameter_transformer.apply_vib_constraints(vib_constraints)

In [None]:
print(calculator.parameter_transformer.info)

## Set experimental intensities as reference

In [None]:
aligned_exp_intensities = exp_intensities[:, corr]

In [None]:
# set reference point
calculator.set_experiment_intensity(aligned_exp_intensities,
                                    exp_energies)

#calculator.set_experiment_intensity(ref_int,param_energies)

## Timing

In [None]:
test_disp = np.array([[0.2, 0.0, 0.0],] + [[0.0, 0.0, 0.0],]*29)
calculator.intensity(centered_vib_amps, test_disp).block_until_ready()

In [None]:
# Compilation
%time ref_int = calculator.intensity(centered_vib_amps, centered_displacements).block_until_ready()

In [None]:
# Execution
%time ref_int = calculator.intensity(centered_vib_amps, centered_displacements).block_until_ready()

# Intensity

In [None]:
# some displacements to play with
spaced_displacements = [
    np.array([[i*0.01-0.05, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*29)
    for i in range (11)
]

In [None]:
# Use first beam for plotting
plot_beam = 1

In [None]:
%matplotlib inline
plt.figure()


ints = []
for d in spaced_displacements:
    intens = calculator.intensity(centered_vib_amps, d)
    plt.plot(param_energies, intens[:, 1])
    ints.append(intens)
plt.show()

# Interpolation

In [None]:
plt.figure()
for d in spaced_displacements:
    plt.plot(calculator.interpolated(centered_vib_amps, d, deriv_deg=0)[:, 5])
plt.title("Interpolated Intensity")

In [None]:
from src.rfactor import pendry_y
plt.figure()
for d in spaced_displacements:
    intensity = calculator.interpolated(centered_vib_amps, d, deriv_deg=0)[:, plot_beam]
    deriv = calculator.interpolated(centered_vib_amps, d, deriv_deg=1)[:, plot_beam]
    plt.plot(pendry_y(intensity, deriv, 4.5))
plt.title("Interpolated Y-function")

# Rfactor

In [None]:
test_flat_param = np.array([0.5]*10) # displacements

### $R_P$

In [None]:
# compile time
%time calculator.R_pendry_from_reduced(test_flat_param)

In [None]:
# execution time
%timeit calculator.R_pendry_from_reduced(test_flat_param)

### $\nabla R_P$

In [None]:
# compile time
%time calculator.R_pendry_grad_from_reduced(test_flat_param)

In [None]:
# execution time
%timeit calculator.R_pendry_grad_from_reduced(test_flat_param)

### ($R_P$, $\nabla R_P$)

In [None]:
# compile time
%time calculator.R_pendry_val_and_grad_from_reduced(test_flat_param,)

In [None]:
# execution time
%timeit calculator.R_pendry_val_and_grad_from_reduced(test_flat_param,)

### Non-compressed parameters

In [None]:
# compile time
%time calculator.R_pendry_val_and_grad(centered_vib_amps, centered_displacements, 0)

In [None]:
calculator.R_pendry_val_and_grad(centered_vib_amps, centered_displacements, 0)

In [None]:
R_arr = []
R_grad_arr = []
z_arr = []
for d in spaced_displacements:
    R, gradient = calculator.R_pendry_val_and_grad(centered_vib_amps, d)
    R_arr.append(R)
    R_grad_arr.append(gradient)
    z_arr.append(d[0][0])

In [None]:
plt.figure()
plt.plot(z_arr, R_arr)

In [None]:
plt.figure()
plt.plot(z_arr, [g[0,0] for g in R_grad_arr])
plt.plot(z_arr, [g[0,1] for g in R_grad_arr])

In [None]:
plt.figure()
plt.plot(z_arr, [g[0,0] for g in R_grad_arr])
plt.plot(z_arr[:-1], jnp.diff(np.array(R_arr))/ (z_arr[1]-z_arr[0]))

# R2

In [None]:
ref_intensity_all_beams = delta_intensity(jnp.array([[0.0, 0.0, 0.0],]*30))

In [None]:
lam_r2 = lambda z: jnp.real(((delta_intensity(jnp.array([[i*0.01-0.05, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*29)) - ref_intensity_all_beams)**2).sum())

In [None]:
z_arr = jnp.linspace(-0.05, 0.05, 100)
R2_arr = [lam_r2(r) for r in z_arr]
R2_grad_arr = [jax.grad(lam_r2)(r) for r in z_arr]

In [None]:
plt.figure()
plt.plot(z_arr, R2_arr)

In [None]:
plt.figure()
plt.plot(z_arr, R2_grad_arr)

# Timing

In [None]:
# Function cost
estimate_function_cost(lam_r, 0.0)

In [None]:
# Function cost
estimate_function_cost(lam_r2, 0.0)

In [None]:
l = jax.jit(lam_r2).lower(0.0).compile()
%timeit l(0.0)

In [None]:
l2 = jax.jit(jax.grad(lam_r2)).lower(0.0).compile()
%timeit l2(0.0)