In [None]:
from jax import config
config.update("jax_debug_nans", False)
config.update("jax_enable_x64", True)
config.update("jax_disable_jit", False)
config.update("jax_log_compiles", False)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np

from viperleed_jax.constants import BOHR
from viperleed_jax.lib_tensors import read_tensor
from viperleed_jax.lib_phaseshifts import *
from viperleed_jax.data_structures import ReferenceData
from viperleed_jax.tensor_calculator import TensorLEEDCalculator

%matplotlib inline

In [None]:
jax.devices()

In [None]:
use_installable = True
if use_installable:
    import viperleed
    from viperleed.calc import symmetry
    from viperleed.calc.files import poscar
    from viperleed.calc.files import parameters
    from viperleed.calc.classes.rparams import Rparams
    from viperleed.calc.files.beams import readIVBEAMS, readOUTBEAMS
    from viperleed.calc.files.phaseshifts import readPHASESHIFTS
    from viperleed.calc.files.vibrocc import readVIBROCC
    from viperleed.calc.files.iorfactor import beamlist_to_array
    from viperleed.calc.lib.leedbase import getBeamCorrespondence
else:
    # master
    import sys
    sys.path.append('/Users/alexander/GitHub/')
    import viperleed
    from viperleed.tleedmlib.files import poscar
    from viperleed.tleedmlib.files import parameters
    from viperleed.tleedmlib.classes.rparams import Rparams
    from viperleed.tleedmlib.files.beams import readIVBEAMS, readOUTBEAMS
    from viperleed.tleedmlib.files.phaseshifts import readPHASESHIFTS
    from viperleed.tleedmlib.files.vibrocc import readVIBROCC
    from viperleed.tleedmlib.files.iorfactor import beamlist_to_array

In [None]:
data_path = Path('tests') / 'test_data' / 'Fe2O3_012'

In [None]:
# Read in data from POSCAR and PARAMETERS files
slab = poscar.read(data_path / 'POSCAR')
rparams = parameters.read(data_path / 'PARAMETERS')
parameters.interpret(rparams, slab, silent=False)
slab.full_update(rparams)

# reading IVBEAMS
# rparams.ivbeams = readIVBEAMS(data_path / 'IVBEAMS')
# beam_indices = np.array([beam.hk for beam in rparams.ivbeams])

# reading VIBROCC
readVIBROCC(rparams, slab, data_path / 'VIBROCC')

# incidence angles
rparams.THETA = 0.0
rparams.PHI = 90.0

In [None]:
LMAX = rparams.LMAX.max

In [None]:
param_energies = np.linspace(rparams.THEO_ENERGIES.start,
                           rparams.THEO_ENERGIES.stop,
                           rparams.THEO_ENERGIES.n_energies)

# Experimental Data

In [None]:
expbeams = readOUTBEAMS(data_path / 'EXPBEAMS.csv')
exp_energies, _, _, exp_intensities = beamlist_to_array(expbeams)

In [None]:
theobeams = readOUTBEAMS(data_path / 'THEOBEAMS.csv')
theo_energies, _, _, theo_intensities = beamlist_to_array(theobeams)

In [None]:
beam_indices = ((1.00000,  0.00000), (1.00000,  1.00000), (1.00000, -1.00000), (0.00000,  2.00000), (0.00000, -2.00000), (2.00000,  0.00000), (1.00000,  2.00000), (1.00000, -2.00000), (2.00000,  1.00000), (2.00000, -1.00000), (2.00000,  2.00000), (2.00000, -2.00000), (1.00000,  3.00000), (1.00000, -3.00000), (3.00000,  0.00000), (3.00000,  1.00000), (3.00000, -1.00000), (2.00000,  3.00000), (2.00000, -3.00000), (3.00000,  2.00000), (3.00000, -2.00000), (0.00000,  4.00000), (0.00000, -4.00000), (1.00000,  4.00000), (1.00000, -4.00000), (4.00000,  0.00000), (3.00000,  3.00000), (3.00000, -3.00000), (4.00000,  1.00000), (4.00000, -1.00000), (2.00000,  4.00000), (4.00000,  2.00000), (1.00000,  5.00000), (4.00000,  3.00000), (4.00000, -3.00000), (4.00000,  4.00000), (3.00000,  5.00000), (0.00000, -6.00000), )

In [None]:
corr = [np.argmax([b == t.hk for t in expbeams])for b in beam_indices]

# Tensor files

In [None]:
read_tensor_num = lambda num: read_tensor(data_path / 'Tensors' / f'T_{num}',
                                          n_beams=len(beam_indices),
                                        n_energies=param_energies.size,
                                        l_max=LMAX+1)
non_bulk_atoms = [at for at in slab.atlist if not at.is_bulk]
tensors = [read_tensor_num(at.num) for at in non_bulk_atoms]

ref = ReferenceData(tensors, fix_lmax=10)

#delete tensors to free up memory
for t in tensors:
    del t
del tensors

In [None]:
# read phase shifts
phaseshifts_path = data_path /  'PHASESHIFTS'
_, raw_phaseshifts, _, _ = readPHASESHIFTS(
    slab, rparams, readfile=phaseshifts_path, check=True, ignoreEnRange=False)


In [None]:
# TODO: site_indices needs a general solution once we implement chemical pertubations
site_indices = [0,0,1,1,1,1,1,1,1,1,1,1,2,2,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3]

In [None]:
# TODO: with this current implementation, we can not treat chemical
#       pertubations, nor vacancies. We need to implement this.
#       See e.g. iodeltas.generateDeltaInput()
#       (Treating vacancies requires setting zeros for that site)

phaseshifts = Phaseshifts(raw_phaseshifts, ref.energies, LMAX, site_indices)

# Calculator set up

In [None]:
calculator = TensorLEEDCalculator(ref, phaseshifts, slab, rparams, beam_indices)

centered_vib_amps = calculator.ref_vibrational_amps
centered_displacements = np.array([[0.0, 0.0, 0.0],]*30)

In [None]:
v0r_range = (-3.5, +3.5) # in eV
vib_amp_range = (-0.1, +0.1) # in A
geo_range = (-0.1, +0.1) # in A

In [None]:
centered_reduced_vib_amps = np.array([0.089, 0.06, 0.141, 0.115])
centered_reduced_displacements = np.array([[0.0],]*5).flatten()

In [None]:
calculator.parameter_transformer.set_vib_amp_bounds(centered_reduced_vib_amps + vib_amp_range[0], centered_reduced_vib_amps + vib_amp_range[0])
calculator.parameter_transformer.set_v0r_bounds(*v0r_range)
calculator.parameter_transformer.set_displacement_bounds(centered_reduced_displacements + geo_range[0], centered_reduced_displacements + geo_range[1])

In [None]:
# read tensor files for non-bulk atoms
non_bulk_atoms = [at for at in slab.atlist if not at.is_bulk]

In [None]:
# vibration constraints to change all sites together
every_second_site = site_indices[::2] # for Fe2O3, every 2nd atom is symmetry independent
vib_constraints = np.zeros(shape=(calculator.parameter_transformer.n_irreducible_vib_amps, max(every_second_site)+1))
for at_id, site in enumerate(every_second_site):
    vib_constraints[at_id, site] = 1.0

In [None]:
# geometric constraints to move only z for the topmost layer (*L(1) z) in viperleed
atoms_in_first_layer = [0, 1, 8, 9, 10]
geo_constraints = np.zeros(shape=(calculator.parameter_transformer.n_irreducible_displacements, len(atoms_in_first_layer)))
for at_id, site in enumerate(atoms_in_first_layer):
    geo_constraints[site*3, at_id] = 1.0

In [None]:
# apply constraints
calculator.parameter_transformer.apply_geo_constraints(geo_constraints)
calculator.parameter_transformer.apply_vib_constraints(vib_constraints)

In [None]:
print(calculator.parameter_transformer.info)

In [None]:
calculator.exp_spline

In [None]:
test_flat_param = np.array([0.6, 0.5, 0.5, 0.5, 0.5, 0.501]+ [0.5]*4) # displacements

In [None]:
intens = calculator._intensity_from_reduced(test_flat_param)

In [None]:
spline_coeffs = calculator.interpolator.get_bspline_coeffs(intens)
orig = calculator.interpolator.evaluate_bspline_coeffs(spline_coeffs)

In [None]:
calculator.interpolator = CardinalNotAKnotSplineInterpolator(
            calculator.interpolator.origin_grid,
            calculator.target_grid,
            calculator.interpolation_deg # TODO: take from rparams.INTPOL_DEG
        )

In [None]:
import interpax

In [None]:
spline = interpax.CubicSpline(calculator.interpolator.origin_grid,
                     intens)

In [None]:
spline.derivative()

In [None]:
b_coeffs = calculator.interpolator.get_bspline_coeffs(intens)
orig = calculator.interpolator.evaluate_bspline_coeffs(b_coeffs)
pp_coeffs = calculator.interpolator.convert_b_to_pp_spline_coeffs(b_coeffs)
res = calculator.interpolator.evaluate_pp_spline_coeffs(pp_coeffs, 0, -1)
#res_2 = evaluate_pp_spline_coeffs(calculator.interpolator, a.copy(), 0, -1.5)

e = calculator.interpolator.target_grid

%matplotlib inline
plt.plot(e, res[0,:], marker='x')
#plt.plot(res_2[0,:])
plt.plot(e, orig[:,0])
plt.xlim([48,75])
plt.scatter(calculator.interpolator.origin_grid, intens[:,0])
#plt.ylim([0,0.01])

plt.plot(e, spline(e)[:,0])

plt.show()

In [None]:
inten

In [None]:
import interpax

In [None]:
%%timeit
b_coeffs = calculator.interpolator.get_bspline_coeffs(intens)
orig = calculator.interpolator.evaluate_bspline_coeffs(b_coeffs)
pp_coeffs = calculator.interpolator.convert_b_to_pp_spline_coeffs(b_coeffs)
res = calculator.interpolator.evaluate_pp_spline_coeffs(pp_coeffs, 0, -1)


In [None]:
interpax

In [None]:
x = calculator.interpolator.origin_grid
y = intens
def f(x, y):
    return interpax.CubicSpline(x, y, check=False)

f = jax.jit(f)

In [None]:
f(x, y)

In [None]:
%timeit f(x, y)

# Other thing

In [None]:
print(intens.shape)
_intens = intens[:, 0]

In [None]:
n_values = len(intens)
first_col_row = jnp.array([4.0, 1.0] + [0.0]* (n_values-4))
mat = 1/3*scipy.linalg.toeplitz(first_col_row, first_col_row)
mat

step = 3.0
divided_diff = jnp.diff(_intens)/step

_c = jnp.linalg.solve(mat, jnp.diff(divided_diff)/step)

In [None]:
# natural bc
coeffs = jnp.pad(_c, (1,1), 'constant')
coeffs

a_i = jnp.diff(coeffs)/(3*step)
c_i = divided_diff  - 2/3*coeffs[:-1]*step -1/3*jnp.roll(coeffs, 1)[:-1]*step
b_i = coeffs[:-1]
d_i = _intens[:-1]

In [None]:
a_i.shape, b_i.shape, c_i.shape, d_i.shape

In [None]:
pp_coeffs

In [None]:
calculator.interpolator.knots

In [None]:
plt.plot(e, res[0,:])
#plt.plot(res_2[0,:])
plt.plot(e, orig[:,0])
plt.xlim([625,650])
plt.ylim([0,0.0002])

plt.show()

In [None]:
calculator.interpolator.x_diffs[2, :20]

In [None]:
calculator.interpolator.knots

In [None]:
calculator.interpolator.knots, calculator.interpolator.origin_grid

In [None]:
calculator.interpolator.x_diffs

In [None]:
# Basis transformation from cradinal B-spline to cardinal piecewise polynomial
# in the form as given by de Boor, A practical guide to splines 1978, p.324
B_TO_PP_SPLINE_BASIS_TRANSFORMATION = np.array([[1/6, 2/3, 1/6, 0.0],
                                                [-1/2, 0.0, 1/2, 0.0],
                                                [1.0, -2.0, 1.0, 0.0],
                                                [-1.0, 3.0, -3.0, 1.0]])

def translate_cubic_pp_spline_coeffs(s):
    """Returns a transformation matrix that translates spline coeffiecients.

    The return transformation can be applied to the coefficients (a,b,c,d) of a
    cubic (spline) in the piecewise-polynomial basis, i.e.:
    f_i(x) = a*x**3 + b*x**2 + c*x + d
    The resulting set of coefficients yields the coefficents for g(x) = f(x-s).
    Note that for s->0, the transformation approaches unity.
    Note also that for a piecewise polynomial splines, this shift in only valid
    while x+s is in the same knot-interval as the x.

    Parameters
    ----------
    d : float
        The amount to translate the coefficients.

    Returns
    -------
    ndarray
        The transformation in the form of a (4x4) matrix.
    """
    return np.array([[1.0, 0.0, 0.0, 0.0],
                     [3*s, 1.0, 0.0, 0.0],
                     [3*s**2, 2*s, 1.0, 0.0],
                     [s**3, s**2, s, 1.0]])


In [None]:
convert_b_to_pp_spline_coeffs(c, 0.5).shape

In [None]:
from functools import partial
@partial(jax.vmap, in_axes=(1, None))
def convert_b_to_pp_spline_coeffs(bspline_coeffs, step):
    """Converts B-spline to piecewise polynomial coefficents."""

    a = jnp.convolve(bspline_coeffs,
                    B_TO_PP_SPLINE_BASIS_TRANSFORMATION[3, :])[2:-1]/step**3
    a  =a/6

    b = jnp.convolve(bspline_coeffs,
                    B_TO_PP_SPLINE_BASIS_TRANSFORMATION[2, :])[2:-1]/step**2
    b = (b-6*a)/2

    c = jnp.convolve(bspline_coeffs,
                    B_TO_PP_SPLINE_BASIS_TRANSFORMATION[1, :])[2:-1]/step
    c = c- 3*a -2*b

    d = jnp.convolve(bspline_coeffs,
                    B_TO_PP_SPLINE_BASIS_TRANSFORMATION[0, :])[2:-1]
    d = d - a - b- c
    return jnp.array([a, b, c, d])

In [None]:
def translate_bspline_coeffs(bspline_coeffs, shift):
    piecewise_translator = translate_cubic_pp_spline_coeffs(shift)

    transformation = (jnp.linalg.inv(B_TO_PP_SPLINE_BASIS_TRANSFORMATION)
                        @piecewise_translator
                        @B_TO_PP_SPLINE_BASIS_TRANSFORMATION)

    #trafo_eigen_vec = jnp.linalg.eig(transformation)[1][0,:]
    trafo_eigen_vec = transformation[1,:]
    trafo_eigen_vec = trafo_eigen_vec

    translated_bspline_coeffs = np.array(
        [np.convolve(bspline_coeffs[:, beam], trafo_eigen_vec, 'full')
        for beam in range(bspline_coeffs.shape[1])]
    ).swapaxes(0,1)
    # remove added dummy coeffs from convolution
    translated_bspline_coeffs = translated_bspline_coeffs[1:-2]
    return translated_bspline_coeffs, transformation

In [None]:
def evaluate_bspline_coeffs(self, bspline_coeffs, knot_shift=0, deriv_order=0):
    """Evaluate spline using the De Boor and the B-spline coefficients"""
    # Extract the relevant coefficients for each interval
    lower_indices = self.intervals - self.intpol_deg
    lower_indices = jnp.roll(lower_indices, knot_shift)
    coeff_indices = lower_indices.reshape(-1,1) + jnp.arange(self.intpol_deg+1)
    coeff_subarrays = bspline_coeffs[coeff_indices]

    # Element-wise multiplication between coefficients and de_boor values
    # then sum over basis functions
    return jnp.einsum('ijb,ji->ib',
                        coeff_subarrays,
                        jnp.roll(calculator.interpolator.de_boor_coeffs[deriv_order],
                                knot_shift))

In [None]:
delta = -0.1
step = 0.5

knot_shift, frac_shift = divmod(delta, step)
knot_shift = int(knot_shift)
c = calculator.interpolator.get_bspline_coeffs(intens)
s, T = translate_bspline_coeffs(c, -frac_shift)

In [None]:
orig = evaluate_bspline_coeffs(calculator.interpolator, c)
shifted = evaluate_bspline_coeffs(calculator.interpolator, s, knot_shift=knot_shift)

%matplotlib inline
plt.plot(orig[:,0] / np.max(orig[:,0]))
plt.plot(shifted[:,0] / np.max(shifted[:,0]))
plt.xlim([25,40])
plt.ylim([0.9,1.01])
plt.show()


In [None]:
orig[:,0]/shifted[:,0]

In [None]:
c.shape, s.shape

In [None]:
orig = calculator.interpolator.evaluate_bspline_coeffs(c)

In [None]:
spline_coeffs = get_bspline_coeffs(calculator.interpolator, intens)

In [None]:
delta = 2.5
step = 3.0

knot_shift, frac_shift = divmod(delta, step)
knot_shift = int(knot_shift)

intervals = np.arange(201)
coeffs = calculator.interpolator.de_boor_coeffs[0, ...] # non-derivative coeffs


first_order = coeffs[3-1, intervals] - coeffs[3-1, intervals+1]
second_order = 0.5*coeffs[3-2, intervals] + coeffs[3-2, intervals+1 ] + 0.5*coeffs[3-2, intervals+2]
third_order = np.full_like(first_order, fill_value=1.)

corr = frac_shift/step * first_order + (frac_shift/step)**2 * second_order + (frac_shift/step)**3 * third_order
corr = 1/3*np.array(corr)

In [None]:
frac_shift

In [None]:
interpolator.de_boor_coeffs.shape

In [None]:
corr = frac_shift/step

In [None]:
shifted_spline_coeffs = np.roll(spline_coeffs, knot_shift, axis=0)
# invalidate values:
nan_mask = jnp.full_like(corr, fill_value=1.0)
#nan_mask = nan_mask.at[:knot_shift].set(jnp.nan)

In [None]:
alpha_t = jnp.einsum('ib,->ib', shifted_spline_coeffs, (1+corr))
lower_indices = calculator.interpolator.intervals - calculator.interpolator.intpol_deg
coeff_indices = lower_indices.reshape(-1,1) + jnp.arange(calculator.interpolator.intpol_deg+1)
coeff_subarrays = alpha_t[coeff_indices]

In [None]:
deriv_order = 0

res = jnp.einsum('ijb,ji->ib',
                      coeff_subarrays,
                      calculator.interpolator.de_boor_coeffs[deriv_order])

In [None]:
b_to_pp_matrix = np.array([[1/6, 2/3, 1/6, 0.0],
                           [-1/2, 0.0, 1/2, 0.0],
                           [1.0, -2.0, 1.0, 0.0],
                           [-1.0, 3.0, -3.0, 1.0]])

In [None]:
def shift_matrix(d):
    return np.array([[1, 0, 0, 0],
                     [3*d, 1, 0, 0],
                     [3*d**2, 2*d, 1, 0],
                     [d**3, d**2, d, 1]])


In [None]:
np.linalg.eig(R)[1][:,0]

In [None]:
delta = -15.3
step = 0.5

knot_shift, frac_shift = divmod(delta, step)
knot_shift = int(knot_shift)

R = np.linalg.inv(b_to_pp_matrix)@shift_matrix(frac_shift)@b_to_pp_matrix

R_vec = np.linalg.eig(R)[1][:,0]
t_coeffs = np.array([np.convolve(spline_coeffs[:, beam], R_vec)
                    for beam in range(spline_coeffs.shape[1])])

In [None]:
b = spline_coeffs[:, 0]
b.shape

In [None]:
spline_coeffs = calculator.interpolator.get_bspline_coeffs(intens)

In [None]:
lower_indices = calculator.interpolator.intervals - calculator.interpolator.intpol_deg
lower_indices = np.roll(lower_indices, knot_shift)
coeff_indices = lower_indices.reshape(-1,1) + jnp.arange(calculator.interpolator.intpol_deg+1)
coeff_subarrays = t_coeffs.swapaxes(0,1)[coeff_indices, :]


deriv_order = 0

res = jnp.einsum('ijb,ji->ib',
                      coeff_subarrays,
                      np.roll(calculator.interpolator.de_boor_coeffs[deriv_order],
                              knot_shift))

In [None]:
plt.plot(res[:,0])
plt.plot(orig[:,0])

In [None]:
from matplotlib import animation

In [None]:
t_coeffs.shape, coeff_indices

In [None]:
fig = plt.figure()


n_frames = 201

def update(delta_i):
    plt.cla()
    delta = -9 + 13.5/n_frames*delta_i
    step = 0.5

    knot_shift, frac_shift = divmod(delta, step)
    knot_shift = int(knot_shift)
    #frac_shift += 3.0
    R = np.linalg.inv(b_to_pp_matrix)@shift_matrix(frac_shift)@b_to_pp_matrix

    R_vec = R[3,:]
    #R_vec = R_vec/np.linalg.norm(R)
    t_coeffs = np.array([np.convolve(spline_coeffs[:, beam], R_vec)
                        for beam in range(spline_coeffs.shape[1])])

    lower_indices = calculator.interpolator.intervals - calculator.interpolator.intpol_deg
    lower_indices = np.roll(lower_indices, knot_shift)
    coeff_indices = lower_indices.reshape(-1,1) + jnp.arange(calculator.interpolator.intpol_deg+1)

    coeff_subarrays = t_coeffs.swapaxes(0,1)[coeff_indices, :]

    deriv_order = 0

    res = jnp.einsum('ijb,ji->ib',
                        coeff_subarrays,
                        np.roll(calculator.interpolator.de_boor_coeffs[deriv_order],
                                knot_shift))
    
    plt.plot(calculator.interpolator.target_grid, res[:,0]/np.max(res[:,0]), marker='')
    plt.plot(calculator.interpolator.target_grid, orig[:, 0]/np.max(orig[:,0]))
    plt.scatter(calculator.interpolator.origin_grid, intens[:,0], marker='x')

    plt.xlim([290,340])
    plt.ylim([0, 0.1])

ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=15)
ani.save('animate_bspline.gif')

In [None]:
spline_coeffs.shape

In [None]:
plt.plot(res[:,0])
plt.plot(orig[:,0])

plt.xlim([200, 400])
plt.ylim([0, 0.002])

In [None]:
R.trace()

In [None]:
spline_coeffs

In [None]:
spline_coeffs.shape

In [None]:
a = np.convolve(spline_coeffs[:, 0], b_to_pp_matrix[3, :])[2:-1]/step**3
a  =a/6

b = np.convolve(spline_coeffs[:, 0], b_to_pp_matrix[2, :])[2:-1]/step**2
b = (b-6*a)/2

c = np.convolve(spline_coeffs[:, 0], b_to_pp_matrix[1, :])[2:-1]/step
c = c- 3*a -2*b

d = np.convolve(spline_coeffs[:, 0], b_to_pp_matrix[0, :])[2:-1]
d = d - a - b- c



In [None]:
interpolator = calculator.interpolator

In [None]:
x_diffs = interpolator.knots[interpolator.intervals+1] - interpolator.target_grid

In [None]:
intervals = interpolator.intervals

In [None]:
r = a[intervals]*x_diffs**3 + b[intervals]*x_diffs**2 + c[intervals]*x_diffs + d[intervals]

In [None]:
plt.plot(a[intervals]*x_diffs**3 + b[intervals]*x_diffs**2 + c[intervals]*x_diffs + d[intervals])

In [None]:
orig = calculator.interpolator.evaluate_bspline_coeffs(spline_coeffs)

In [None]:
pp_coeffs = np.array([a,b,c,d])

In [None]:
delta = 0.4
step = 0.5

knot_shift, frac_shift = divmod(delta, step)
knot_shift = int(knot_shift)


In [None]:
intervals

In [None]:
intervals = jnp.clip(jnp.searchsorted(interpolator.knots,
                                              interpolator.target_grid,
                                              side='left'),
                        a_min=interpolator.intpol_deg + 1,
                        a_max=interpolator.knots.size - interpolator.intpol_deg - 1) - 1

In [None]:
_a, _b, _c, _d = np.roll(shift_matrix(-frac_shift)@pp_coeffs, shift=-knot_shift)

In [None]:


_a, _b, _c, _d = np.roll(shift_matrix(-frac_shift)@pp_coeffs, shift=-knot_shift)


#plt.plot(calculator.interpolator.target_grid, a[intervals]*x_diffs**3 + b[intervals]*x_diffs**2 + c[intervals]*x_diffs + d[intervals], marker='o')
plt.plot(calculator.interpolator.target_grid, _a[intervals]*x_diffs**3 + _b[intervals]*x_diffs**2 + _c[intervals]*x_diffs + _d[intervals], marker='o')
plt.plot(calculator.interpolator.target_grid, orig[:, 0])
plt.scatter(calculator.interpolator.origin_grid, intens[:,0], marker='x')

plt.xlim([60,100])
plt.ylim([0.002, 0.005])

In [None]:
fig = plt.figure()

n_frames = 151

def update(delta_i):
    plt.cla()
    delta = -4.5 + 9/n_frames*delta_i
    step = 0.5

    knot_shift, frac_shift = divmod(delta, step)
    knot_shift = int(knot_shift)
    _intervals = np.roll(intervals, -knot_shift)
    #_intervals = intervals
    _x_diffs = np.roll(x_diffs, -knot_shift)
    #_x_diffs = x_diffs
    
    _a, _b, _c, _d = np.roll(shift_matrix(-frac_shift)@pp_coeffs, shift=0)
    #_a, _b, _c, _d = np.roll(shift_matrix(-frac_shift)@pp_coeffs, shift=-knot_shift)
    #x_diff = x_diffs + frac_shift
    #plt.plot(calculator.interpolator.target_grid, a[intervals]*x_diffs**3 + b[intervals]*x_diffs**2 + c[intervals]*x_diffs + d[intervals])
    plt.plot(calculator.interpolator.target_grid, _a[_intervals]*_x_diffs**3 + _b[_intervals]*_x_diffs**2 + _c[_intervals]*_x_diffs + _d[_intervals], marker='o')
    plt.plot(calculator.interpolator.target_grid, orig[:, 0])
    plt.scatter(calculator.interpolator.origin_grid, intens[:,0], marker='x')

    plt.xlim([60,180])
    plt.ylim([0, 0.007])

ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=20)
ani.save('animate.gif')

In [None]:
shift_matrix(0.001)

In [None]:
g = calculator.interpolator.origin_grid

p = a + b + c +d

In [None]:
plt.plot(p)

In [None]:
shifted_spline_coeffs

In [None]:
corr

In [None]:
calculator.comp_energies

In [None]:
from viperleed_jax.interpolation import evaluate_spline

In [None]:
def evaluate_spline(spline_coeffs, interpolator, deriv_order=0):
    """Evaluate the spline using the De Boor coefficients and the B-spline coefficients"""
    # Extract the relevant coefficients for each interval
    lower_indices = interpolator.intervals - interpolator.intpol_deg
    coeff_indices = lower_indices.reshape(-1,1) + jnp.arange(interpolator.intpol_deg+1)
    coeff_subarrays = spline_coeffs[coeff_indices]

    # Element-wise multiplication between coefficients and de_boor values, sum over basis functions
    return jnp.einsum('ijb,ji->ib',
                      coeff_subarrays,
                      interpolator.de_boor_coeffs[deriv_order])

In [None]:
evaluate_spline

## Set experimental intensities as reference

In [None]:
aligned_exp_intensities = exp_intensities[:, corr]
# set reference point
calculator.set_experiment_intensity(aligned_exp_intensities,
                                    exp_energies)

In [None]:
y = aligned_exp_intensities
x = exp_energies

In [None]:
import interpax

In [None]:
%matplotlib inline
plt.plot(calculator.target_grid, calculator.exp_spline(calculator.target_grid+1)[:,1])
plt.plot(calculator.target_grid, calculator.exp_spline(calculator.target_grid+50)[:,1])
plt.plot(exp_energies, aligned_exp_intensities[:,1])
plt.show()

# Intensity

In [None]:
# Use first beam for plotting
plot_beam = 1

In [None]:
# some displacements to play with
spaced_displacements = [
    np.array([[i*0.01-0.05, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*29)
    for i in range (11)
]

# Interpolation

In [None]:
%matplotlib inline

In [None]:
plt.figure()
for d in spaced_displacements:
    plt.plot(calculator.interpolated(centered_vib_amps, d, deriv_deg=0)[:, 5])
plt.title("Interpolated Intensity")

In [None]:
from viperleed_jax.rfactor import pendry_y
plt.figure()
for d in spaced_displacements:
    intensity = calculator.interpolated(centered_vib_amps, d, deriv_deg=0)[:, plot_beam]
    deriv = calculator.interpolated(centered_vib_amps, d, deriv_deg=1)[:, plot_beam]
    plt.plot(pendry_y(intensity, deriv, 4.5))
plt.title("Interpolated Y-function")
plt.show()

In [None]:
from viperleed_jax.rfactor import y_ms
plt.figure()
for d in spaced_displacements:
    intensity = calculator.interpolated(centered_vib_amps, d, deriv_deg=0)[:, plot_beam]
    deriv = calculator.interpolated(centered_vib_amps, d, deriv_deg=1)[:, plot_beam]
    deriv2 = calculator.interpolated(centered_vib_amps, d, deriv_deg=2)[:, plot_beam]
    plt.plot(y_ms(intensity, deriv, deriv2, 4.5, 0.5))
plt.title("Interpolated Y-function")
plt.show()

# Rfactor

In [None]:
test_flat_param = np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.500]+ [0.5]*4) # displacements

In [None]:
calculator.parameter_transformer.unflatten_parameters(test_flat_param)

In [None]:
v = jnp.array([0.089, 0.089, 0.06 , 0.06 , 0.06 , 0.06 , 0.06 , 0.06 , 0.06 ,
        0.06 , 0.06 , 0.06 , 0.141, 0.141, 0.115, 0.115, 0.115, 0.115,
        0.115, 0.115, 0.115, 0.115, 0.115, 0.115, 0.115, 0.115, 0.115,
        0.115, 0.115, 0.115])
calculator.R(v, jnp.zeros((30,3)))

In [None]:
# Pendry R-factor
calculator.set_rfactor('pendry')

### $R_P$

In [None]:
# compile time
%time calculator.R_from_reduced(test_flat_param)

In [None]:
# excecution time
%timeit calculator.R_from_reduced(test_flat_param)

In [None]:
3.58 s ± 396 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [None]:
config.update("jax_debug_nans", True)
config.update("jax_disable_jit", True)

In [None]:
calculator._delta_amplitude_from_reduced(test_flat_param).sum().real()

In [None]:
jax.grad(lambda p:calculator._delta_amplitude_from_reduced(p).sum().real)(test_flat_param)

### $\nabla R_P$

In [None]:
# compile time
%time calculator.R_grad_from_reduced(test_flat_param)

In [None]:
from viperleed_jax.lib_math import cart_to_polar

cart_to_polar(np.array([0.0, 0.0, 0.0]))

In [None]:
# execution time
# %time calculator.R_grad_from_reduced(test_flat_param)

### ($R_P$, $\nabla R_P$)

In [None]:
# compile time
%time calculator.R_val_and_grad_from_reduced(test_flat_param,)

In [None]:
# execution time
%timeit calculator.R_val_and_grad_from_reduced(test_flat_param,)

# Compare R-factors

In [None]:
# move topmost atom

params = [np.array([0.5] + [0.1*i] + [0.5]*8) for i in range(11)]

In [None]:
def normalize(R):
    return [(r - min(R)) / (max(R) - min(R))  for r in R]

%matplotlib inline
plt.figure()
plt.plot([p[1] for p in params], normalize(R_P), marker='x', label='Pendry')
plt.plot([p[1] for p in params], normalize(R_2), marker='x', label='R2')
plt.plot([p[1] for p in params], normalize(R_MS), marker='x', label='MS')
plt.plot([p[1] for p in params], normalize(R_1), marker='x', label='R1')
#plt.plot([p for p in params], R_1, label='R1')
plt.legend()

In [None]:
R_MS

In [None]:
first_atom_geo_index = 5
first_site_vib_index = 1

### Non-compressed parameters

In [None]:
# compile time
%time calculator.R_pendry_val_and_grad(centered_vib_amps, centered_displacements, 0)

In [None]:
calculator.R_pendry_val_and_grad(centered_vib_amps, centered_displacements, 0)

In [None]:
R_arr = []
R_grad_arr = []
z_arr = []
for d in spaced_displacements:
    R, gradient = calculator.R_pendry_val_and_grad(centered_vib_amps, d)
    R_arr.append(R)
    R_grad_arr.append(gradient)
    z_arr.append(d[0][0])

In [None]:
plt.figure()
plt.plot(z_arr, R_arr)

In [None]:
plt.figure()
plt.plot(z_arr, [g[0,0] for g in R_grad_arr])
plt.plot(z_arr, [g[0,1] for g in R_grad_arr])

In [None]:
plt.figure()
plt.plot(z_arr, [g[0,0] for g in R_grad_arr])
plt.plot(z_arr[:-1], jnp.diff(np.array(R_arr))/ (z_arr[1]-z_arr[0]))

# Optimization

In [None]:
normalization_bounds = scipy.optimize.Bounds(lb=[0.0,]*10, ub=[1.0,]*10)

In [None]:
def callback_print(intermediate_result):
    print(f"R={intermediate_result.fun}")
    if intermediate_result.fun <= 0.165:
        raise StopIteration

In [None]:
%%time
centered_flat_param = np.concatenate([np.array([0.0]), # v0r
                                [0.089, 0.06, 0.141, 0.115], # vib_amps
                                np.zeros(5)] # displacements
                                 )
start_flat_param = np.concatenate([np.array([1.0]), # v0r
                                [0.089, 0.06, 0.141, 0.115], # vib_amps
                                [0.05, -0.03, -0.02, 0.06, 0.02]] # displacements
                                 )

scipy.optimize.minimize(
    fun=calculator.R_pendry_val_and_grad_from_flat,
    jac=True,
    x0=start_flat_param,
    method="BFGS", # recommended method for expensive functions with access to gradient
    options={'disp':True, 'return_all':True, },
    callback=callback_print
)

