# Reference values for Unit Tests

This file is used to create and store reference values used in unit test.
These values should ONLY be changed after manually confirming that the results
are correct and reliable.

In [None]:
from pathlib import Path

import numpy as np
import jax

from viperleed_jax.propagator import calc_propagator

In [None]:
test_data_path = Path() / "tests" / "test_data" / "reference_values"
test_data_path.is_dir()

## T-matrix values

In [None]:
# TODO

## Propagtor values

In [None]:
# Testing displacement vectors
TEST_DISP_VECTORS = (
    np.array([0.1, 0.0, 0.0]),
    np.array([0.0, 0.1, 0.0]),
    np.array([0.0, 0.0, 0.1]),
    np.array([0.1, 0.1, 0.0]),
    np.array([0.1, 0.0, 0.1]),
    np.array([0.1, 0.1, 0.1]),
    np.array([-0.1, 0.0, 0.0]),
    np.array([-0.1, -0.1, -0.]),
    np.array([0.0, -0.1, 0.1]),
    np.array([1.0, 0.0, 0.0]),
    np.array([1.0, 2.0, 3.0]),
    np.array([1e-3, 1e-3, 1e-3]),
    np.array([1e-4, 1e-4, 1e-4]),
    np.array([1e-5, 1e-5, 1e-5]),
    np.array([1e-6, 1e-6, 1e-6]),
)

In [None]:
# calculate energy jacobians - limit to l_max=8 for memory reasons
jac_energy_propagator = jax.jacrev(calc_propagator, argnums=2, holomorphic=True)
reference_energy_jac_values_l_max_8_e_1e0j_v_imag_1e0 = np.array(
    [jac_energy_propagator(8, vec, 1.0+0j, 1.0) for vec in TEST_DISP_VECTORS]
)


In [None]:
# calculate displacement jacobians - limit to l_max=8 for memory reasons
abs_calc_propagator = lambda l_max, vec, e, v_imag: abs(calc_propagator(l_max, vec, e, v_imag))
jac_disp_propagator = jax.jacrev(abs_calc_propagator, argnums=1)
reference_displacement_jac_values_l_max_8_e_1e0_v_imag_1e0 = np.array(
    [jac_disp_propagator(8, vec, 1.0, 1.0) for vec in TEST_DISP_VECTORS]
)

In [None]:
propagtor_reference_file_name = 'propagator_reference_values.npz'
propagator_reference_file_path = test_data_path / propagtor_reference_file_name

In [None]:
np.savez(propagator_reference_file_path,
         reference_values_l_max_18_e_1e0_v_imag_1e0=reference_values_l_max_18_e_1e0_v_imag_1e0,
         reference_energy_jac_values_l_max_8_e_1e0j_v_imag_1e0=reference_energy_jac_values_l_max_8_e_1e0j_v_imag_1e0,
         reference_displacement_jac_values_l_max_8_e_1e0_v_imag_1e0=reference_displacement_jac_values_l_max_8_e_1e0_v_imag_1e0)

# temp

In [None]:
import jax
import jax.numpy as jnp

from viperleed_jax.dense_quantum_numbers import DENSE_QUANTUM_NUMBERS
from viperleed_jax.gaunt_coefficients import CSUM_COEFFS
from viperleed_jax.lib_math import bessel, HARMONY, safe_norm, EPS
from viperleed_jax.atomic_units import kappa
def calc_propagator(LMAX, c, energy, v_imag):
    c_norm = safe_norm(c)

    BJ = bessel(kappa(energy, v_imag) * c_norm, 2*LMAX)
    YLM = HARMONY(c, LMAX)  # TODO: move outside since it's not energy dependent

    dense_m_2d = DENSE_QUANTUM_NUMBERS[LMAX][:, :, 2]
    dense_mp_2d =  DENSE_QUANTUM_NUMBERS[LMAX][:, :, 3]

    # AI: I don't fully understand this, technically it should be MPP = -M - MP
    dense_mpp = dense_mp_2d - dense_m_2d

    # pre-computed coeffs, capped to LMAX
    capped_coeffs = CSUM_COEFFS[:2*LMAX+1, :(LMAX+1)**2, :(LMAX+1)**2]

    def propagator_lpp_element(lpp, running_sum):
        bessel_values = BJ[lpp]
        ylm_values = YLM[lpp*lpp+lpp-dense_mpp]
        # Equation (34) from Rous, Pendry 1989
        return running_sum + bessel_values * ylm_values * capped_coeffs[lpp,:,:] #* (abs(dense_mpp) <= lpp) * (1j)**(-lpp)

    # we could skip some computations because some lpp are guaranteed to give
    # zero contributions, but this would need a way around the non-static array
    # sizes

    # This is the propagator from the origin to C
    propagator = jax.lax.fori_loop(0, LMAX*2+1, propagator_lpp_element,
                             jnp.zeros(shape=((LMAX+1)**2, (LMAX+1)**2),
                                       dtype=jnp.complex128))
    propagator *= 4*jnp.pi
    return jnp.where(c_norm >= EPS*100, propagator, jnp.identity((LMAX+1)**2))


In [None]:
def rot_matrix(theta):
    """Return a 2D rotation matrix."""
    return np.array([
        [np.cos(theta), -np.sin(theta)],
        [np.sin(theta), np.cos(theta)]
    ])

In [None]:
mir_matrix = np.array([
    [1, 0., 0],
    [0, -1,  0],
    [0., 0., 1]
])


In [None]:
vec = np.array([1.0, 2.0, 3.0])

In [None]:
calc_propagator(8, vec, 1.0, 1.0)

In [None]:
LMAX = 8
dense_m_2d = DENSE_QUANTUM_NUMBERS[LMAX][:, :, 2]
dense_mp_2d =  DENSE_QUANTUM_NUMBERS[LMAX][:, :, 3]

# AI: I don't fully understand this, technically it should be MPP = -M - MP
dense_mpp = dense_mp_2d - dense_m_2d


In [None]:
calc_propagator(8, vec, 1.0, 1.0) - (calc_propagator(8, mir_matrix @ vec, 1.0, 1.0)* (-1)**(dense_mpp)).T

In [None]:
phi = 1.
rot = rot_matrix(phi)
trafo = np.identity(3)
trafo[1:3, 1:3] = rot

vec = np.array([-0.1, -0.1, 0.0])

s = symmetry_tensor(8, rot)
rel_diff = (calc_propagator(8,  vec, 1.0, 1.0)* s / (calc_propagator(8, trafo @ vec, 1.0, 1.0))+ EPS)
abs_diff = (calc_propagator(8,  vec, 1.0, 1.0)* s - (calc_propagator(8, trafo @ vec, 1.0, 1.0)))

np.max(abs(abs_diff)), np.max(abs(rel_diff))

In [None]:
phi = 3/4*np.pi
rot = rot_matrix(phi)
trafo = np.identity(3)
trafo[1:3, 1:3] = rot

(np.log(trafo[2, 2] + 1j*trafo[2, 1])/1j).real

In [None]:
vec, trafo @ vec

In [None]:
trafo

In [None]:
my =  np.array([[-1., 0.], [0., 1.]])
mx = np.array([[1., 0.], [0., -1.]])
mxy = np.array([[0., 1.], [1., 0.]])

In [None]:
np.linalg.det(my), np.linalg.det(mx), np.linalg.det(mxy)

In [None]:
get_rot_angle(mxy @ mx)

In [None]:
np.linalg.inv(mx)

In [None]:
mx

In [None]:
get_rot_angle(my@mx)

In [None]:
import numpy as np

In [None]:
def get_rot_angle(plane_symmetry_op):
    return (np.log(plane_symmetry_op[1,1] + 1j*plane_symmetry_op[1, 0])/1j).real

In [None]:
def get_plane_symmetry_operation_rotation_angle(plane_symmetry_operation):
    """Return the rotation angle for a plane symmetry operation.

    NB: The rotation angle is calculated for the plane symmetry operation by
    embedding it in 3D space. In-plane symmetry operations (even mirror
    operations) can be converted into a rotation operation in 3D space, as the
    z-movement of linked atoms is equal.

    Parameters
    ----------
    plane_symmetry_operation : ndarray (2,2)
        In plane symmetry operation matrix.

    Returns
    -------
    float
        Rotation angle in radians.
    """
    full_rot_mat = np.identity(3)
    full_rot_mat[1:, 1:] = plane_symmetry_operation
    print(full_rot_mat)

    Kz = np.array([[0, 0, 0], [0, 0, -1], [0, 1, 0]])

    # calculate rotation angle
    cos = np.arccos((np.trace(full_rot_mat)-1)/2)
    sin = np.arcsin(np.trace(Kz @ full_rot_mat)/2)
    return np.arctan2(np.trace(Kz @ full_rot_mat)/2, (np.trace(full_rot_mat)-1)/2)

    return cos, sin

In [None]:
def rot_matrix(theta):
    """Return a 2D rotation matrix."""
    return np.array([
        [np.cos(theta), -np.sin(theta)],
        [np.sin(theta), np.cos(theta)]
    ])

In [None]:
get_plane_symmetry_operation_rotation_angle(rot_matrix(7*np.pi/4))

In [None]:
7*np.pi/4 % (2*np.pi)

In [None]:
get_rot_angle(rot_matrix(7.9*np.pi/4)) % (2*np.pi)

In [None]:
def cart_to_polar(c):
    """Converts cartesian coordinates to polar coordinates.

    Note, this function uses safe division to avoid division by zero errors, 
    and gives defined results and gradients for all inputs, EXCEPT for
    c = (0.0, 0.0, 0.0)."""
    z, x, y = c  # LEED coordinates

    x_y_norm = jnp.hypot(x, y)
    r = jnp.linalg.norm(c)
    theta = 2*jnp.arctan(
        _divide_zero_safe(x_y_norm, (jnp.hypot(x_y_norm, z)+z), (1/EPS) * (1 - jnp.sign(z)))
    )

    # forces phi to 0 on theta=0 axis (where phi is undefined)
    phi = 2*jnp.arctan(
        _divide_zero_safe(y, (x_y_norm+x)+EPS, 0.0)
    )

    return r, theta, phi


In [None]:
from viperleed_jax.lib_math import _divide_zero_safe
cart_to_polar(jnp.array([1., 0., 0.]))

In [None]:
cart_to_polar(jnp.array([0., 1., 1.]))

In [None]:
cart_to_polar(jnp.array([0., 1., 0.])), cart_to_polar(jnp.array([0., -1., 0.]))

In [None]:
cart_to_polar(jnp.array([1., 1., -1.])), cart_to_polar(jnp.array([1., -1., -1.]))

In [None]:
2.35619447 - 0.78539816

In [None]:
(s.real - s.real.T).sum()

In [None]:
safe_norm(trafo @ vec)

In [None]:
#diff = (calc_propagator(8,  vec, 1.0, 1.0) - (calc_propagator(8, trafo @ vec, 1.0, 1.0)* symmetry_tensor(8, rot)))
#abs(diff)

In [None]:
vec, trafo @ vec

In [None]:
(calc_propagator(8, vec, 1.0, 1.0))[:4, :4], (calc_propagator(8, trafo @ vec, 1.0, 1.0) * symmetry_tensor(8, rot))[:4, :4]

In [None]:
from viperleed_jax.dense_quantum_numbers import DENSE_QUANTUM_NUMBERS
from jax import numpy as jnp
def symmetry_tensor(l_max, plane_symmetry_operation):
    """_summary_

    Parameters
    ----------
    l_max : int
        Maximum angular momentum quantum number. Compiled as static argument.
    plane_symmetry_operation : 2x2 array
        The in-plane symmetry operation matrix.

    Returns
    -------
    jax.numpy.ndarray, shape=((l_max+1)**2, (l_max+1)**2)
        Tensor that can be applied element-wise to the propagator to apply the
        symmetry operation.
    """

    dense_l_2d = DENSE_QUANTUM_NUMBERS[l_max][:, :, 0]
    dense_lp_2d = DENSE_QUANTUM_NUMBERS[l_max][:, :, 1]
    dense_m_2d = DENSE_QUANTUM_NUMBERS[l_max][:, :, 2]
    dense_mp_2d =  DENSE_QUANTUM_NUMBERS[l_max][:, :, 3]

    # AI: I don't fully understand this, technically it should be MPP = -M - MP
    dense_mpp = dense_mp_2d - dense_m_2d

    plane_rotation_angle = (np.log(plane_symmetry_operation[1,1] + 1j*plane_symmetry_operation[1, 0])/1j).real

    symmetry_tensor = jnp.exp(plane_rotation_angle*1j*(dense_mpp)) #* (-1)**(dense_mp_2d)
    return symmetry_tensor * (-1)**(dense_mpp*(dense_l_2d+dense_lp_2d))


In [None]:
def cart_to_polar(c):
    """Converts cartesian coordinates to polar coordinates.

    Note, this function uses safe division to avoid division by zero errors, 
    and gives defined results and gradients for all inputs, EXCEPT for
    c = (0.0, 0.0, 0.0)."""
    z, x, y = c  # LEED coordinates

    x_y_norm = jnp.hypot(x, y)
    r = jnp.linalg.norm(c)
    theta = 2*jnp.arctan(
        _divide_zero_safe(x_y_norm, (jnp.hypot(x_y_norm, z)+z), (1/EPS) * (1 - jnp.sign(z)))
    )

    #forces phi to 0 on theta=0 axis (where phi is undefined)
    # phi = 2*jnp.arctan(
    #     _divide_zero_safe(y, (x_y_norm+x)+EPS, 0.0)
    # )
    phi =  jnp.sign(y) * jnp.arccos(_divide_zero_safe(x, (x_y_norm)+EPS, 0.0))
    phi = jnp.where(y != 0.0, phi, jnp.heaviside(-x, 0)*jnp.pi)

    return r, theta, phi


def spherical_to_cart(spherical_coordinates):

    r, theta, phi = spherical_coordinates
    x = r * jnp.sin(theta) * jnp.cos(phi)
    y = r * jnp.sin(theta) * jnp.sin(phi)
    z = r * jnp.cos(theta)

    return jnp.array([z, x, y])

In [None]:
import numpy as np
a = np.array([1., 0., 0.])
a.astype(bool)

In [None]:
c = np.array([1., 0., 0.])
spherical_to_cart(cart_to_polar(c))
composition = lambda c: spherical_to_cart(cart_to_polar(c))
jax.jacfwd(composition)(c), jax.jacrev(composition)(c)


In [None]:
@jax.custom_jvp
def cart_to_polar_2(c):
    """Converts cartesian coordinates to polar coordinates.

    Note, this function uses safe division to avoid division by zero errors, 
    and gives defined results and gradients for all inputs, EXCEPT for
    c = (0.0, 0.0, 0.0)."""
    z, x, y = c  # LEED coordinates

    x_y_norm = jnp.hypot(x, y)
    r = jnp.linalg.norm(c)
    theta = 2*jnp.arctan(
        _divide_zero_safe(x_y_norm, (jnp.hypot(x_y_norm, z)+z), (1/EPS) * (1 - jnp.sign(z)))
    )

    #forces phi to 0 on theta=0 axis (where phi is undefined)
    # phi = 2*jnp.arctan(
    #     _divide_zero_safe(y, (x_y_norm+x)+EPS, 0.0)
    # )
    phi =  jnp.sign(y) * jnp.arccos(_divide_zero_safe(x, (x_y_norm)+EPS, 0.0))
    phi = jnp.where(y != 0.0, phi, jnp.heaviside(-x, 0)*jnp.pi)

    return r, theta, phi


@cart_to_polar_2.defjvp
def cart_to_polar_jacobian(primals, tangents):
    z, x, y = primals[0]
    (dz, dx, dy) = tangents[0]
    r, theta, phi = cart_to_polar(primals[0])
    x_y_norm = jnp.hypot(x, y)
    jacobian = jnp.array(
        [z/r*dz + x/r*dx + y/r*dy,
         -x_y_norm/r**2*dz + x*z/(r**2 * x_y_norm)*dx+ y*z/(r**2 * x_y_norm)*dy,
         0 + -y/(x_y_norm**2)*dx + x/(x_y_norm**2)*dy],
        
    )
    return jnp.array([r, theta, phi]), jacobian

In [None]:
jax.jacrev(cart_to_polar)(jnp.array([1., 0, 0]))

In [None]:
jax.jacrev(cart_to_polar_2)(jnp.array([1., 0, 0]))