In [96]:
def get_marr_wigvals(nl1, nl2):
    enn1, ell1 = nl1
    enn2, ell2 = nl2
    ell = min(ell1, ell2)
    m = np.arange(-ell, ell+1)
    s_arr = np.array([1, 3, 5])

    wigvals = np.zeros((2*ell+1, len(s_arr)))
    for i in range(len(s_arr)):
        wigvals[:, i] = qdcls.w3j_vecm(ell1, s_arr[i], ell2, -m, 0*m, m)
    return m, wigvals

def test1(n, l, eig_dict):
    arg_str = f"{n}.{l}"
    U1 = eig_dict['eigU'][arg_str]*2.0
    return U1   

def create_submatrix(nl1, nl2, rws, omegaref, mwigs, eig_dict, use_precomputed=False):
    enn1, ell1 = nl1
    enn2, ell2 = nl2
    r, wsr = rws
    m, wigvals = mwigs

    eigU = eig_dict['eigU']
    eigV = eig_dict['eigV']
    s_arr = jnp.array([1, 3, 5])


    Tsr = jnp.zeros((len(s_arr), len(r)))
    arg_str1 = f"{enn1}.{ell1}"
    arg_str2 = f"{enn2}.{ell2}"
    if use_precomputed:
        arg_str1 = f"{enn1}.{ell1}"
        arg_str2 = f"{enn2}.{ell2}"
        U1 = eigU[arg_str1]
        U2 = eigU[arg_str2]
        V1 = eigV[arg_str1]
        V2 = eigV[arg_str2]
    else:
        U1, V1 = eig_dict['eigU'][arg_str1], eig_dict['eigV'][arg_str1]
        U2, V2 = eig_dict['eigU'][arg_str2], eig_dict['eigV'][arg_str2]
    L1sq = ell1*(ell1+1)
    L2sq = ell2*(ell2+1)
    Om1 = qdcls.Omega(ell1, 0)
    Om2 = qdcls.Omega(ell2, 0)

    for i in range(len(s_arr)):
        s = s_arr[i]
        ls2fac = L1sq + L2sq - s*(s+1)
        eigfac = U2*V1 + V2*U1 - U1*U2 - 0.5*V1*V2*ls2fac
        wigval = qdcls.w3j(ell1, s, ell2, -1, 0, 1)
        Tsr[i, :] = -(1 - qdcls.minus1pow(ell1 + ell2 + s)) * \
            Om1 * Om2 * wigval * eigfac / r
        LOGGER.debug(" -- s = {}, eigmax = {}, wigval = {}, Tsrmax = {}"\
                        .format(s, abs(eigfac).max(), wigval, abs(Tsr[i, :]).max()))

    # -1 factor from definition of toroidal field
    integrand = Tsr * wsr   # since U and V are scaled by sqrt(rho) * r
    integral = simps(integrand, axis=1, x=r)
    prod_gammas = qdcls.gamma(ell1) * qdcls.gamma(ell2) * qdcls.gamma(s_arr)
    omegaref = omegaref
    Cvec = qdcls.minus1pow_vec(m) * 8*np.pi * omegaref * (wigvals @ (prod_gammas * integral))
    return Cvec


def create_submatrix2(nl1, nl2, r, wsr, U1, V1, omegaref, m, wigvals, use_precomputed=False):
    enn1, ell1 = nl1
    enn2, ell2 = nl2

    s_arr = jnp.array([1, 3, 5])
    
    U2, V2 = U1, V1

    Tsr = jnp.zeros((len(s_arr), len(r)))
    arg_str1 = f"{enn1}.{ell1}"
    arg_str2 = f"{enn2}.{ell2}"
    L1sq = ell1*(ell1+1)
    L2sq = ell2*(ell2+1)
    Om1 = qdcls.Omega(ell1, 0)
    Om2 = qdcls.Omega(ell2, 0)

    for i in range(len(s_arr)):
        s = s_arr[i]
        ls2fac = L1sq + L2sq - s*(s+1)
        eigfac = U2*V1 + V2*U1 - U1*U2 - 0.5*V1*V2*ls2fac
        wigval = qdcls.w3j(ell1, s, ell2, -1, 0, 1)
        Tsr[i, :] = -(1 - qdcls.minus1pow(ell1 + ell2 + s)) * \
            Om1 * Om2 * wigval * eigfac / r
        LOGGER.debug(" -- s = {}, eigmax = {}, wigval = {}, Tsrmax = {}"\
                        .format(s, abs(eigfac).max(), wigval, abs(Tsr[i, :]).max()))

    # -1 factor from definition of toroidal field
    integrand = Tsr * wsr   # since U and V are scaled by sqrt(rho) * r
    integral = simps(integrand, axis=1, x=r)
    prod_gammas = qdcls.gamma(ell1) * qdcls.gamma(ell2) * qdcls.gamma(s_arr)
    omegaref = omegaref
    Cvec = qdcls.minus1pow_vec(m) * 8*np.pi * omegaref * (wigvals @ (prod_gammas * integral))
    return Cvec


def load_eig_dict(nl_list, nl_idxs):
    eig_dict = {}
    eig_dict['eigU'] = {}
    eig_dict['eigV'] = {}
    fname_prefix = "/scratch/g.samarth/get-solar-eigs/efs_Jesper/snrnmais_files/eig_files"
    for i, idx in enumerate(nl_idxs):
        nl = nl_list[i]
        n, l = nl[0], nl[1]
        arg_str = f"{n}.{l}"
        fname_suffix = f"{idx}.dat"
        eig_dict['eigU'][arg_str] = np.loadtxt(f"{fname_prefix}/U{fname_suffix}")[rmin_idx:rmax_idx]
        eig_dict['eigV'][arg_str] = np.loadtxt(f"{fname_prefix}/V{fname_suffix}")[rmin_idx:rmax_idx]
    return eig_dict

test1 = jit(test1, static_argnums=(2,))
create_submatrix = jit(create_submatrix, static_argnums=())
create_submatrix2 = jit(create_submatrix2, static_argnums=())
        
        

In [61]:
import jax
from jax import jit
import jax.numpy as jnp
import numpy as np
from scipy.integrate import simps

In [2]:
cd /home/g.samarth/qdPy

/home/g.samarth/qdPy


In [4]:
run qdpt.py --n0 5 --l0 50

[Rank: 0] Creating submatrices: 
Time taken =    3.13 seconds


In [98]:
nl_idx = analysis_modes.nl_idx
omega0 = analysis_modes.omega0
nl_list = analysis_modes.nl_neighbors
nl_idxs = analysis_modes.nl_neighbors_idx
rmin_idx = super_matrix.gvar.rmin_idx
rmax_idx = super_matrix.gvar.rmax_idx
omegaref = analysis_modes.omega0
r = analysis_modes.gvar.r
wsr = spline_dict.wsr


nl1 = (5, 50)
nl2 = (5, 50)
eig_dict = load_eig_dict(nl_list, nl_idxs)

mwigs = get_marr_wigvals(nl1, nl2)
#e1 = test1(5, 50, eig_dict)
U1, V1 = eig_dict['eigU']['5.50'], eig_dict['eigV']['5.50']
Cvec = create_submatrix2(nl1, nl2, r, wsr, omegaref, mwigs[0], mwigs[1], U1, V1)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function create_submatrix2 at <ipython-input-96-6a28b7daa510>:66, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to create_submatrix2 at <ipython-input-96-6a28b7daa510>:66, transformed by jit. at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
 (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)

In [86]:
e1

{'5.46': DeviceArray([0.       , 0.       , 0.       , ..., 3.521161 , 3.4828827,
              3.44371  ], dtype=float32),
 '5.48': DeviceArray([0.       , 0.       , 0.       , ..., 3.7098951, 3.6697822,
              3.6287205], dtype=float32),
 '5.50': DeviceArray([0.       , 0.       , 0.       , ..., 3.9018595, 3.8598988,
              3.8169336], dtype=float32),
 '5.52': DeviceArray([0.       , 0.       , 0.       , ..., 4.0971284, 4.053307 ,
              4.008423 ], dtype=float32),
 '5.54': DeviceArray([0.       , 0.       , 0.       , ..., 4.295543 , 4.249849 ,
              4.2030334], dtype=float32),
 '6.46': DeviceArray([0.      , 0.      , 0.      , ..., 4.663378, 4.614552,
              4.564485], dtype=float32)}