# Deuteron electrodisintegration

_Last update:_ January 31, 2025

In [1]:
# Python imports
from functools import partial
from matplotlib.offsetbox import AnchoredText
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from sympy.physics.quantum.cg import CG

In [2]:
# JAX imports
from jax import config, jit
import jax.numpy as jnp
from jax.numpy.linalg import solve
from jax.scipy.special import sph_harm
from jaxinterp2d import interp2d

In [3]:
# Imports from scripts
from scripts.figures import set_rc_parameters
from scripts.tools import channel_L_value, coupled_channel

### Set-up

In [4]:
# Enable double precision
config.update("jax_enable_x64", True)

In [5]:
# Run this cell to turn on customized matplotlib graphics
set_rc_parameters()

## Classes

In [6]:
# TODO: Convert to JAX
class ClebschGordan:
    """Class for evaluating Clebsch-Gordan coefficients."""
    
    def __init__(self):
        pass

In [7]:
# TODO: Convert to JAX
class FormFactors:
    """Form factors from Sushant's data files."""
    
    def __init__(self):
        pass

In [8]:
class Potential:
    """Class that loads potentials as JAX arrays."""
    
    # Define class attribute for h-bar^2 / M [MeV fm^2]
    hbar_sq_over_m = 41.47
    
    def __init__(self, kvnn, kmax, kmid, ntot, L_max=2):
        
        # Need kvnn as string (can cause error if kvnn < 10)
        if kvnn < 10:
            kvnn_str = '0' + str(kvnn)
        else:
            kvnn_str = str(kvnn)
        self.kvnn_str = kvnn_str
        
        # Get potential directory
        kmax_int = int(kmax)
        kmid_int = int(kmid)
        self.directory = (
            f'../data/potentials/vsrg_kvnn_{kvnn_str}_lam12.0_kmax{kmax_int:d}'
            f'_kmid{kmid_int:d}_ntot{ntot:d}/'
        )

        # Set momentum mesh in units fm^-1
        self.k_array, self.k_weights = self.get_momentum_mesh()
        
        # Set momentum mesh specifications as attributes
        self.kmax, self.kmid, self.ntot = kmax, kmid, ntot
        
        # Set potentials up to L_max in a big JAX arrays distinguishing
        # uncoupled- or coupled-channel
        self.uncoupled_potentials, self.coupled_potentials = (
            self.get_potentials(L_max)
        )
        
    def get_momentum_mesh(self):
        """Momentum mesh in units [fm^-1] as JAX arrays."""
        
        filename = f'vsrg_1S0_kvnn_{self.kvnn_str}_lam12.0_reg_0_3_0_mesh.out'
        momentum_mesh = np.loadtxt(self.directory + filename)
        k_array = jnp.asarray(momentum_mesh[:, 0])
        k_weights = jnp.asarray(momentum_mesh[:, 1])
        
        return k_array, k_weights
    
    def get_potentials(self, L_max):
        """JAX arrays of potentials up to L_max returning uncoupled- and
        coupled-channel potentials separately."""
        
        # Initialize potentials as lists later to be converted to JAX arrays
        uncoupled_potentials_list = []
        coupled_potentials_list = []
        
        # Possible partial wave channels up to L = 5
        channels = [
            '1S0', '3S1', '3P0', '1P1', '3P1', '3P2', '1D2', '3D2', '3D3',
            '1F3', '3F3', '3F4', '1G4', '3G4', '3G5', '1H5', '3H5', '3H6'
        ]
        
        # Loop over channels
        for channel in channels:

            # Check that channel is within L_max
            L = channel_L_value(channel)
            if L <= L_max:
                
                # Load potential [fm]
                V_matrix = self.load_potential(channel)

                # Coupled-channel
                if coupled_channel(channel):
                    coupled_potentials_list.append(V_matrix)
                # Uncoupled-channel
                else:
                    uncoupled_potentials_list.append(V_matrix)
                    
        # Convert lists to JAX arrays with shape (# of channels, ntot, ntot)
        uncoupled_potentials = jnp.asarray(uncoupled_potentials_list)
        coupled_potentials = jnp.asarray(coupled_potentials_list)

        return uncoupled_potentials, coupled_potentials
            
    def load_potential(self, channel):
        """Load the potential in the given partial wave channel."""
        
        filename = f'vnn_{channel}_kvnn_{self.kvnn_str}_lam12.0_reg_0_3_0.out'
        data = np.loadtxt(self.directory + filename)
        
        # Coupled-channel potential?
        if coupled_channel(channel):
        
            V11 = jnp.reshape(data[:, 2], (self.ntot, self.ntot))
            V12 = jnp.reshape(data[:, 3], (self.ntot, self.ntot))
            V21 = jnp.reshape(data[:, 4], (self.ntot, self.ntot))
            V22 = jnp.reshape(data[:, 5], (self.ntot, self.ntot))
            V_matrix = jnp.vstack(
                (jnp.hstack((V11, V12)), jnp.hstack((V21, V22)))
            )
        
        else:
        
            V_matrix = jnp.reshape(data[:, 2], (self.ntot, self.ntot))

        # Potential in units [fm] where the shape can vary
        return V_matrix

    @partial(jit, static_argnums=(0,))
    def coupled_channel_mapping(self, J):
        """Map the J value onto the index for looking-up coupled-channel
        potentials.
        """
        
        return J - 1
    
    @partial(jit, static_argnums=(0,))
    def get_uncoupled_potential(self, L, S, J):
        
        # Map L, S, and J onto index (special case for 3P0)
        cond = jnp.logical_and(L == 1, jnp.logical_and(S == 1, J == 0))
        index = jnp.where(cond, 1, L + S + J)
        
        # Units are [fm] and shape is (ntot, ntot)
        return self.uncoupled_potentials[index]
    
    @partial(jit, static_argnums=(0,))
    def get_coupled_potential(self, J):
        
        # Map J onto index
        index = J - 1
        
        # Units are [fm] and shape is (2*ntot, 2*ntot)
        return self.coupled_potentials[index]
    
    @partial(jit, static_argnums=(0,))
    def is_coupled(self, L, J):
        """Boolean value on whether the channel is coupled or not."""
        
        return jnp.logical_and(J > 0, J != L)

In [9]:
class TMatrix(Potential):
    """Class that computes the half off-shell T-matrix."""
    
    def __init__(self, kvnn, kmax, kmid, ntot, L_max=2):
        
        # Initializes potentials up to L_max
        super().__init__(kvnn, kmax, kmid, ntot, L_max)

        # Maximum momentum [fm^-1]
        self.lamb = kmax
    
    @partial(jit, static_argnums=(0,))
    def compute(self, pp, L, Lp, S, J):
        """Compute the half off-shell T-matrix for given momentum p' with
        energy E' = p'^2 / M for a particular partial wave channel.
        """
        
        # Condition on whether the channel is coupled or not
        cond = self.is_coupled(L, J)
        T_matrix = jnp.where(
            cond, self.compute_coupled(pp, L, Lp, J),
            self.compute_uncoupled(pp, L, S, J)
        )
        
        # Half off-shell T-matrix in units [fm] with shape (ntot+1, ntot+1)
        return T_matrix
    
    @partial(jit, static_argnums=(0,))
    def compute_uncoupled(self, pp, L, S, J):
        """Compute the half off-shell T-matrix in a uncoupled-channel."""

        # Evaluate D-vector for solving matrix inversion problem
        D_vector = self.D_vector(pp)
        
        # Append p' to end of mesh (ntot+1,)
        k_full = jnp.append(self.k_array, pp)
        
        # Create meshes for interpolation (ntot+1, ntot+1)
        k_grid, kp_grid = jnp.meshgrid(k_full, k_full, indexing='ij')
        
        # Load potential in units [fm] with shape (ntot, ntot)
        V_matrix = self.get_uncoupled_potential(L, S, J)

        # Append p' points by linear interpolation (ntot+1, ntot+1)
        V_interpolated = interp2d(k_grid, kp_grid, self.k_array, self.k_array,
                                  V_matrix)
        
        # Build F-matrix [unitless] where F_ij = \delta_ij + D_j V_ij
        F_matrix = (jnp.identity(self.ntot + 1)
                    + jnp.tile(D_vector, (self.ntot + 1, 1)) * V_interpolated)

        # Solve for T-matrix by matrix inversion
        T_matrix = solve(F_matrix, V_interpolated)

        # Units are [fm] and shape is (ntot+1, ntot+1)
        return T_matrix
    
    @partial(jit, static_argnums=(0,))
    def compute_coupled(self, pp, L, Lp, J):
        """Compute the half off-shell T-matrix in a coupled-channel."""
        
        # Evaluate D-vector for solving matrix inversion problem
        D_vector = self.D_vector(pp)
        
        # Append p' to end of mesh (ntot+1,)
        k_full = jnp.append(self.k_array, pp)
        
        # Create meshes for interpolation (ntot+1, ntot+1)
        k_grid, kp_grid = jnp.meshgrid(k_full, k_full, indexing='ij')
        
        # Load potential in units [fm] with shape (2*ntot, 2*ntot)
        V_matrix = self.get_coupled_potential(J)

        # Append p' points by linear interpolation (2*ntot+2, 2*ntot+2)
        V_interpolated = self.interpolate_coupled_potential(k_grid, kp_grid,
                                                            V_matrix)
        
        # Build F-matrix [unitless] where F_ij = \delta_ij + D_j V_ij
        F_matrix = (
            jnp.identity(2 * (self.ntot + 1))
            + jnp.tile(D_vector, (2 * (self.ntot + 1), 2)) * V_interpolated
        )

        # Solve for T-matrix by matrix inversion
        T_matrix = solve(F_matrix, V_interpolated)
        
        # Select particular sub-block of coupled-channel matrix
        T_subblock = self.select_subblock(L, Lp, J, T_matrix)

        # Units are [fm] and shape is (ntot+1, ntot+1)
        return T_subblock
    
    @partial(jit, static_argnums=(0,))
    def D_vector(self, pp):
        """Compute D-vector used for solving matrix inversion problem:
            T = F^-1 V
        where F_ij = \delta_ij + D_j V_ij.
        """
        
        # First ntot elements of D_vector [fm^-1]
        D_vector = (2.0 / jnp.pi * (self.k_weights * self.k_array ** 2)
                    / (self.k_array ** 2 - pp ** 2))
        
        # ntot + 1 element of D_vector [fm^-1]
        D_last = (
            -2.0 / jnp.pi * pp ** 2 * (
                jnp.sum(self.k_weights / (self.k_array ** 2 - pp ** 2))
                + jnp.log((self.lamb + pp) / (self.lamb - pp)) / (2.0 * pp))
        ) + 1j * pp
        
        # Append ntot + 1 element to D_vector
        return jnp.append(D_vector, D_last)

    @partial(jit, static_argnums=(0,))
    def interpolate_coupled_potential(self, k_grid, kp_grid, V_matrix):
        """Interpolate half off-shell T-matrix matrix in a coupled-channel."""
        
        # Get each sub-block separately with shapes (ntot, ntot)
        V11, V12, V21, V22 = self.get_subblocks(V_matrix)
        
        # Append p' points by linear interpolation (ntot+1, ntot+1)
        V11_interpolated = interp2d(k_grid, kp_grid, self.k_array, self.k_array,
                                    V11)
        V12_interpolated = interp2d(k_grid, kp_grid, self.k_array, self.k_array,
                                    V12)
        V21_interpolated = interp2d(k_grid, kp_grid, self.k_array, self.k_array,
                                    V21)
        V22_interpolated = interp2d(k_grid, kp_grid, self.k_array, self.k_array,
                                    V22)
        
        # Recombine sub-blocks for coupled-channel matrix
        V_interpolated = jnp.vstack((
            jnp.hstack((V11_interpolated, V12_interpolated)),
            jnp.hstack((V21_interpolated, V22_interpolated))
        ))

        # Shape is (2*ntot+2, 2*ntot+2)
        return V_interpolated
    
    @partial(jit, static_argnums=(0,))
    def select_subblock(self, L, Lp, J, T_matrix):
        """Select a particular sub-block (L and L') of a coupled-channel T-
        matrix.
        """
        
        # List of conditions for each sub-block
        condlist = [
            jnp.logical_and(L == Lp, J > L),   # 0-0 sub-block
            jnp.logical_and(L != Lp, J > L),   # 0-2 sub-block
            jnp.logical_and(L != Lp, J > Lp),  # 2-0 sub-block
            jnp.logical_and(L == Lp, J < L)    # 2-2 sub-block
        ]
        
        # List of the T-matrix in each sub-block
        choicelist = [
            T_matrix[:self.ntot+1, :self.ntot+1],
            T_matrix[:self.ntot+1, self.ntot+1:],
            T_matrix[self.ntot+1:, :self.ntot+1],
            T_matrix[self.ntot+1:, self.ntot+1:]
        ]
        
        # Get particular sub-block with shape (ntot+1, ntot+1)
        T_subblock = jnp.select(condlist, choicelist)
        
        # Units remain [fm]
        return T_subblock
    
    @partial(jit, static_argnums=(0,))
    def get_subblocks(self, V_matrix):
        """Gets each sub-block (L and L') from a coupled-channel potential."""
        
        V11 = V_matrix[:self.ntot, :self.ntot]
        V12 = V_matrix[:self.ntot, self.ntot:]
        V21 = V_matrix[self.ntot:, :self.ntot]
        V22 = V_matrix[self.ntot:, self.ntot:]

        return V11, V12, V21, V22

In [10]:
# TODO: Convert to JAX
class DeuteronElectrodisintegration:
    """Class that calculates the longitudinal structure function for deuteron
    electrodisintegration.
    """
    
    def __init__(self):
        pass

## Functions

In [11]:
kvnn, kmax, kmid, ntot = 6, 30.0, 4.0, 120
potential = Potential(kvnn, kmax, kmid, ntot)

In [12]:
print(potential.uncoupled_potentials.shape)
print(potential.coupled_potentials.shape)

(6, 120, 120)
(3, 240, 240)


In [13]:
L, S, J = 0, 0, 0
V_1S0 = potential.get_uncoupled_potential(L, S, J)
print(V_1S0.shape)

(120, 120)


In [14]:
L, S, J = 0, 1, 1
V_3S1 = potential.get_coupled_potential(J)
print(V_3S1.shape)

(240, 240)


In [15]:
tmatrix = TMatrix(kvnn, kmax, kmid, ntot)
pp = 0.5

In [17]:
L, Lp, S, J = 0, 0, 0, 0
T_1S0 = tmatrix.compute(pp, L, Lp, S, J)
print(T_1S0.shape)

(121, 121)


In [18]:
L, Lp, S, J = 0, 2, 1, 1
T_3S1_3D1 = tmatrix.compute(pp, L, Lp, S, J)
print(T_3S1_3D1.shape)

(121, 121)
