In [None]:
import jax
from jax import config
config.update("jax_debug_nans", False)
config.update("jax_enable_x64", True)
config.update("jax_disable_jit", False)
config.update("jax_log_compiles", False)
import jax.numpy as jnp

from pathlib import Path
import viperleed

from matplotlib import pyplot as plt
import numpy as np

%matplotlib inline

jax.devices()

In [None]:
from jax.scipy.special import sph_harm

from viperleed_jax.dense_quantum_numbers import DENSE_M, DENSE_L
from viperleed_jax.dense_quantum_numbers import MAXIMUM_LMAX
from viperleed_jax.lib_math import _divide_zero_safe, EPS

In [None]:
@jax.named_scope("HARMONY")
def HARMONY(C, LMAX):
    """Generates the spherical harmonics for the vector C.

    This is a python implementation of the fortran subroutine HARMONY from
    TensErLEED. It uses the jax.scipy.special.sph_harm function to produce
    equivalent results."""
    _, theta, phi = cart_to_polar(C)
    l = DENSE_L[2*LMAX]
    m = DENSE_M[2*LMAX]

    is_on_pole_axis = abs(theta)<=EPS
    _theta = jnp.where(is_on_pole_axis, 0.1, theta)

    # values at the poles(theta = 0) depend on l and m only
    pole_values = (m == 0)*jnp.sqrt((2*l+1)/(4*jnp.pi))
    non_pole_values = sph_harm(m, l,
                               jnp.asarray([phi]), jnp.asarray([_theta]),
                               n_max=2*LMAX)

    return jnp.where(is_on_pole_axis, pole_values, non_pole_values)


def cart_to_polar(c):
    """Converts cartesian coordinates to polar coordinates.

    Note, this function uses safe division to avoid division by zero errors, 
    and gives defined results and gradients for all inputs, EXCEPT for
    c = (0.0, 0.0, 0.0)."""
    z, x, y = c  # LEED coordinates

    x_y_norm = jnp.hypot(x, y)
    r = jnp.linalg.norm(c)
    theta = 2*jnp.arctan(
        _divide_zero_safe(x_y_norm, (jnp.hypot(x_y_norm, z)+z), (1/EPS) * (1 - jnp.sign(z)))
    )

    # forces phi to 0 on theta=0 axis (where phi is undefined)
    phi = 2*jnp.arctan(
        _divide_zero_safe(y, (x_y_norm+x)+EPS, 0.0)
    )

    return r, theta, phi

harm = jax.jit(HARMONY, static_argnums=(1,))

In [None]:
c_scalar = jnp.array([0.15, 0.09, 0.1])
c_arr = jnp.vstack([c_scalar]*100)

In [None]:
harm(c_scalar, MAXIMUM_LMAX)

In [None]:
harm = jax.jit(HARMONY, static_argnums=(1,))
%time harm(c_scalar, MAXIMUM_LMAX)
%timeit harm(c_scalar, MAXIMUM_LMAX)

In [None]:
harm_vmap = jax.jit(jax.vmap(HARMONY, in_axes=(0, None)), static_argnums=(1,))
%time harm_vmap(c_arr, MAXIMUM_LMAX)
%timeit harm_vmap(c_arr, MAXIMUM_LMAX)

# New

In [None]:
@jax.named_scope("HARMONY")
def HARMONY(C, LMAX):
    """Generates the spherical harmonics for the vector C.

    This is a python implementation of the fortran subroutine HARMONY from
    TensErLEED. It uses the jax.scipy.special.sph_harm function to produce
    equivalent results."""
    _, theta, phi = cart_to_polar(C)
    l = DENSE_L[2*LMAX]
    m = DENSE_M[2*LMAX]

    is_on_pole_axis = abs(theta)<=EPS
    _theta = jnp.where(is_on_pole_axis, 0.1, theta)

    # values at the poles(theta = 0) depend on l and m only
    pole_values = (m == 0)*jnp.sqrt((2*l+1)/(4*jnp.pi))

    # associated legendre polynomials
    legendre = jax.scipy.special.lpmn_values(2*LMAX, 2*LMAX, jnp.cos(jnp.array([_theta])), True)[abs(DENSE_M[2*LMAX]),DENSE_L[2*LMAX],0]
    non_pole_values = legendre * jnp.exp(1j*DENSE_M[2*LMAX]*jnp.array([phi]))
    non_pole_values = jnp.where(jnp.logical_and(DENSE_M[2*LMAX]<0, DENSE_M[2*LMAX]%2!=0), -non_pole_values, non_pole_values)


    return jnp.where(is_on_pole_axis, pole_values, non_pole_values)



def cart_to_polar(c):
    """Converts cartesian coordinates to polar coordinates.

    Note, this function uses safe division to avoid division by zero errors, 
    and gives defined results and gradients for all inputs, EXCEPT for
    c = (0.0, 0.0, 0.0)."""
    z, x, y = c  # LEED coordinates

    x_y_norm = jnp.hypot(x, y)
    r = jnp.linalg.norm(c)
    theta = 2*jnp.arctan(
        _divide_zero_safe(x_y_norm, (jnp.hypot(x_y_norm, z)+z), (1/EPS) * (1 - jnp.sign(z)))
    )

    # forces phi to 0 on theta=0 axis (where phi is undefined)
    phi = 2*jnp.arctan(
        _divide_zero_safe(y, (x_y_norm+x)+EPS, 0.0)
    )

    return r, theta, phi

harm = jax.jit(HARMONY, static_argnums=(1,))

In [None]:
harm(c_scalar, MAXIMUM_LMAX)

In [None]:
plt.imshow(abs(harm(c_scalar, MAXIMUM_LMAX) - res).reshape(2*MAXIMUM_LMAX+1, 2*MAXIMUM_LMAX+1))

In [None]:
harm = jax.jit(HARMONY, static_argnums=(1,))
%time harm(c_scalar, MAXIMUM_LMAX)
%timeit harm(c_scalar, MAXIMUM_LMAX)

In [None]:
harm_vmap = jax.jit(jax.vmap(HARMONY, in_axes=(0, None)), static_argnums=(1,))
%time harm_vmap(c_arr, MAXIMUM_LMAX)
%timeit harm_vmap(c_arr, MAXIMUM_LMAX)

In [None]:
DENSE_M[2*MAXIMUM_LMAX]

In [None]:
DENSE_M[2*MAXIMUM_LMAX].shape

In [None]:
sph_harm(DENSE_M[2*MAXIMUM_LMAX], DENSE_L[2*MAXIMUM_LMAX],jnp.asarray(0.1), jnp.asarray(0.1),n_max=2*MAXIMUM_LMAX)

In [None]:
sph_harm(1, 1, jnp.asarray([0.1]), jnp.asarray([0.1]), n_max=2*MAXIMUM_LMAX)

In [None]:
DENSE_L[5]

In [None]:
DENSE_M[5]

In [None]:
a, b = jax.scipy.special.lpmn(3, 3, jnp.array([0.1]))

In [None]:
a.shape

In [None]:
j = jax.jit(jax.scipy.special.lpmn_values, static_argnums=(0, 1, 3))
%time j(2*MAXIMUM_LMAX, 2*MAXIMUM_LMAX, jnp.array([0.1]*100,), True)
%timeit j(2*MAXIMUM_LMAX, 2*MAXIMUM_LMAX, jnp.array([0.1]*100,), True)

In [None]:
L=4
jax.scipy.special.sph_harm(DENSE_M[L], DENSE_L[L], jnp.array([0.1]), jnp.array([0.2]))

In [None]:
a = jax.scipy.special.lpmn_values(L, L, jnp.cos(jnp.array([0.2])), True)[abs(DENSE_M[L]),DENSE_L[L],0] * jnp.exp(1j*DENSE_M[L]*jnp.array([0.1]))
a

In [None]:
jnp.where(jnp.logical_and(DENSE_M[L]<0, DENSE_M[L]%2!=0), -a, a) - jax.scipy.special.sph_harm(DENSE_M[L], DENSE_L[L], jnp.array([0.1]), jnp.array([0.2]))

In [None]:
a

In [None]:
DENSE_L[2], DENSE_M[2]

In [None]:
jax.scipy.special.lpmn_values(3, 3, jnp.cos(jnp.array([-1.0])), True)

In [None]:
DENSE_M[2]

In [None]:
import scipy