In [None]:
import jax
from jax import config
config.update("jax_debug_nans", False)
config.update("jax_enable_x64", True)
config.update("jax_disable_jit", False)
config.update("jax_log_compiles", False)
import jax.numpy as jnp

from pathlib import Path
import viperleed

from matplotlib import pyplot as plt
import numpy as np

%matplotlib inline

jax.devices()

In [None]:
%%time
from fe2O3_r_cut_example import calculator

In [None]:
test_flat_param = np.array([0.1]*10) # displacements

# MS Rfactor

In [None]:
# Use first beam for plotting
plot_beam = 1

In [None]:
centered_vib_amps = calculator.ref_vibrational_amps
centered_displacements = np.array([[0.0, 0.0, 0.0],]*30)

In [None]:
# some displacements to play with
spaced_displacements = [
    np.array([[i*0.01-0.05, 0.0, 0.0],] +[[0.00, 0.0, 0.0],]*29)
    for i in range (11)
]

In [None]:
%matplotlib inline
from src.rfactor import y_ms, pendry_y

In [None]:
# Intensity and derivatives
fig, axs = plt.subplots(5,1, figsize=(8, 12))

intensity = calculator.interpolated(centered_vib_amps, centered_displacements, deriv_deg=0)[:, plot_beam]
deriv = calculator.interpolated(centered_vib_amps, centered_displacements, deriv_deg=1)[:, plot_beam]
deriv2 = calculator.interpolated(centered_vib_amps, centered_displacements, deriv_deg=2)[:, plot_beam]

y_ms = y_ms(intensity, deriv, deriv2, 4.5, 0.5)
y_p = pendry_y(intensity, deriv, 4.5)

axs[0].plot(calculator.target_grid, intensity)
axs[0].set_title("Intensity")

axs[1].plot(calculator.target_grid, deriv)
axs[1].set_title("1st derivative")

axs[2].plot(calculator.target_grid, deriv2)
axs[2].set_title("2nd derivative")

axs[3].plot(calculator.target_grid, y_ms)
axs[3].set_title("Y_ms")

axs[4].plot(calculator.target_grid, y_p)
axs[4].set_title("Y_Pendry")

In [None]:
# move topmost atom

params = [np.array([0.5] + [0.1*i] + [0.5]*8) for i in range(11)]

In [None]:
# Pendry
calculator.set_rfactor('pendry')
R_P = [calculator.R_from_reduced(p) for p in params]


In [None]:
# R2
calculator.set_rfactor('R2')
R_2 = [calculator.R_from_reduced(p) for p in params]

In [None]:
# MS
calculator.set_rfactor('MS')
R_MS = [calculator.R_from_reduced(p) for p in params]

In [None]:
# R1
calculator.set_rfactor('R1')
R_1 = [calculator.R_from_reduced(p) for p in params]

In [None]:
# RZJ
calculator.set_rfactor('ZJ')
R_ZJ = [calculator.R_from_reduced(p) for p in params]

In [None]:
def normalize(R):
    return [(r - min(R)) / (max(R) - min(R))  for r in R]

%matplotlib inline
plt.figure()
plt.plot([p[1] for p in params], normalize(R_P), marker='x', label='Pendry')
plt.plot([p[1] for p in params], normalize(R_2), marker='x', label='R2')
plt.plot([p[1] for p in params], normalize(R_MS), marker='x', label='MS')
plt.plot([p[1] for p in params], normalize(R_1), marker='x', label='R1')
plt.plot([p[1] for p in params], normalize(R_ZJ), marker='x', label='RZJ')
plt.legend()

In [None]:
vib_steps = 10
geo_steps = 10

vib_id = 1
geo_id = 5

R_surf = np.full((geo_steps, vib_steps), fill_value=np.nan)

vibs = np.linspace(0, 1, vib_steps)
geos = np.linspace(0, 1, geo_steps)

def get_R_surf(calculator):
    R_surf = np.full((geo_steps, vib_steps), fill_value=np.nan)

In [None]:
%matplotlib inline
plt.plot([p[1] for p in params], normalize(R_ZJ), marker='x', label='RZJ')
plt.show()