In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from jax.scipy.special import sph_harm

from viperleed_jax.dense_quantum_numbers import DENSE_M, DENSE_L
from viperleed_jax.dense_quantum_numbers import MAXIMUM_LMAX

from functools import partial

LMAX = 18

import matplotlib.pyplot as plt

In [None]:
EPS = 1e-8


def _divide_zero_safe(
    numerator: jnp.ndarray,
    denominator: jnp.ndarray,
    limit_value: float = 0.0,
) -> jnp.ndarray:
    """Function that forces the result of dividing by 0 to be equal to a limit
    value in a jit- and autodiff-compatible way

    Args:
        numerator: Values in the numerator
        denominator: Values in the denominator, may contain zeros
        limit_value: Value to return where denominator == 0.0
    Returns:
        numerator / denominator with result == 0.0 where denominator == 0.0
    """
    denominator_masked = jnp.where(denominator == 0.0, 1.0, denominator)
    return jnp.where(
        denominator == 0.0,
        limit_value,
        numerator / denominator_masked,
    )

In [None]:
@jax.jit
def cart_to_polar(c):
    """Converts cartesian coordinates to polar coordinates.

    Note, this function uses safe division to avoid division by zero errors, 
    and gives defined results and gradients for all inputs, EXCEPT for
    c = (0.0, 0.0, 0.0)."""
    z, x, y = c  # LEED coordinates

    x_y_norm = jnp.hypot(x, y)
    r = jnp.linalg.norm(c)
    theta = 2*jnp.arctan(
        _divide_zero_safe(x_y_norm, (jnp.hypot(x_y_norm, z)+z)+EPS, 0)
    )
    phi = 2*jnp.arctan(
        _divide_zero_safe(y, (x_y_norm+x)+EPS, 0)
    )

    return r, theta, phi

In [None]:
@partial(jax.jit, static_argnames=('LMAX'))
def HARMONY(C, LMAX):
    """Generates the spherical harmonics for the vector C.

    This is a python implementation of the fortran subroutine HARMONY from
    TensErLEED. It uses the jax.scipy.special.sph_harm function to produce
    equivalent results."""
    _, theta, phi = cart_to_polar(C)
    l = DENSE_L[2*LMAX]
    m = DENSE_M[2*LMAX]

    is_on_pole_axis = theta==0
    _theta = jnp.where(is_on_pole_axis, 0.1, theta)

    # values at the poles(theta = 0) depend on l and m only
    pole_values = (m == 0)*jnp.sqrt((2*l+1)/(4*jnp.pi))
    non_pole_values = sph_harm(m, l,
                               jnp.asarray([phi]), jnp.asarray([_theta]),
                               n_max=2*LMAX)

    return jnp.where(is_on_pole_axis, pole_values, non_pole_values)


In [None]:
coord = np.linspace(-1e-1, 1e-1, 501)

In [None]:
fig, axs = plt.subplots(3,3)
dfig, daxs = plt.subplots(3,3)

hfig, haxs = plt.subplots(3,1)
hdfig, hdaxs = plt.subplots(3,3)


dir = 0 # z

vec = np.array([0.0, 0.0, 0.0])
r, theta, phi, dr, dtheta, dphi, h, hg = [], [], [], [], [], [], [], []
for c in coord:
    vec[dir] = c
    _r, _theta, _phi = cart_to_polar(vec)
    r.append(_r)
    theta.append(_theta)
    phi.append(_phi)
    _dr, _dtheta, _dphi = jax.jacrev(cart_to_polar)(vec)
    dr.append(_dr)
    dtheta.append(_dtheta)
    dphi.append(_dphi)
    h.append(HARMONY(vec, 18).real.sum())
    hg.append(jax.grad(lambda v: HARMONY(v, 18).real.sum())(vec))


axs[dir, 0].plot(coord, r)
axs[dir, 1].plot(coord, theta)
axs[dir, 2].plot(coord, phi)

daxs[dir, 0].plot(coord, np.array(dr))
daxs[dir, 1].plot(coord, np.array(dtheta))
daxs[dir, 2].plot(coord, np.array(dphi))

hdaxs[dir, 0].plot(coord, np.array(hg)[:, 0])
hdaxs[dir, 1].plot(coord, np.array(hg)[:, 1])
hdaxs[dir, 2].plot(coord, np.array(hg)[:, 2])

haxs[dir].plot(coord, np.array(h))

# for ax in daxs.flatten():
#     ax.set_yscale('log')


In [None]:
hg

In [None]:

g = jax.jacrev(lambda theta, phi: sph_harm(jnp.array([1]), jnp.array([1]),
                                     jnp.array([theta]), jnp.array([phi]), jnp.array(2)).real
           )(-0.00000, 0.001)

In [None]:
%timeit ha = HARMONY(vec, 5).real

In [None]:
ha

In [None]:
vec = np.array([0.0, 0., 0.])
ha = []
for c in coord:
    vec[dir] = c
    ha.append(HARMONY(vec, 5).real)

In [None]:
ha = np.array(ha)

In [None]:
ha.shape

In [None]:
plt.plot(ha)

In [None]:
np.max(ha)

In [None]:
vec

In [None]:
a = jax.jacfwd(lambda c, l: HARMONY(c, l))(vec, 18)

In [None]:
vec

In [None]:
a

In [None]:
plt.imshow(jnp.isnan(a.real), aspect='auto')

In [None]:
m == 0

In [None]:
(m == 0)*jnp.sqrt((2*l+1)/(4*jnp.pi))

In [None]:
jnp.where(phi==0.0, )

In [None]:
a.shape

In [None]:
phi, theta

In [None]:
np.array(hg).shape

In [None]:
np.isnan(np.array(hg)).sum()

In [None]:
plt.plot(np.array(hg))

In [None]:
plt.plot(h)

In [None]:
np.isnan(np.array(h)).sum()

In [None]:
np.isnan(np.array(hg)).shape

In [None]:
plt.plot(np.isnan(np.array(hg))[:, 0])

In [None]:

plt.plot(np.isnan(np.array(hg))[:, 1])

In [None]:
plt.plot(np.array(hg))

In [None]:
np.sum(np.isnan(np.array(r))), np.argwhere(np.isnan(np.array(dr)))

In [None]:
np.sum(np.isnan(np.array(theta))), np.sum(np.isnan(np.array(dtheta)))

In [None]:
np.sum(np.isnan(np.array(phi))), np.sum(np.isnan(np.array(dphi)))

In [None]:
jax.jacrev(cart_to_polar)(c)