# Testing JAX

__Author:__ A. J. Tropiano [atropiano@anl.gov]<br/>
__Date:__ May 3, 2023

In this notebook, we test how to use `jax.numpy`, vectorize functions using `jax.vmap`, and possibly speed-up functions using `jax.jit`.

_Last update:_ May 16, 2023

In [1]:
# Python imports
from functools import partial
from jax import config, jit, vmap
import jax.numpy as jnp
from jax.scipy.special import sph_harm as sph_harm_jax
import numpy as np
from scipy.special import sph_harm, spherical_jn
from sympy.physics.quantum.cg import CG
import timeit

In [2]:
# Imports from scripts
# ...

In [3]:
# Enable double-precision with JAX arrays
config.update("jax_enable_x64", True)

## Function that depends on scalar and array parameters

In [4]:
def compute_cg_array(j_max):
    
    j_array = np.arange(0, j_max+1/2, 1/2)
    N_j = j_array.size
    m_array = np.concatenate((j_array, -j_array[1:]))
    N_m = m_array.size
    
    cg_array = np.zeros((N_j, N_j, N_j, N_m, N_m, N_m))
    
    for i, j_1 in enumerate(j_array):
        m_1_array = np.arange(-j_1, j_1+1)
        for j, j_2 in enumerate(j_array):
            m_2_array = np.arange(-j_2, j_2+1)
            j_3_array = np.arange(np.abs(j_1-j_2), j_1+j_2+1)
            for k, j_3 in enumerate(j_array):
                m_3_array = np.arange(-j_3, j_3+1)
                for l, m_1 in enumerate(m_array):
                    for m, m_2 in enumerate(m_array):
                        for n, m_3 in enumerate(m_array):
                            
                            selection_rules = (
                                np.any(j_3 == j_3_array)
                                and np.any(m_1 == m_1_array)
                                and np.any(m_2 == m_2_array)
                                and np.any(m_3 == m_3_array)
                                and m_1 + m_2 == m_3
                            )

                            if selection_rules:

                                cg_array[i, j, k, l, m, n] = float(
                                    CG(j_1,m_1,j_2,m_2,j_3,m_3).doit()
                                )
    
    return jnp.array(cg_array, dtype=jnp.float64), N_j

In [5]:
def cg_mapping(j, m, N_j):
    """Return the indices of the input angular momentum and projection for the
    array of Clebsch-Gordan coefficients.
    """
    
    j_index = jnp.array(j / 0.5, dtype=int)
    m_index = jnp.array(jnp.abs(m/0.5) + jnp.heaviside(-m, 0) * (N_j-1),
                        dtype=int)

    return j_index, m_index

In [6]:
def clebsch_gordan_coefficient(j1, m1, j2, m2, j3, m3, cg_array, N_j):
    """Clebsch-Gordan coefficient < j1 m1 j2 m2 | j3 m3 >."""
    
    ij, im = cg_mapping(j1, m1, N_j)
    jj, jm = cg_mapping(j2, m2, N_j)
    kj, km = cg_mapping(j3, m3, N_j)
    
    return cg_array[ij, jj, kj, im, jm, km]

In [7]:
@partial(jit, static_argnums=(7,))
def clebsch_gordan_coefficient_vmap(j1, m1, j2, m2, j3, m3, cg_array, N_j):
    return vmap(
        clebsch_gordan_coefficient, in_axes=(0, 0, 0, 0, 0, 0, None, None),
        out_axes=(0)
    )(j1, m1, j2, m2, j3, m3, cg_array, N_j)

In [8]:
cg_array, N_j = compute_cg_array(1)

In [9]:
# Spin arrays
s_array = jnp.repeat(1/2, 10)
m1_array = jnp.array([1/2, -1/2, 1/2, -1/2, 1/2, -1/2, 1/2, -1/2, 1/2, -1/2])
m2_array = jnp.array([1/2, 1/2, 1/2, 1/2, 1/2, -1/2, -1/2, -1/2, -1/2, -1/2])
S_array = jnp.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
M_array = jnp.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 1])

In [10]:
%timeit cgs = clebsch_gordan_coefficient_vmap(s_array, m1_array, s_array, m2_array, S_array, M_array, cg_array, N_j)

3.4 µs ± 62.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## Function that depends on several functions that could be vectorized

In [11]:
def kronecker_delta(x, y):
    """Kronecker \delta function: \delta_{x,y}."""
    
    return jnp.array(x == y, dtype=int)

### Sub-functions are vectorized

In [12]:
@jit
def kronecker_delta_vmap(x, y):
    return vmap(kronecker_delta)(x, y)

In [13]:
def spinor_sph_harm(theta, phi, l, j, m_j, m_t, sigma, tau, cg_array, N_j):
    
    s = jnp.repeat(1/2, sigma.size)
    m_s = sigma
    
    m_l = m_j - m_s
    
    # Calls clebsch_gordan_coefficient_vmap
    cg = clebsch_gordan_coefficient_vmap(
        l, m_l, s, m_s, j, m_j, cg_array, N_j).block_until_ready()
    
    # Isospinor indexed by \tau \chi_{m_t}(\tau)
    chi_tau = kronecker_delta_vmap(tau, m_t).block_until_ready()
    
    # Calls sph_harm or sph_harm_jax (which is already vectorized)
    # NumPy
    Y_lmj = jnp.where(
        jnp.abs(m_l) <= l,
        cg * chi_tau * sph_harm(m_l, l, phi, theta), 0+0j
    )

    return Y_lmj

In [14]:
cg_array, N_j = compute_cg_array(3)

In [15]:
theta = jnp.linspace(0.0, jnp.pi, 100)
phi = jnp.linspace(0.0, 2*jnp.pi, 100)
l = jnp.repeat(jnp.array([0, 1]), 50)
j = jnp.repeat(jnp.array([1/2, 1/2, 3/2, 1/2]), 25)
m_j = jnp.repeat(
    jnp.array([1/2, -1/2, 1/2, -1/2, 3/2, 1/2, -1/2, -3/2, 1/2, -1/2]), 10
)
m_t = jnp.repeat(1/2, 100)
sigma = jnp.repeat(jnp.array([1/2, -1/2]), 50)
tau = jnp.repeat(jnp.array([1/2, -1/2]), 50)

In [16]:
%timeit Y_lmj = spinor_sph_harm(theta, phi, l, j, m_j, m_t, sigma, tau, cg_array, N_j)

248 µs ± 7.45 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Main function is vectorized

_Can't get this to work._

In [17]:
# def spinor_sph_harm_2(theta, phi, l, m_l, sigma, j, m_j, m_t, tau, cg_array, N_j, l_max):
    
#     s = jnp.repeat(1/2, sigma.size)
    
#     # Calls clebsch_gordan_coefficient_vmap
#     cg = clebsch_gordan_coefficient(l, m_l, s, sigma, j, m_j, cg_array, N_j)
    
#     # Isospinor indexed by \tau \chi_{m_t}(\tau)
#     chi_tau = kronecker_delta(tau, m_t)
    
#     # Calls sph_harm (which is already vectorized)
#     Y_lm = sph_harm_jax(m_l, l, phi, theta, n_max=l_max)
#     Y_lmj = jnp.where(jnp.abs(m_l) <= l, cg * chi_tau * Y_lm, 0+0j)

#     return Y_lmj

In [18]:
# @partial(jit, static_argnums=(10,11,))
# def spinor_sph_harm_vmap(theta, phi, l, m_l, sigma, j, m_j, m_t, tau, cg_array, N_j, l_max):
    
#     return vmap(
#         spinor_sph_harm_2, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None),
#         out_axes=(0)
#     )(theta, phi, l, m_l, sigma, j, m_j, m_t, tau, cg_array, N_j, l_max)

In [19]:
# m_l = m_j - sigma

In [20]:
# %timeit Y_lmj = spinor_sph_harm_vmap(theta, phi, l, m_l, sigma, j, m_j, m_t, tau, cg_array, N_j, 1).block_until_ready()

## Function with scalar input $x$ which outputs quantum number arrays

In [21]:
delta_U_quantum_numbers = np.loadtxt('O16_quantum_numbers.txt')
delU_Ntot = delta_U_quantum_numbers.shape[0]
x_array = np.linspace(0.01, 1.0, 100)

In [22]:
# Looping
def get_quantum_numbers_looping(x_array):
    
    quantum_numbers = np.zeros((x_array.size, 28))
    
    for i, x in enumerate(x_array):
    
        # Index for combination of s.p. and partial wave channel quantum numbers
        index = np.floor(x * (delU_Ntot-1)).astype(int)

        quantum_numbers[i] = delta_U_quantum_numbers[index]
        
    return quantum_numbers

%timeit quantum_numbers_1 = get_quantum_numbers_looping(x_array)

127 µs ± 2.08 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [23]:
# NumPy
def get_quantum_numbers(x):
    
    index = np.floor(x * (delU_Ntot-1)).astype(int)
    return delta_U_quantum_numbers[index]

get_quantum_numbers_numpy = np.vectorize(get_quantum_numbers,
                                         signature='(n)->(n,28)')

%timeit quantum_numbers_2 = get_quantum_numbers_numpy(x_array)

21 µs ± 84.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [24]:
# JAX
delta_U_quantum_numbers = jnp.array(np.loadtxt('O16_quantum_numbers.txt'))
delU_Ntot = delta_U_quantum_numbers.shape[0]
x_array = jnp.linspace(0.01, 1.0, 100)

def get_quantum_numbers_jax(x):
    
    index = jnp.floor(x * (delU_Ntot-1)).astype(int)
    return delta_U_quantum_numbers[index]

@jit
def get_quantum_numbers_jax_vmap(x_array):
    return vmap(get_quantum_numbers_jax)(x_array)

%timeit quantum_numbers_3 = get_quantum_numbers_jax_vmap(x_array).block_until_ready()

2.77 µs ± 19 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## Function that calls an interpolating function

In [25]:
# Functions rely on this function
def phi_func(k):
    return np.exp(-k**2)/4

In [26]:
k_array = np.linspace(0, 15, 100)

In [27]:
def spinor_sph_harm_python(theta, phi, l, j, m_j, m_t, sigma, tau, cg_array,
                           N_j):
    
    s = np.repeat(1/2, sigma.size)
    m_s = sigma
    
    m_l = m_j - m_s
    
    cg = clebsch_gordan_coefficient(l, m_l, s, m_s, j, m_j, cg_array, N_j)
    
    chi_tau = kronecker_delta(tau, m_t)
    
    Y_lmj = np.where(
        np.abs(m_l) <= l, cg * chi_tau * sph_harm(m_l, l, phi, theta), 0+0j
    )
    
    return Y_lmj

In [28]:
# Python vectorized functions
def psi(k, theta, phi, l, j, m_j, m_t, sigma, tau, cg_array, N_j, sp_wf):
    
    # Call pure Python function that already exists
    phi_array = sp_wf(k)
    
    Y_jml = spinor_sph_harm_python(theta, phi, l, j, m_j, m_t, sigma, tau,
                                   cg_array, N_j)
    
    return phi_array * Y_jml

%timeit psi(k_array, theta, phi, l, j, m_j, m_t, sigma, tau, cg_array, N_j, phi_func)

1.71 ms ± 19.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [29]:
# JAX vectorized functions with call to phi_func
def psi_jax(k, theta, phi, l, j, m_j, m_t, sigma, tau, cg_array, N_j, sp_wf):
    
    phi_array = sp_wf(k)
    
    Y_jml = spinor_sph_harm(theta, phi, l, j, m_j, m_t, sigma, tau, cg_array,
                            N_j)
    
    return phi_array * Y_jml

%timeit psi_jax(k_array, theta, phi, l, j, m_j, m_t, sigma, tau, cg_array, N_j, phi_func)

247 µs ± 1.94 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Which case is the most expensive within the integrand function?

In [None]:
# CHECK INTEGRAND FUNCTION - WHAT TAKES THE LONGEST