# Testing vectorized functions

__Author:__ A. J. Tropiano [atropiano@anl.gov]<br/>
__Date:__ April 13, 2023

In this notebook, we test how to vectorize Python functions. We will compare brute force looping to `numpy.vectorize`, and potential speed-ups using Jax or Numba.

_Last update:_ April 18, 2023

In [1]:
# Python imports
import numpy as np
import numpy.linalg as la
from scipy.interpolate import interp1d
from scipy.special import sph_harm, spherical_jn
from sympy.physics.quantum.cg import CG

In [2]:
# Imports from scripts
from scripts.integration import momentum_mesh

## Calculating arrays of Clebsch-Gordan coefficients

In [3]:
def compute_clebsch_gordan_table(j_max):
    """
    Calculate Clebsch-Gordan coefficients for combinations of j and m_j up
    to j_max.
    
    Parameters
    ----------
    j_max : int
        Maximum j value for j_1, j_2, and j_3. This also constrains m_j.
    
    Returns
    -------
    cg_table : dict
        Table of Clebsch-Gordan coefficients <j_1 m_j_1 j_2 m_j_2|j_3 m_j_3>
        for each combination of angular momenta.
        
    """
        
    cg_table = {}
        
    j_array = np.arange(0, j_max+1/2, 1/2)
    
    for j_1 in j_array:
        for j_2 in j_array:
            j_3_array = np.arange(abs(j_1-j_2), j_1+j_2+1/2)
            for j_3 in j_3_array:
                for m_1 in np.arange(-j_1, j_1+1, 1):
                    for m_2 in np.arange(-j_2, j_2+1, 1):
                        m_3 = m_1 + m_2
                        if abs(m_3) <= j_3:
                            cg_table[(j_1,m_1,j_2,m_2,j_3,m_3)] = float(
                                CG(j_1,m_1,j_2,m_2,j_3,m_3).doit()
                            )
                                
    return cg_table

In [4]:
# Get CG table for fast look-up
cg_table = compute_clebsch_gordan_table(3)
# The function below depends on this!

In [5]:
def cg_func(j1, m1, j2, m2, j3, m3):
    
    if m1 + m2 == m3 and np.absolute(m1) <= j1 and np.absolute(m2) <= j2 and np.absolute(m3) <= j3:
        
        return cg_table[(j1, m1, j2, m2, j3, m3)]
    
    else:
        
        return 0

In [6]:
# Batch of spin values
N_batch = 100
random_numbers = np.random.random((N_batch, 2))
spin_samples = np.zeros((N_batch, 4))
for i in range(N_batch):
    for j in range(2):
        if random_numbers[i, j] > 0.5:
            spin_samples[i, 2*j+1] = 1/2
        else:
            spin_samples[i, 2*j+1] = -1/2  
spin_samples[:, 0], spin_samples[:, 2] = 1/2, 1/2

In [7]:
def cgs_by_looping(s1, sigma_1, s2, sigma_2, S, M_S):
    """Get array of CG's by looping."""
    
    N_batch = len(s1)
    cg_array = np.zeros(N_batch)
    for i in range(N_batch):
        cg_array[i] = cg_func(s1[i], sigma_1[i], s2[i], sigma_2[i], S, M_S)
        
    return cg_array

In [8]:
cg_func_vect = np.vectorize(cg_func, otypes=[float])

In [9]:
# Get array of CG's using jax.numpy.vectorize

In [10]:
# Comparison of the three
s1 = spin_samples[:, 0]
sigma_1 = spin_samples[:, 1]
s2 = spin_samples[:, 2]
sigma_2 = spin_samples[:, 3]
S, M_S = 1, 0

# Looping
%time cg_array_1 = cgs_by_looping(s1, sigma_1, s2, sigma_2, S, M_S)

# numpy.vectorize
%time cg_array_2 = cg_func_vect(s1, sigma_1, s2, sigma_2, S, M_S)

# Jax
# ...

CPU times: user 138 µs, sys: 11 µs, total: 149 µs
Wall time: 150 µs
CPU times: user 104 µs, sys: 6 µs, total: 110 µs
Wall time: 112 µs


## Calculating spherical harmonics given arrays of $l$, $m_l$, $\theta$, and $\phi$

In [11]:
# Batch of spin values
l_array = np.array([0, 1, 0, 1, 0, 0, 2, 1, 1, 2, 0, 1, 0, 1, 0, 0, 2, 1, 1, 2])
m_array = np.concatenate((l_array[:10], -l_array[10:]))
theta_array = np.linspace(0.0, np.pi, 20)
phi_array = np.linspace(0.0, 2*np.pi, 20)

In [12]:
# Looping
def ylms_by_looping(l_array, m_array, theta_array, phi_array):
    
    N_batch = len(l_array)
    ylm_array = np.zeros(N_batch, dtype='complex')
    for i in range(N_batch):
        ylm_array[i] = sph_harm(m_array[i], l_array[i], phi_array[i],
                                theta_array[i])
    return ylm_array

In [13]:
# Spherical harmonic functions from scipy are already vectorized!

In [14]:
# Comparison of the two

# Looping
%time ylm_array_1 = ylms_by_looping(l_array, m_array, theta_array, phi_array)

# Vectorized
%time ylm_array_2 = sph_harm(m_array, l_array, phi_array, theta_array)

CPU times: user 60 µs, sys: 18 µs, total: 78 µs
Wall time: 78.9 µs
CPU times: user 4 µs, sys: 1e+03 ns, total: 5 µs
Wall time: 5.01 µs


## Vectorizing $\psi_\alpha(\mathbf{q};\sigma)$

In [15]:
class SingleParticleState:
    """
    Single-particle state class. Packs together the following single-particle
    quantum numbers into one object.
    
    Parameters
    ----------
    n : int
        Principal quantum number n = 1, 2, ...
    l : int
        Orbital angular momentum l = 0, 1, ...
    j : float
        Total angular momentum j = 1/2, 3/2, ...
    m_j : float
        Total angular momentum projection m_j = -j, -j+1, ..., j.
    tau : float
        Isospin projection tau = 1/2 or -1/2.
    
    """
    
    
    def __init__(self, n, l, j, m_j, tau):
        
        # Check if m_j is valid
        if abs(m_j) > j:
            raise RuntimeError("m_j is not valid.")
            
        # Check that |\tau| = 1/2
        if abs(tau) != 1/2:
            raise RuntimeError("tau is not valid.")
            
        self.n = n
        self.l = l
        self.j = j
        self.m_j = m_j
        self.tau = tau
        
        if tau == 1/2:
            self.nucleon = 'proton'
        elif tau == -1/2:
            self.nucleon = 'neutron'
        
        
    def __eq__(self, sp_state):

        if (
            self.n == sp_state.n and self.l == sp_state.l
            and self.j == sp_state.j and self.m_j == sp_state.m_j
            and self.tau == sp_state.tau
        ):
            
            return True
        
        else:
            
            return False
        
        
    def __str__(self):
        
        # Spectroscopic notation of orbital angular momentum
        l_str = convert_l_to_string(self.l)  # E.g., 's', 'p', 'd', ...
        
        # Display j subscript as a fraction
        numerator = 2*int(self.j) + 1
        denominator = 2

        return fr"${self.n}{l_str}_{{{numerator}/{denominator}}}$"

In [16]:
class WoodsSaxon:
    """
    Woods-Saxon class. Handles the wave functions associated with the
    Woods-Saxon potential from the subroutine in woodsaxon.f90. Outputs wave
    functions in coordinate and momentum space.
    
    Parameters
    ----------
    nucleus_name : str
        Name of the nucleus (e.g., 'O16', 'Ca40', etc.)
    Z : int
        Proton number of the nucleus.
    N : int
        Neutron number of the nucleus.
    run_woodsaxon : bool, optional
        Option to run the Woods-Saxon subroutine to generate orbital files.
    n_max : int, optional
        Maximum principal quantum number where n = 1, 2, ..., n_max.
    l_max : int, optional
        Maximum orbital angular momentum where l = 0, 1, ..., l_max.
    rmax : float, optional
        Maximum r for orbital tables.
    ntab : int, optional
        Number of points for orbital tables.
        
    """
    
    
    def __init__(
        self, nucleus_name, Z, N, run_woodsaxon=True, n_max=0, l_max=0, rmax=40,
        ntab=2000
    ):
        
        # Set instance attributes
        self.woods_saxon_directory = f"../data/woods_saxon/{nucleus_name}/"
        self.dr = rmax / ntab
        self.r_array = np.arange(self.dr, rmax + self.dr, self.dr)

        # Generate orbitals?
        if run_woodsaxon:
            
            self.run_woods_saxon_code(nucleus_name, Z, N, n_max, l_max, rmax,
                                      ntab)

            # Move output files to relevant directory
            shutil.move("ws_log", self.woods_saxon_directory + "ws_log")
            shutil.move("ws_pot", self.woods_saxon_directory + "ws_pot")
            shutil.move("ws_rho", self.woods_saxon_directory + "ws_rho")
                
        # Order single-particle states with lowest energy first
        self.order_sp_states(Z, N)
        
        # Organize wave functions in dictionary with the file name as the key
        self.sp_wfs = {}
        for sp_state in self.sp_states:
            # Wave functions are independent of m_j, so fix m_j=j
            if sp_state.m_j == sp_state.j:
                file_name = get_orbital_file_name(sp_state)
                if run_woodsaxon:
                    shutil.move(file_name,
                                self.woods_saxon_directory + file_name)
                data = np.loadtxt(self.woods_saxon_directory + file_name)
                # Use file name as the key
                self.sp_wfs[file_name] = data[:, 1]

            
    def run_woods_saxon_code(self, nucleus_name, Z, N, n_max, l_max, rmax, ntab):
        """Run Woods-Saxon code to generate data."""
        
        # Total number of nucleons
        A = Z + N
        
        # Type of orbitals: 1 - nucleons with no Coulomb
        #                   2 - distinguish protons and neutrons
        ntau = 2
        
        # Orbitals to consider (note, we track 2*j not j)
        norb, lorb, jorb = [], [], []
        for n in range(1, n_max+1):
            for l in range(0, l_max+1):
                norb.append(n)
                lorb.append(l)
                jorb.append(int(2*(l+1/2)))
                if int(2*(l-1/2)) > 0:  # Don't append negative j
                    norb.append(n)
                    lorb.append(l)
                    jorb.append(int(2*(l-1/2)))
        nrad = len(jorb)
        orbws = np.zeros(shape=(2,nrad,ntab), order='F')
    
        # Divide orbital by r? -> get R(r); false: get u(r)=r R(r)
        rdiv = False
        dens = True
    
        # Set parameters of the Woods-Saxon potential
        prm = np.zeros(shape=(2,9), order='F')
    
        # Starting with vws (p & n)
        if nucleus_name == 'He4':
            prm[:,0] = 76.8412
        elif nucleus_name == 'O16':
            prm[:,0] = 58.0611
        elif nucleus_name == 'Ca40':
            prm[:,0] = 54.3051
        elif nucleus_name == 'Ca48':
            prm[0,0] = 59.4522
            prm[1,0] = 46.9322
    
        # Not sure about these (better way to load these parameters?)
        prm[:,1] = 1.275
        prm[:,2] = 0.7
        prm[:,3] = 0.
        prm[:,4] = 1.
        prm[:,5] = 36
        prm[:,6] = 1.32
        prm[:,7] = 0.7
        prm[:,8] = 1.275
        
        # Print summary, potentials, and densities
        prnt = True
        prntorb = True

        # Run Fortran subroutine
        ws(ntau, A, Z, rmax, orbws, norb, lorb, jorb, prm, rdiv, prnt, prntorb,
           dens)
        
        
    def order_sp_states(self, Z, N):
        """Keep track of all s.p. states and occupied s.p. states"""

        self.sp_states = []
        self.occ_states = []
        proton_count = 0
        neutron_count = 0
        
        # File with details of the orbitals
        ws_file = self.woods_saxon_directory + "ws_log"
    
        # Order single-particle states using the ws_log file
        with open(ws_file, 'r') as f:
            for line in f:
                unit = line.strip().split()
                
                # Protons
                if len(unit) > 0 and unit[0] == '1':

                    j = int(unit[3])/2
                    for m_j in np.arange(-j, j+1, 1):
                        sp_state = SingleParticleState(
                            int(unit[1])+1, int(unit[2]), j, m_j, 1/2
                        )  # n, l, j, m_j, \tau
                    
                        self.sp_states.append(sp_state)
                    
                        if proton_count < Z:
                            self.occ_states.append(sp_state)
                            # Add up filled proton states
                            proton_count += 1
                    
                
                # Neutrons
                elif len(unit) > 0 and unit[0] == '2':

                    j = int(unit[3])/2
                    for m_j in np.arange(-j, j+1, 1):
                        sp_state = SingleParticleState(
                            int(unit[1])+1, int(unit[2]), j, m_j, -1/2
                        )  # n, l, j, m_j, \tau
                    
                        self.sp_states.append(sp_state)
                    
                        if neutron_count < N:
                            self.occ_states.append(sp_state)
                            # Add up filled neutron states
                            neutron_count += 1
                        
                        
    def get_wf_rspace(self, sp_state, print_normalization=False):
        """Single-particle wave function in coordinate space."""
        
        # Orbital file name is the key
        u_array = self.sp_wfs[get_orbital_file_name(sp_state)]

        # Normalization: \int dr |u(r)|^2 = 1
        if print_normalization:
            normalization = np.sum(self.dr*u_array**2)
            print(f"Coordinate space normalization = {normalization}.")

        return self.r_array, u_array
    
    
    def fourier_transformation(self, l, k_array):
        """Fourier transformation matrix for given orbital angular momentum."""
        
        # r_array column vectors and k_array row vectors where both grids are
        # n x m matrices
        r_cols, k_rows = np.meshgrid(self.r_array, k_array)
        
        # Transformation matrix with shape n x m, where m is the length of
        # r_array and n is the length of the k_array
        M = 1j**(-l) * np.sqrt(2/np.pi) * self.dr * r_cols * spherical_jn(
            l, k_rows*r_cols
        )
        
        return M
    
    
    def get_wf_kspace(
            self, sp_state, kmax, kmid, ntot, print_normalization=False,
            interpolate=False,
            
    ):
        """Single-particle wave function in momentum space."""
    
        # Set momentum mesh with more points at low momentum
        k_array, k_weights = momentum_mesh(kmax, kmid, ntot)
    
        # Get coordinate-space s.p. wave function
        _, u_array = self.get_wf_rspace(sp_state)

        # Fourier-transform the wave function to momentum space
        phi_array = self.fourier_transformation(sp_state.l, k_array) @ u_array
    
        # Normalization: \int dk k^2 |\phi(k)|^2 = 1
        if print_normalization:
            normalization = np.sum(k_weights*k_array**2*abs(phi_array)**2)
            print(f"Momentum space normalization = {normalization}.")
            
        # Interpolate and return function?
        if interpolate:
            phi_func = interp1d(k_array, phi_array, kind='linear',
                                bounds_error=False, fill_value='extrapolate')
            return phi_func
        
        # Otherwise return momentum, weights, and \phi(k)
        else:
            return k_array, k_weights, phi_array

In [17]:
def get_orbital_file_name(sp_state):
    """Returns the file name of the orbital."""
        
    n, l, j = sp_state.n, sp_state.l, sp_state.j
    # Proton
    if sp_state.tau == 1/2:
        file_name = f"p.n{int(n-1)}.l{l}.j{int(2*j)}.orb"
    # Neutron
    elif sp_state.tau == -1/2:
        file_name = f"n.n{int(n-1)}.l{l}.j{int(2*j)}.orb"
        
    return file_name

In [18]:
def get_sp_wave_functions(sp_basis, kmax, kmid, ntot):
    """Set interpolating functions for s.p. wave functions \phi."""
    
    occ_states = sp_basis.occ_states

    phi_functions = {}
    for sp_state in occ_states: 
        file_name = get_orbital_file_name(sp_state)
        phi_functions[file_name] = sp_basis.get_wf_kspace(
            sp_state, kmax, kmid, ntot, interpolate=True)
            
    return phi_functions

In [19]:
def build_vector(k, theta, phi):
    """
    Build a vector from input spherical coordinates.

    Parameters
    ----------
    k : float
        Magnitude of the vector.
    theta : float
        Polar angle of the vector in the range [0, \pi].
    phi : float
        Azimuthal angle of the vector in the range [0, 2\pi].

    Returns
    -------
    k_vector : 1-D ndarray
        Output vector with shape (3,1).

    """

    k_vector = np.array([k * np.sin(theta) * np.cos(phi),
                         k * np.sin(theta) * np.sin(phi),
                         k * np.cos(theta)])

    return k_vector

In [20]:
def get_vector_components(k_vector):
    """
    Get the spherical coordinates from an input vector.

    Parameters
    ----------
    k_vector : 1-D ndarray
        Input vector with shape (3,1).

    Returns
    -------
    k : float
        Magnitude of the vector.
    theta : float
        Polar angle of the vector in the range [0, \pi].
    phi : float
        Azimuthal angle of the vector in the range [0, 2\pi].

    """

    k = la.norm(k_vector, axis=0)
    theta = np.arccos(k_vector[2]/k)
    phi = np.arctan2(k_vector[1], k_vector[0])

    return k, theta, phi

In [21]:
woods_saxon = WoodsSaxon('O16', 8, 8, run_woodsaxon=False)
phi_functions = get_sp_wave_functions(woods_saxon, 10.0, 2.0, 120)

In [22]:
def psi(sp_state, q_vector, sigma):
    """Single-particle wave function including the Clebsch-Gordan coefficient 
    and spherical harmonic.
    """
        
    # Unpack q_vector into magnitude and angles
    q, theta, phi = get_vector_components(q_vector)
        
    # Calculate \phi_\alpha(q)
    phi_sp_wf = phi_functions[get_orbital_file_name(sp_state)](q)
    
    # m_l is determined by m_j and \sigma
    m_l = sp_state.m_j - sigma
        
    # Check that m_l is allowed
    if abs(m_l) > sp_state.l:
        return 0
        
    # Clebsch-Gordan coefficient
    cg = cg_table[(sp_state.l, m_l, 1/2, sigma, sp_state.j, sp_state.m_j)]
        
    # Spherical harmonic
    Y_lm = sph_harm(m_l, sp_state.l, phi, theta)

    return phi_sp_wf * cg * Y_lm

In [23]:
N_batch = 100
k_array = np.random.random(N_batch) * 10  # 10 fm^-1 max
theta_array = np.random.random(N_batch) * np.pi  # \pi max
phi_array = np.random.random(N_batch) * 2*np.pi  # 2\pi max
k_vector_array = build_vector(k_array, theta_array, phi_array)

In [24]:
sp_states = []
for i in range(N_batch):
    # protons if even
    if i % 2 == 0:
        tau = 1/2
        # p-waves if divisible by 3
        if i % 3 == 0:
            j = 3/2
            m_j = 3/2
            l = 1
        # d-waves if divisible by 9
        elif i % 9 == 0:
            j = 3/2
            m_j = -1/2
            l = 2
        # s-waves otherwise
        else:
            j = 1/2
            m_j = 1/2
            l = 0
    # neutrons if odd
    else: 
        tau = -1/2
        # p-waves if divisible by 3
        if i % 3 == 0:
            j = 3/2
            m_j = 3/2
            l = 1
        # d-waves if divisible by 9
        elif i % 9 == 0:
            j = 3/2
            m_j = -1/2
            l = 2
        # s-waves otherwise
        else:
            j = 1/2
            m_j = 1/2
            l = 0 
    sp_states.append(SingleParticleState(1, l, j, m_j, tau))

In [25]:
random_numbers = np.random.random(N_batch)
sigma_array = np.zeros((N_batch, 2))
for i in range(N_batch):
    if random_numbers[i] > 0.5:
        sigma_array[i, 1] = 1/2
    else:
        sigma_array[i, 1] = -1/2
sigma_array[:, 0] = 1/2

In [26]:
def psis_by_looping(sp_states, k_vector_array, sigma_array):
    
    N_batch = len(sp_states)
    psi_array = np.zeros(N_batch, dtype='complex')
    for i in range(N_batch):
        psi_array[i] = psi(sp_states[i], k_vector_array[:, i], sigma_array[i, 1])
    return psi_array

In [27]:
def psi_for_vect(sp_state, k, theta, phi, spin, spin_projection):
    """Single-particle wave function including the Clebsch-Gordan coefficient 
    and spherical harmonic.
    """

    # Calculate \phi_\alpha(q)
    phi_sp_wf = phi_functions[get_orbital_file_name(sp_state)](k)
    
    # m_l is determined by m_j and \sigma
    m_l = sp_state.m_j - spin_projection
        
    # Check that m_l is allowed
    if abs(m_l) > sp_state.l:
        return 0
        
    # Clebsch-Gordan coefficient
    cg = cg_table[(sp_state.l, m_l, 1/2, spin_projection, sp_state.j,
                   sp_state.m_j)]
        
    # Spherical harmonic
    Y_lm = sph_harm(m_l, sp_state.l, phi, theta)

    return phi_sp_wf * cg * Y_lm

In [28]:
psi_func_vect = np.vectorize(psi_for_vect, otypes=[complex])

In [31]:
# Compare the cost of looping over N points in (sp_state, q, theta, phi) space
# to a vectorized version of this function

# Looping
%time psi_array_1 = psis_by_looping(sp_states, k_vector_array, sigma_array)

# numpy.vectorize
%time psi_array_2 = psi_func_vect(sp_states, k_array, theta_array, phi_array, sigma_array[:, 0], sigma_array[:, 1])

CPU times: user 11.1 ms, sys: 3.57 ms, total: 14.6 ms
Wall time: 12.3 ms
[-5.73599368e-08+3.08845712e-07j  0.00000000e+00+0.00000000e+00j
  6.32651522e-08+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
  6.01138280e-08+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
 -3.55922893e-02+3.49240289e-02j -8.97393304e-04+0.00000000e+00j
  4.21130663e-07+0.00000000e+00j -6.16547179e-08-9.47706005e-08j
  0.00000000e+00+0.00000000e+00j  2.32335801e-08+0.00000000e+00j
 -2.06641969e-07+1.23832263e-08j  1.59255682e-08+0.00000000e+00j
  2.30737952e-01+0.00000000e+00j -2.84408271e-01+1.50895640e-01j
 -7.36828450e-04+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
  7.48024664e-05+5.28807279e-06j  0.00000000e+00+0.00000000e+00j
  3.98597511e-07+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
  0.00000000e+00+0.00000000e+00j  8.09253480e-02+0.00000000e+00j
 -3.65497200e-07+2.16337139e-07j  4.87917484e-06+0.00000000e+00j
  0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
  4.90823111e-01+

array([1, 1, 0])