# Testing integration with `vegas`

__Author:__ A. J. Tropiano [atropiano@anl.gov]<br/>
__Date:__ April 13, 2023

In this notebook, we test how to vectorize Python functions. We will compare brute force looping to `numpy.vectorize`, and potential speed-ups using Jax or Numba.

_Last update:_ April 13, 2023

In [None]:
# Python imports
import numpy as np
from scipy.special import sph_harm
from sympy.physics.quantum.cg import CG

## Calculating arrays of Clebsch-Gordan coefficients

In [None]:
def compute_clebsch_gordan_table(j_max):
    """
    Calculate Clebsch-Gordan coefficients for combinations of j and m_j up
    to j_max.
    
    Parameters
    ----------
    j_max : int
        Maximum j value for j_1, j_2, and j_3. This also constrains m_j.
    
    Returns
    -------
    cg_table : dict
        Table of Clebsch-Gordan coefficients <j_1 m_j_1 j_2 m_j_2|j_3 m_j_3>
        for each combination of angular momenta.
        
    """
        
    cg_table = {}
        
    j_array = np.arange(0, j_max+1/2, 1/2)
    
    for j_1 in j_array:
        for j_2 in j_array:
            j_3_array = np.arange(abs(j_1-j_2), j_1+j_2+1/2)
            for j_3 in j_3_array:
                for m_1 in np.arange(-j_1, j_1+1, 1):
                    for m_2 in np.arange(-j_2, j_2+1, 1):
                        m_3 = m_1 + m_2
                        if abs(m_3) <= j_3:
                            cg_table[(j_1,m_1,j_2,m_2,j_3,m_3)] = float(
                                CG(j_1,m_1,j_2,m_2,j_3,m_3).doit()
                            )
                                
    return cg_table

In [None]:
# Get CG table for fast look-up
cg_table = compute_clebsch_gordan_table(3)

In [None]:
# Batch of spin values
# ...

In [None]:
# Get array of CG's by brute force looping

In [None]:
# Get array of CG's using numpy.vectorize

In [None]:
# Get array of CG's using jax.numpy.vectorize

In [None]:
# Comparison of the three
# ...

## Calculating spherical harmonics given arrays of $l$, $m_l$, $\theta$, and $\phi$

In [None]:
# Brute force looping

In [None]:
# numpy.vectorize

In [None]:
# jax

In [None]:
# Comparison of the three
# ...

## Vectorizing $\psi_\alpha(\mathbf{q};\sigma)$

In [None]:
def psi(sp_state, q_vector, sigma, cg_table, phi_functions):
    """Single-particle wave function including the Clebsch-Gordan coefficient 
    and spherical harmonic.
    """
        
    # Unpack q_vector into magnitude and angles
    q, theta, phi = get_vector_components(q_vector)
        
    # Calculate \phi_\alpha(q)
    phi_sp_wf = phi_functions[get_orbital_file_name(sp_state)](q)
    
    # m_l is determined by m_j and \sigma
    m_l = sp_state.m_j - sigma
        
    # Check that m_l is allowed
    if abs(m_l) > sp_state.l:
        return 0
        
    # Clebsch-Gordan coefficient
    cg = cg_table[(sp_state.l, m_l, 1/2, sigma, sp_state.j, sp_state.m_j)]
        
    # Spherical harmonic
    Y_lm = sph_harm(m_l, sp_state.l, phi, theta)

    return phi_sp_wf * cg * Y_lm

In [None]:
# Compare the cost of looping over N points in (sp_state, q, theta, phi) space
# to a vectorized version of this function

# Note: cg_table and phi_functions might need to be attributes of a class