# Testing JAX

__Author:__ A. J. Tropiano [atropiano@anl.gov]<br/>
__Date:__ May 3, 2023

In this notebook, we test how to use `jax.numpy`, vectorize functions using `jax.vmap`, and possibly speed-up functions using `jax.jit`.

_Last update:_ May 10, 2023

In [1]:
# Python imports
from functools import partial
from jax import device_put, jit, vmap
import jax.numpy as jnp
from jax.scipy.special import sph_harm as sph_harm_jax
import numpy as np
import numpy.linalg as la
from scipy.interpolate import interp1d, RectBivariateSpline
from scipy.special import sph_harm, spherical_jn
from sympy.physics.quantum.cg import CG
import timeit

In [2]:
# Imports from scripts
from scripts.integration import momentum_mesh, unattach_weights_from_matrix
from scripts.srg import SRG

## Calculating arrays of Clebsch-Gordan coefficients

In [3]:
def compute_clebsch_gordan_table(j_max):
    """
    Calculate Clebsch-Gordan coefficients for combinations of j and m_j up
    to j_max.
        
    Parameters
    ----------
    j_max : int
        Maximum j value for j_1, j_2, and j_3. This also constrains m_j.
        
    Returns
    -------
    cg_table : dict
        Table of Clebsch-Gordan coefficients <j_1 m_j_1 j_2 m_j_2|j_3 m_j_3>
        for each combination of angular momenta.
            
    """
        
    j_array = np.arange(0, j_max+1/2, 1/2)
    
    cg_table = {}        
    for j_1 in j_array:
        for j_2 in j_array:
            j_3_array = np.arange(abs(j_1-j_2), j_1+j_2+1/2)
            for j_3 in j_3_array:
                for m_1 in np.arange(-j_1, j_1+1, 1):
                    for m_2 in np.arange(-j_2, j_2+1, 1):
                        m_3 = m_1 + m_2
                        if abs(m_3) <= j_3:
                            cg_table[(j_1,m_1,j_2,m_2,j_3,m_3)] = float(
                                CG(j_1,m_1,j_2,m_2,j_3,m_3).doit()
                            )
                                    
    return cg_table      

In [4]:
# Calculate the table of CG coefficients
cg_table = compute_clebsch_gordan_table(3)

In [5]:
def get_cg_coefficient(j1, m1, j2, m2, j3, m3, cg_table):
    """Clebsch-Gordan coefficient < j1 m1 j2 m2 | j3 m3 >."""
        
    try:
            
        return cg_table[(j1, m1, j2, m2, j3, m3)]
                
    except KeyError:
            
        return 0

In [6]:
def cgs_by_looping(j1, m1, j2, m2, j3, m3, cg_table):
    """Get array of CG's by looping."""
    
    N_batch = len(j1)
    cg_array = np.zeros(N_batch)
    for i in range(N_batch):
        cg_array[i] = get_cg_coefficient(j1[i], m1[i], j2[i], m2[i], j3[i],
                                         m3[i], cg_table)
        
    return cg_array

In [7]:
def compute_clebsch_gordan_array(j_max):
    
    j_array = np.arange(0, j_max+1/2, 1/2)
    N_j = j_array.size
    m_array = np.concatenate((j_array, -j_array[1:]))
    N_m = m_array.size
    
    cg_array = np.zeros((N_j, N_j, N_j, N_m, N_m, N_m))
    
    for i, j_1 in enumerate(j_array):
        m_1_array = np.arange(-j_1, j_1+1)
        for j, j_2 in enumerate(j_array):
            m_2_array = np.arange(-j_2, j_2+1)
            j_3_array = np.arange(np.abs(j_1-j_2), j_1+j_2+1)
            for k, j_3 in enumerate(j_array):
                m_3_array = np.arange(-j_3, j_3+1)
                for l, m_1 in enumerate(m_array):
                    for m, m_2 in enumerate(m_array):
                        for n, m_3 in enumerate(m_array):
                            
                            selection_rules = (
                                np.any(j_3 == j_3_array)
                                and np.any(m_1 == m_1_array)
                                and np.any(m_2 == m_2_array)
                                and np.any(m_3 == m_3_array)
                                and m_1 + m_2 == m_3
                            )

                            if selection_rules:
                                
                                cg = float(CG(j_1,m_1,j_2,m_2,j_3,m_3).doit())
                                cg_array[i, j, k, l, m, n] = cg
    
    return jnp.array(cg_array), N_j

In [8]:
# Calculate the array of CG coefficients
cg_array, N_j = compute_clebsch_gordan_array(3)

In [9]:
def cg_mapping(j, m, N_j):
    """Return the indices of the input angular momentum and projection for the
    array of Clebsch-Gordan coefficients.
    """
    
    j_index = jnp.array(j / 0.5, dtype=int)
    m_index = jnp.array(jnp.abs(m/0.5) + jnp.heaviside(-m, 0) * (N_j-1),
                        dtype=int)

    return j_index, m_index

In [10]:
# # HARD-CODING N_J AND CG_ARRAY CURRENTLY
# def get_cg_coefficient_jax(j1, m1, j2, m2, j3, m3):
#     """Clebsch-Gordan coefficient < j1 m1 j2 m2 | j3 m3 >."""
    
#     N_j = 7
    
#     ij, im = cg_mapping(j1, m1, N_j)
#     jj, jm = cg_mapping(j2, m2, N_j)
#     kj, km = cg_mapping(j3, m3, N_j)
    
#     return cg_array[ij, jj, kj, im, jm, km]

In [11]:
# Vectorize CG function using NumPy
cg_func_vect = np.vectorize(get_cg_coefficient, otypes=[float])

In [12]:
# # Vectorize CG function using JAX
# @jit
# def cg_func_vect_jax(j1, m1, j2, m2, j3, m3):
#     return vmap(get_cg_coefficient_jax)(j1, m1, j2, m2, j3, m3)

def get_cg_coefficient_jax(j1, m1, j2, m2, j3, m3, cg_array, N_j):
    """Clebsch-Gordan coefficient < j1 m1 j2 m2 | j3 m3 >."""
    
    ij, im = cg_mapping(j1, m1, N_j)
    jj, jm = cg_mapping(j2, m2, N_j)
    kj, km = cg_mapping(j3, m3, N_j)
    
    return cg_array[ij, jj, kj, im, jm, km]


@jit
def cg_func_vect_jax(j1, m1, j2, m2, j3, m3, cg_array, N_j):
    return vmap(
        get_cg_coefficient_jax, in_axes=(0, 0, 0, 0, 0, 0, None, None), out_axes=(0)
    )(j1, m1, j2, m2, j3, m3, cg_array, N_j)

In [13]:
# Batch of spin values
N_batch = 10000
random_numbers = np.random.random((N_batch, 2))
spin_samples = np.zeros((N_batch, 4))
for i in range(N_batch):
    for j in range(2):
        if random_numbers[i, j] > 0.5:
            spin_samples[i, 2*j+1] = 1/2
        else:
            spin_samples[i, 2*j+1] = -1/2  
spin_samples[:, 0], spin_samples[:, 2] = 1/2, 1/2

In [14]:
# Coupling two spin-1/2 particles
s1 = spin_samples[:, 0]
sigma_1 = spin_samples[:, 1]
s2 = spin_samples[:, 2]
sigma_2 = spin_samples[:, 3]
# S, M_S = 1, 0  # Taking S = 1, M_S = 0
S, M_S = np.repeat(1, N_batch), np.repeat(0, N_batch)

In [15]:
# JAX arrays
s1_jax = jnp.array(s1)
s2_jax = jnp.array(s2)
sigma_1_jax = jnp.array(sigma_1)
sigma_2_jax = jnp.array(sigma_2)
S_jax = jnp.array(S)
M_S_jax = jnp.array(M_S)

In [16]:
# Comparison

# Looping
%timeit cg_array_1 = cgs_by_looping(s1, sigma_1, s2, sigma_2, S, M_S, cg_table)

# numpy.vectorize
%timeit cg_array_2 = cg_func_vect(s1, sigma_1, s2, sigma_2, S, M_S, cg_table)

# JAX with NumPy arrays
%timeit cg_array_3 = cg_func_vect_jax(s1, sigma_1, s2, sigma_2, S, M_S, cg_array, N_j).block_until_ready()

# JAX with JAX arrays
%timeit cg_array_4 = cg_func_vect_jax(s1_jax, sigma_1_jax, s2_jax, sigma_2_jax, S_jax, M_S_jax, cg_array, N_j).block_until_ready()

5.19 ms ± 44.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.93 ms ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
105 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
83.6 µs ± 14.8 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Calculating spherical harmonics given arrays of $l$, $m_l$, $\theta$, and $\phi$

In [17]:
def ylms_by_looping(l_array, m_array, theta_array, phi_array):
    
    N_batch = len(l_array)
    ylm_array = np.zeros(N_batch, dtype='complex')
    for i in range(N_batch):
        ylm_array[i] = sph_harm(m_array[i], l_array[i], phi_array[i],
                                theta_array[i])
    return ylm_array

In [18]:
# Batch of l, m_l, \theta, and \phi values
l_array = np.array([0, 1, 0, 1, 0, 0, 2, 1, 1, 2, 0, 1, 0, 1, 0, 0, 2, 1, 1, 2])
m_array = np.concatenate((l_array[:10], -l_array[10:]))
theta_array = np.linspace(0.0, np.pi, 20)
phi_array = np.linspace(0.0, 2*np.pi, 20)

In [19]:
# Comparison of the two

# Looping
%timeit ylm_array_1 = ylms_by_looping(l_array, m_array, theta_array, phi_array)

# Vectorized
%timeit ylm_array_2 = sph_harm(m_array, l_array, phi_array, theta_array)

# JAX
%timeit ylm_array_3 = sph_harm_jax(m_array, l_array, phi_array, theta_array).block_until_ready()

31.1 µs ± 283 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
888 ns ± 6.42 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
10.6 µs ± 36.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [20]:
ylm_array_3 = sph_harm_jax(m_array, l_array, phi_array, theta_array).block_until_ready()
print(ylm_array_3)

[ 0.28209478+0.j         -0.05378524-0.0184645j   0.28209478+0.j
 -0.08993854-0.13766117j  0.28209478+0.j          0.28209478+0.j
 -0.18335326-0.19917472j  0.21428804-0.23277888j  0.2945552 -0.1594053j
  0.36285338-0.12456774j  0.28209478+0.j         -0.2945552 +0.15940529j
  0.28209478+0.j         -0.11618476+0.26487473j  0.28209478+0.j
  0.28209478+0.j         -0.03514881+0.08013118j  0.08852715+0.06890342j
  0.05378524+0.0184645j   0.        +0.j        ]


## Vectorizing $\psi_\alpha(\mathbf{q};\sigma,\tau)$

In [21]:
class SingleParticleState:
    """
    Single-particle state class. Packs together the following single-particle
    quantum numbers into one object.
    
    Parameters
    ----------
    n : int
        Principal quantum number n = 1, 2, ...
    l : int
        Orbital angular momentum l = 0, 1, ...
    j : float
        Total angular momentum j = 1/2, 3/2, ...
    m_j : float
        Total angular momentum projection m_j = -j, -j+1, ..., j.
    m_t : float
        Isospin projection m_t = 1/2 or -1/2.
    
    """
    
    
    def __init__(self, n, l, j, m_j, m_t):
        
        # Check if m_j is valid
        if abs(m_j) > j:
            raise RuntimeError("m_j is not valid.")
            
        # Check that |m_t| = 1/2
        if abs(m_t) != 1/2:
            raise RuntimeError("m_t is not valid.")
            
        self.n = n
        self.l = l
        self.j = j
        self.m_j = m_j
        self.m_t = m_t
        
        if m_t == 1/2:
            self.nucleon = 'proton'
        elif m_t == -1/2:
            self.nucleon = 'neutron'
        
        
    def __eq__(self, sp_state):

        if (
            self.n == sp_state.n and self.l == sp_state.l
            and self.j == sp_state.j and self.m_j == sp_state.m_j
            and self.m_t == sp_state.m_t
        ):
            
            return True
        
        else:
            
            return False
        
        
    def __str__(self):
        
        # Spectroscopic notation of orbital angular momentum
        l_str = convert_l_to_string(self.l)  # E.g., 's', 'p', 'd', ...
        
        # Display j subscript as a fraction
        numerator = 2*int(self.j) + 1
        denominator = 2

        return fr"${self.n}{l_str}_{{{numerator}/{denominator}}}$"

In [22]:
class WoodsSaxon:
    """
    Woods-Saxon orbitals class. Handles the wave functions associated with the
    Woods-Saxon potential from the subroutine in woodsaxon.f90. Outputs wave
    functions in coordinate and momentum space.
    
    Parameters
    ----------
    nucleus_name : str
        Name of the nucleus (e.g., 'O16', 'Ca40', etc.)
    Z : int
        Proton number of the nucleus.
    N : int
        Neutron number of the nucleus.
    run_woodsaxon : bool, optional
        Option to run the Woods-Saxon subroutine to generate orbital files.
    n_max : int, optional
        Maximum principal quantum number where n = 1, 2, ..., n_max.
    l_max : int, optional
        Maximum orbital angular momentum where l = 0, 1, ..., l_max.
    rmax : float, optional
        Maximum r for orbital tables.
    ntab : int, optional
        Number of points for orbital tables.
        
    """
    
    
    def __init__(
        self, nucleus_name, Z, N, run_woodsaxon=True, n_max=0, l_max=0, rmax=40,
        ntab=2000
    ):
        
        # Set instance attributes
        self.woods_saxon_directory = f"../data/woods_saxon/{nucleus_name}/"
        self.dr = rmax / ntab
        self.r_array = np.arange(self.dr, rmax + self.dr, self.dr)

        # Generate orbitals?
        if run_woodsaxon:
            
            self.run_woods_saxon_code(nucleus_name, Z, N, n_max, l_max, rmax,
                                      ntab)

            # Move output files to relevant directory
            shutil.move("ws_log", self.woods_saxon_directory + "ws_log")
            shutil.move("ws_pot", self.woods_saxon_directory + "ws_pot")
            shutil.move("ws_rho", self.woods_saxon_directory + "ws_rho")
                
        # Order single-particle states with lowest energy first
        self.order_sp_states(Z, N)
        
        # Organize wave functions in dictionary with the file name as the key
        self.sp_wfs = {}
        for sp_state in self.sp_states:
            # Wave functions are independent of m_j, so fix m_j=j
            if sp_state.m_j == sp_state.j:
                file_name = get_orbital_file_name(sp_state.n, sp_state.l,
                                                  sp_state.j, sp_state.m_t)
                if run_woodsaxon:
                    shutil.move(file_name,
                                self.woods_saxon_directory + file_name)
                data = np.loadtxt(self.woods_saxon_directory + file_name)
                # Use file name as the key
                self.sp_wfs[file_name] = data[:, 1]

            
    def run_woods_saxon_code(
            self, nucleus_name, Z, N, n_max, l_max, rmax, ntab
    ):
        """Run Woods-Saxon code to generate data."""
        
        # Total number of nucleons
        A = Z + N
        
        # Type of orbitals: 1 - nucleons with no Coulomb
        #                   2 - distinguish protons and neutrons
        ntau = 2
        
        # Orbitals to consider (note, we track 2*j not j)
        norb, lorb, jorb = [], [], []
        for n in range(1, n_max+1):
            for l in range(0, l_max+1):
                norb.append(n)
                lorb.append(l)
                jorb.append(int(2*(l+1/2)))
                if int(2*(l-1/2)) > 0:  # Don't append negative j
                    norb.append(n)
                    lorb.append(l)
                    jorb.append(int(2*(l-1/2)))
        nrad = len(jorb)
        orbws = np.zeros(shape=(2,nrad,ntab), order='F')
    
        # Divide orbital by r? -> get R(r); false: get u(r)=r R(r)
        rdiv = False
        dens = True
    
        # Set parameters of the Woods-Saxon potential
        prm = np.zeros(shape=(2,9), order='F')
    
        # Starting with vws (p & n)
        if nucleus_name == 'He4':
            prm[:,0] = 76.8412
        elif nucleus_name == 'O16':
            prm[:,0] = 58.0611
        elif nucleus_name == 'Ca40':
            prm[:,0] = 54.3051
        elif nucleus_name == 'Ca48':
            prm[0,0] = 59.4522
            prm[1,0] = 46.9322
    
        # Not sure about these (better way to load these parameters?)
        prm[:,1] = 1.275
        prm[:,2] = 0.7
        prm[:,3] = 0.
        prm[:,4] = 1.
        prm[:,5] = 36
        prm[:,6] = 1.32
        prm[:,7] = 0.7
        prm[:,8] = 1.275
        
        # Print summary, potentials, and densities
        prnt = True
        prntorb = True

        # Run Fortran subroutine
        ws(ntau, A, Z, rmax, orbws, norb, lorb, jorb, prm, rdiv, prnt, prntorb,
           dens)
        
        
    def order_sp_states(self, Z, N):
        """Keep track of all s.p. states and occupied s.p. states"""

        self.sp_states = []
        self.occ_states = []
        proton_count = 0
        neutron_count = 0
        
        # File with details of the orbitals
        ws_file = self.woods_saxon_directory + "ws_log"
    
        # Order single-particle states using the ws_log file
        with open(ws_file, 'r') as f:
            for line in f:
                unit = line.strip().split()
                
                # Protons
                if len(unit) > 0 and unit[0] == '1':

                    j = int(unit[3])/2
                    for m_j in np.arange(-j, j+1, 1):
                        sp_state = SingleParticleState(
                            int(unit[1])+1, int(unit[2]), j, m_j, 1/2
                        )  # n, l, j, m_j, m_t
                    
                        self.sp_states.append(sp_state)
                    
                        if proton_count < Z:
                            self.occ_states.append(sp_state)
                            # Add up filled proton states
                            proton_count += 1
                    
                
                # Neutrons
                elif len(unit) > 0 and unit[0] == '2':

                    j = int(unit[3])/2
                    for m_j in np.arange(-j, j+1, 1):
                        sp_state = SingleParticleState(
                            int(unit[1])+1, int(unit[2]), j, m_j, -1/2
                        )  # n, l, j, m_j, m_t
                    
                        self.sp_states.append(sp_state)
                    
                        if neutron_count < N:
                            self.occ_states.append(sp_state)
                            # Add up filled neutron states
                            neutron_count += 1
                        
                        
    def get_wf_rspace(self, sp_state, print_normalization=False):
        """Single-particle wave function in coordinate space."""
        
        # Orbital file name is the key
        u_array = self.sp_wfs[get_orbital_file_name(sp_state.n, sp_state.l,
                                                    sp_state.j, sp_state.m_t)]

        # Normalization: \int dr |u(r)|^2 = 1
        if print_normalization:
            normalization = np.sum(self.dr*u_array**2)
            print(f"Coordinate space normalization = {normalization}.")

        return self.r_array, u_array
    
    
    def fourier_transformation(self, l, k_array):
        """Fourier transformation matrix for given orbital angular momentum."""
        
        # r_array column vectors and k_array row vectors where both grids are
        # n x m matrices
        r_cols, k_rows = np.meshgrid(self.r_array, k_array)
        
        # Transformation matrix with shape n x m, where m is the length of
        # r_array and n is the length of the k_array
        M = 1j**(-l) * np.sqrt(2/np.pi) * self.dr * r_cols * spherical_jn(
            l, k_rows*r_cols
        )
        
        return M
    
    
    def get_wf_kspace(
            self, sp_state, kmax, kmid, ntot, print_normalization=False,
            interpolate=False,
            
    ):
        """Single-particle wave function in momentum space."""
    
        # Set momentum mesh with more points at low momentum
        k_array, k_weights = momentum_mesh(kmax, kmid, ntot)
    
        # Get coordinate-space s.p. wave function
        _, u_array = self.get_wf_rspace(sp_state)

        # Fourier-transform the wave function to momentum space
        phi_array = self.fourier_transformation(sp_state.l, k_array) @ u_array
    
        # Normalization: \int dk k^2 |\phi(k)|^2 = 1
        if print_normalization:
            normalization = np.sum(k_weights*k_array**2*abs(phi_array)**2)
            print(f"Momentum space normalization = {normalization}.")
            
        # Interpolate and return function?
        if interpolate:
            phi_func = interp1d(k_array, phi_array, kind='linear',
                                bounds_error=False, fill_value='extrapolate')
            return phi_func
        
        # Otherwise return momentum, weights, and \phi(k)
        else:
            return k_array, k_weights, phi_array

In [23]:
def kronecker_delta(x, y):
    """Kronecker \delta function: \delta_{x,y}."""
    
    return int(x == y)

def get_orbital_file_name(n, l, j, m_t):
    """Returns the file name of the orbital."""
        
    # Proton
    if m_t == 1/2:
        file_name = f"p.n{int(n-1)}.l{int(l)}.j{int(2*j)}.orb"
    # Neutron
    elif m_t == -1/2:
        file_name = f"n.n{int(n-1)}.l{int(l)}.j{int(2*j)}.orb"
        
    return file_name

def get_sp_wave_functions(sp_basis, kmax, kmid, ntot):
    """Set interpolating functions for s.p. wave functions \phi."""
    
    occ_states = sp_basis.occ_states

    phi_functions = {}
    for sp_state in occ_states: 
        file_name = get_orbital_file_name(sp_state.n, sp_state.l, sp_state.j,
                                          sp_state.m_t)
        phi_functions[file_name] = sp_basis.get_wf_kspace(
            sp_state, kmax, kmid, ntot, interpolate=True)
            
    return phi_functions

In [24]:
# Get radial s.p. wave functions for O16
woods_saxon = WoodsSaxon('O16', 8, 8, run_woodsaxon=False)
phi_functions = get_sp_wave_functions(woods_saxon, 10.0, 2.0, 120)

In [25]:
def psi(n, l, j, m_j, m_t, k, theta, phi, sigma, tau, cg_table, phi_functions):
    """Single-particle wave function."""

    # Calculate \phi_\alpha(q)
    phi_sp_wf = phi_functions[get_orbital_file_name(n, l, j, m_t)](k)
    
    # Calculate spinor spherical harmonic
    Y_jml = spinor_spherical_harmonic(l, j, m_j, theta, phi, sigma, cg_table)
    
    # Isospinor indexed by \tau \chi_{m_t}(\tau)
    chi_tau = kronecker_delta(tau, m_t)

    return phi_sp_wf * Y_jml * chi_tau

def spinor_spherical_harmonic(l, j, m_j, theta, phi, sigma, cg_table):
    """Spinor spherical harmonic for a s.p. state described by the quantum
    numbers j, m_j, l, and s=1/2.
    """
    
    Y_ljm = 0+0j

    # Spinor indexed by \sigma \eta_{m_s}^(\sigma) = \delta_{m_s, \sigma}
    m_s = sigma
    
    # m_l must be fixed since m_j and m_s are determined
    m_l = m_j - m_s
    
    # Check that |m_l| <= l
    if np.abs(m_l) <= l:
        
        # Clebsch-Gordan coefficient for l-s coupling
        cg = cg_table[(l, m_l, 1/2, m_s, j, m_j)]
        
        # Spherical harmonic
        Y_lm = sph_harm(m_l, l, phi, theta)
        
        Y_ljm = cg * Y_lm

    return Y_ljm

In [26]:
def psis_by_looping(
    k_array, theta_array, phi_array, n_array, l_array, j_array, m_j_array,
    m_t_array, sigma_array, tau_array, cg_table, phi_functions
):
    
    N_batch = len(k_array)
    psi_array = np.zeros(N_batch, dtype='complex')
    for i in range(N_batch):
        
        n = n_array[i]
        l = l_array[i]
        j = j_array[i]
        m_j = m_j_array[i]
        m_t = m_t_array[i]
        k = k_array[i]
        theta = theta_array[i]
        phi = phi_array[i]
        sigma = sigma_array[i]
        tau = tau_array[i]
        
        psi_array[i] = psi(n, l, j, m_j, m_t, k, theta, phi, sigma, tau,
                           cg_table, phi_functions)
        
    return psi_array

In [27]:
psi_func_vect = np.vectorize(psi, otypes=[complex])

In [28]:
def kronecker_delta_jax(x, y):
    """Kronecker \delta function: \delta_{x,y}."""
    
    return jnp.array(x == y, dtype=int)

@jit
def vectorized_kronecker_delta(x, y):
    return vmap(kronecker_delta_jax)(x, y)

def get_phi_function(n, l, j, m_t, k):
    return phi_functions[get_orbital_file_name(n, l, j, m_t)](k)

vectorized_phi_function = np.vectorize(get_phi_function)

In [29]:
def psi_jax_1(n, l, j, m_j, m_t, k, theta, phi, sigma, tau, cg_array, N_j):
    """Single-particle wave function."""
    
    # Calculate \phi_\alpha(q)
    phi_sp_wf = vectorized_phi_function(n, l, j, m_t, k)
    
    # Calculate spinor spherical harmonic
    Y_jml = spinor_spherical_harmonic_jax_1(l, j, m_j, theta, phi, sigma, cg_array, N_j)
    
    # Isospinor indexed by \tau \chi_{m_t}(\tau)
    chi_tau = vectorized_kronecker_delta(tau, m_t).block_until_ready()

    return phi_sp_wf * Y_jml * chi_tau

def spinor_spherical_harmonic_jax_1(l, j, m_j, theta, phi, sigma, cg_array, N_j):
    """Spinor spherical harmonic for a s.p. state described by the quantum
    numbers j, m_j, l, and s=1/2.
    """

    # Spinor indexed by \sigma \eta_{m_s}^(\sigma) = \delta_{m_s, \sigma}
    m_s = sigma
    
    # m_l must be fixed since m_j and m_s are determined
    m_l = m_j - m_s

    # Clebsch-Gordan coefficient for l-s coupling
    cg = cg_func_vect_jax(l, m_l, np.repeat(1/2, m_s.size), m_s, j, m_j, cg_array, N_j).block_until_ready()
    
    # Check that |m_l| <= l
    Y_ljm = np.where(np.abs(m_l) <= l, cg * sph_harm(m_l, l, phi, theta), 0+0j)
    # Y_lm = sph_harm_jax(m_l, l, phi, theta).block_until_ready()
    # Y_ljm = np.where(np.abs(m_l) <= l, cg * Y_lm, 0+0j)

    return Y_ljm

def psi_jax_2(n, l, j, m_j, m_t, k, theta, phi, sigma, tau, cg_array, N_j, n_max):
    """Single-particle wave function."""
    
    # Calculate \phi_\alpha(q)
    phi_sp_wf = vectorized_phi_function(n, l, j, m_t, k)
    
    # Calculate spinor spherical harmonic
    Y_jml = vectorized_spinor_spherical_harmonic(l, j, m_j, theta, phi, sigma, cg_array, N_j, n_max).block_until_ready()
    
    # Isospinor indexed by \tau \chi_{m_t}(\tau)
    chi_tau = vectorized_kronecker_delta(tau, m_t).block_until_ready()

    return phi_sp_wf * Y_jml * chi_tau

def spinor_spherical_harmonic_jax_2(l, j, m_j, theta, phi, sigma, cg_array, N_j, n):
    """Spinor spherical harmonic for a s.p. state described by the quantum
    numbers j, m_j, l, and s=1/2.
    """

    # Spinor indexed by \sigma \eta_{m_s}^(\sigma) = \delta_{m_s, \sigma}
    m_s = sigma
    
    # m_l must be fixed since m_j and m_s are determined
    m_l = m_j - m_s

    # Clebsch-Gordan coefficient for l-s coupling
    cg = get_cg_coefficient_jax(l, m_l, np.repeat(1/2, m_s.size), m_s, j, m_j, cg_array, N_j)
    
    # Check that |m_l| <= l
#     Y_ljm = np.where(
#         np.abs(m_l) <= l, cg * sph_harm(m_l, l, phi, theta), 0+0j
#     )
    Y_lm = sph_harm_jax(m_l, l, phi, theta, n_max=n).block_until_ready()
    Y_ljm = jnp.where(
        jnp.abs(m_l) <= l, cg * Y_lm, 0+0j
    )
#     Y_ljm = jnp.where(
#         jnp.abs(m_l) <= l, cg * sph_harm_jax(m_l, l, phi, theta), 0+0j
#     )
#     Y_ljm = jnp.where(
#         jnp.abs(m_l) <= l, cg * sph_harm_jax(m_l, l, phi, theta, n_max=jnp.amax(l)), 0+0j
#     )
#     Y_ljm = jnp.where(
#         jnp.abs(m_l) <= l, cg * sph_harm_jax(m_l, l, phi, theta, n_max=2), 0+0j
#     )
#     Y_ljm = jnp.where(
#         np.abs(m_l) <= l, cg * sph_harm(m_l, l, phi, theta), 0+0j
#     )
    # THE SPH_HARM_JAX FUNCTION ISN'T WORKING!

    return Y_ljm

@partial(jit, static_argnames=['n_max'])
# @partial(jit, static_argnames=['cg_array', 'N_j', 'n_max'])
def vectorized_spinor_spherical_harmonic(l, j, m_j, theta, phi, sigma, cg_array, N_j, n_max):
    return vmap(
        spinor_spherical_harmonic_jax_2, in_axes=(0, 0, 0, 0, 0, 0, None, None, None), out_axes=(0)
    )(l, j, m_j, theta, phi, sigma, cg_array, N_j, n_max)
#     return vmap(spinor_spherical_harmonic_jax_2)(l, j, m_j, theta, phi, sigma, cg_array, N_j, n_max)

In [30]:
# Batch of momenta samples
N_batch = 100
k_array = np.random.random(N_batch) * 10  # 10 fm^-1 max
theta_array = np.random.random(N_batch) * np.pi  # \pi max
phi_array = np.random.random(N_batch) * 2*np.pi  # 2\pi max

In [31]:
# Batch of s.p. states
n_array = np.zeros(N_batch, dtype=int)
l_array = np.zeros(N_batch, dtype=int)
j_array = np.zeros(N_batch, dtype=float)
m_j_array = np.zeros(N_batch, dtype=float)
m_t_array = np.zeros(N_batch, dtype=float)

N_mod = len(woods_saxon.occ_states)

for i in range(N_batch):

    sp_state = woods_saxon.occ_states[i % N_mod]
    n_array[i] = sp_state.n
    l_array[i] = sp_state.l
    j_array[i] = sp_state.j
    m_j_array[i] = sp_state.m_j
    m_t_array[i] = sp_state.m_t

In [32]:
# Batch of \sigma values
random_numbers = np.random.random(N_batch)
sigma_array = np.zeros((N_batch, 2))
for i in range(N_batch):
    if random_numbers[i] > 0.5:
        sigma_array[i, 1] = 1/2
    else:
        sigma_array[i, 1] = -1/2
sigma_array[:, 0] = 1/2

In [33]:
# Batch of \tau values
random_numbers = np.random.random(N_batch)
tau_array = np.zeros((N_batch, 2))
for i in range(N_batch):
    if random_numbers[i] > 0.5:
        tau_array[i, 1] = 1/2
    else:
        tau_array[i, 1] = -1/2
tau_array[:, 0] = 1/2

In [34]:
# JAX arrays
k_array_jax = jnp.array(k_array)
theta_array_jax = jnp.array(theta_array)
phi_array_jax = jnp.array(phi_array)

n_array_jax = jnp.array(n_array)
l_array_jax = jnp.array(l_array)
j_array_jax = jnp.array(j_array)
m_j_array_jax = jnp.array(m_j_array)
m_t_array_jax = jnp.array(m_t_array)

sigma_array_jax = jnp.array(sigma_array[:, 1])
tau_array_jax = jnp.array(tau_array[:, 1])

In [35]:
# Try not using vmap on vector spherical harmonic, vectorized CG, Python spherical harmonic
# 1.33ms

# Try not using vmap on vector spherical harmonic, vectorized CG, JAX spherical harmonic
# TypeError

# Try using vmap on vector spherical harmonic, get_cg_coefficient_jax, Python spherical harmonic
# TracerArrayConversionError
# Need to use JAX arrays within vectorized vector spherical harmonic function

# Try using vmap on vector spherical harmonic, get_cg_coefficient_jax, JAX spherical harmonic
# ConcretizationTypeError
# Vague thing about n_max in the JAX sph_harm function. Setting this doesn't fix the issue...

# Try JAX arrays using vmap on vector spherical harmonic, get_cg_coefficient_jax, JAX spherical harmonic
# Same error as previous.

In [36]:
# Compare the cost of looping over N points in
# (n, l, j, m_j, m_t, k, theta, phi, \sigma, \tau) space to a vectorized func

# Looping
%timeit psi_array_1 = psis_by_looping(k_array, theta_array, phi_array, n_array, l_array, j_array, m_j_array, m_t_array, sigma_array[:, -1], tau_array[:, -1], cg_table, phi_functions)

# numpy.vectorize
%timeit psi_array_2 = psi_func_vect(n_array, l_array, j_array, m_j_array, m_t_array, k_array, theta_array, phi_array, sigma_array[:, -1], tau_array[:, -1], cg_table, phi_functions)

# JAX with vectorized CG, Python spherical harmonic, and NumPy array arguments
%timeit psi_array_3 = psi_jax_1(n_array, l_array, j_array, m_j_array, m_t_array, k_array, theta_array, phi_array, sigma_array[:, -1], tau_array[:, -1], cg_array, N_j).block_until_ready()

# JAX with vectorized_spinor_spherical_harmonic, JAX spherical harmonic, and JAX array arguments
%timeit psi_array_4 = psi_jax_2(n_array_jax, l_array_jax, j_array_jax, m_j_array_jax, m_t_array_jax, k_array_jax, theta_array_jax, phi_array_jax, sigma_array_jax, tau_array_jax, cg_array, N_j, 2).block_until_ready()

1.65 ms ± 7.13 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.55 ms ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.33 ms ± 16.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


IndexError: tuple index out of range

In [None]:
# JAX with vectorized_spinor_spherical_harmonic, JAX spherical harmonic, and JAX array arguments
%timeit psi_array_4 = psi_jax_2(n_array_jax, l_array_jax, j_array_jax, m_j_array_jax, m_t_array_jax, k_array_jax, theta_array_jax, phi_array_jax, sigma_array_jax, tau_array_jax, cg_array, N_j, 1).block_until_ready()



## Does JAX work with SciPy 2D interpolation?

JAX does not support SciPy's 2D interpolation functions yet.

In [None]:
def get_delta_U():
    
    # Get momentum mesh
    k_array, k_weights = momentum_mesh(15.0, 3.0, 120)
    
    srg = SRG(6, '1S0', 15.0, 3.0, 120, 'Wegner')
    U_matrix_weights = srg.load_srg_transformation(1.35)

    # Calculate \delta U = U - I
    I_matrix_weights = np.eye(len(U_matrix_weights))
    delU_matrix_weights = U_matrix_weights - I_matrix_weights

    delU_matrix = unattach_weights_from_matrix(k_array, k_weights,
                                               delU_matrix_weights[:120,:120])
        
    # Interpolate \delta U(k, k')
    delU_func = RectBivariateSpline(k_array, k_array, delU_matrix)

    return delU_func

In [None]:
# Vectorize this function with JAX
@jit
def get_delta_U_vectorized(k, kp, delU_func):
    
    return vmap(delU_func)(k, kp)

In [None]:
delU_func = get_delta_U()
a = delU_func.ev(1.0, 1.0)
b = get_delta_U_vectorized(1.0, 1.0, delU_func)