In [1]:
import numpy as np
import basix
import matplotlib.pyplot as plt
from basix import CellType, ElementFamily, LagrangeVariant, LatticeType
import jax_pn
from jax.experimental import sparse
import jax.numpy as jnp
import jax
%load_ext autoreload
%autoreload 2
import jax
jax.config.update("jax_enable_x64", True)

In [2]:
lagrange    = basix.create_element(ElementFamily.P, CellType.interval, degree= 5, lagrange_variant= LagrangeVariant.gll_warped)
N_max = 3
elements_per_cm = 10
regions = [
    (2.0, [20.0], np.array([[[0.0]]]), [20.0]),
    (1.0, [1.0],  np.array([[[0.0]]]),  [0.0]),
    (2.0, [0.0],  np.array([[[0.0]]]),  [0.0]),
    (1.0, [1.0],  np.array([[[0.9]]]), [1.0]),
    (2.0, [1.0],  np.array([[[0.9]]]), [0.0]),
]


dpn_prob = jax_pn.DPN.DPN_Problem.from_regions_per_cm(regions, elements_per_cm, N_max, lagrange, L_scat=0)
pn_prob  = jax_pn.PN.PN_Problem.from_regions_per_cm(regions, elements_per_cm, N_max, lagrange, L_scat=0)
adpn_prob = jax_pn.ADPN.ADPN_Problem.from_regions_per_cm(regions, elements_per_cm, N_max, lagrange, L_scat=0)

sol =adpn_prob.Solve_Multigroup_System('vacuum',1)

Solving system with shape: (1608, 1608) and 1608 equations.


In [3]:
from jax_pn.ADPN import residualPN_jit, GlobalSettings
parameters_eg = {
    'sigma_t_i'       : adpn_prob.jax_sigma_t[:, 0],
    'sigma_s_k_i_gg'  : adpn_prob.jax_sigma_s[:, :, 0, 0],
    'h_i'             : adpn_prob.jax_h_i,
    'q_i_k_j'         : adpn_prob.jax_q_i_k_j[:, :, :, 0]
}

solution = jnp.array(adpn_prob.solution) + 1.2

def residual_naive(global_settings : GlobalSettings, matrix_settings, parameters, solution):
    data, rows, cols, bdata, brows = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(global_settings, matrix_settings, parameters)
    total_dofs = global_settings.n_dofs_per_eg    
    
    indices = jnp.stack([rows, cols], axis=1)
    A_total = jax.experimental.sparse.BCOO((data, indices), shape=(total_dofs, total_dofs))
    b_cols = jnp.zeros_like(brows)
    b_indices = jnp.stack([brows, b_cols], axis=1)    
    b_total = jax.experimental.sparse.BCOO((bdata, b_indices), shape=(total_dofs, 1))        

    return (A_total @ solution - b_total.todense()[:,0])

np.set_printoptions(precision=12)
res_old = residualPN_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)
res_new = residual_naive(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)

print("Difference: ", np.max(np.abs(res_old - res_new)))

Difference:  6.071532165918825e-16


In [4]:
#%timeit residualPN_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution).block_until_ready()
#%timeit residual_naive(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution).block_until_ready()

In [5]:
jac_A= jax.jit(jax.jacrev(residualPN_jit, argnums = (3,)), static_argnums = 0)

In [6]:
def A_total_f(global_settings : GlobalSettings, matrix_settings, parameters, solution):
    data, rows, cols, bdata, brows = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(global_settings, matrix_settings, parameters)
    total_dofs = global_settings.n_dofs_per_eg
        
    indices = jnp.stack([rows, cols], axis=1)
    A_total = jax.experimental.sparse.BCOO((data, indices), shape=(total_dofs, total_dofs))
    return A_total

a_normal = A_total_f(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution).todense()
Ares = jac_A(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)[0]
print(np.max(np.abs(a_normal - Ares)))

5.551115123125783e-17


In [7]:
def edge_flux(sol, settings : GlobalSettings):
   right_dof = settings.right_dof
   return sol[right_dof]

In [8]:
dfdx = jax.grad(edge_flux, argnums = 0)(solution, adpn_prob.global_settings)

at_jax = A_total_f(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution).T

data = np.array(at_jax.data)
indices = np.array(at_jax.indices)
shape = at_jax.shape
import scipy 
import scipy.sparse as sp
at_scipy = scipy.sparse.coo_matrix((data, indices.T), shape=shape).tocsr()
adjoint_var = sp.linalg.spsolve(at_scipy, np.array(dfdx))
print(adjoint_var)

[ 1.570602919567e-03 -1.956439323124e-03  2.421829364644e-03 ...
 -1.281609255364e+00 -6.155900403630e-07  8.112819637707e-01]


In [9]:
drdtheta = jax.jacfwd(residualPN_jit, argnums = 2)

In [10]:
drdtheta_res = drdtheta(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)
print(adjoint_var, - np.array(drdtheta_res['h_i']))
dfdh = np.einsum("i, ik->k", adjoint_var, - np.array(drdtheta_res['h_i']))
print(dfdh)

[ 1.570602919567e-03 -1.956439323124e-03  2.421829364644e-03 ...
 -1.281609255364e+00 -6.155900403630e-07  8.112819637707e-01] [[-0.483012503807 -0.             -0.             ... -0.
  -0.             -0.            ]
 [-0.786075789506 -0.787555543157 -0.             ... -0.
  -0.             -0.            ]
 [-0.             -0.798979384697 -0.799090016017 ... -0.
  -0.             -0.            ]
 ...
 [-0.             -0.             -0.             ... -0.
  -0.             -0.            ]
 [-0.             -0.             -0.             ... -0.
  -0.             -0.            ]
 [-0.             -0.             -0.             ... -0.
  -0.             -0.            ]]
[-1.754317952031e-05  9.488726076914e-06 -6.523683253780e-06
  5.471572523603e-06 -3.097568769511e-06  1.137929687978e-06
  3.871952596515e-06 -1.106850440542e-05  2.548548665109e-05
 -4.896553160071e-05  9.110166717825e-05 -1.613760903858e-04
  2.819200860513e-04 -4.849885969828e-04  8.109899263765e-04
 -1.

In [20]:
drdtheta_res = drdtheta(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)
print(adjoint_var, - np.array(drdtheta_res['h_i']))
dfdh = np.einsum("i, ik->k", adjoint_var, - np.array(drdtheta_res['h_i']))
print(dfdh)

[ 1.570602919567e-03 -1.956439323124e-03  2.421829364644e-03 ...
 -1.281609255364e+00 -6.155900403630e-07  8.112819637707e-01] [[-0.483012503807 -0.             -0.             ... -0.
  -0.             -0.            ]
 [-0.786075789506 -0.787555543157 -0.             ... -0.
  -0.             -0.            ]
 [-0.             -0.798979384697 -0.799090016017 ... -0.
  -0.             -0.            ]
 ...
 [-0.             -0.             -0.             ... -0.
  -0.             -0.            ]
 [-0.             -0.             -0.             ... -0.
  -0.             -0.            ]
 [-0.             -0.             -0.             ... -0.
  -0.             -0.            ]]
[-1.754317952031e-05  9.488726076914e-06 -6.523683253780e-06
  5.471572523603e-06 -3.097568769511e-06  1.137929687978e-06
  3.871952596515e-06 -1.106850440542e-05  2.548548665109e-05
 -4.896553160071e-05  9.110166717825e-05 -1.613760903858e-04
  2.819200860513e-04 -4.849885969828e-04  8.109899263765e-04
 -1.

In [12]:
raise ValueError("Stop here, the rest of the code is not needed for the adjoint test.")

ValueError: Stop here, the rest of the code is not needed for the adjoint test.

In [21]:
sigma_s_ki_gg = np.array(drdtheta_res['sigma_s_k_i_gg'])
print(sigma_s_ki_gg.shape)
dfdsigma_s = np.einsum("i, ikl->kl", adjoint_var, -sigma_s_ki_gg)


(1608, 80, 4)


In [22]:
dfdh_edge = dfdh[-1]
print(dfdh_edge)

print(dfdh[-60])

2.2483160564849967
-0.2016727164186521


In [18]:
def finite_difference_h(index, stepsize = 1e-4):
    h_i_default = np.array(adpn_prob.jax_h_i)

    def compute_with_param(parameters_eg):
        a = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg)
        A_total = scipy.sparse.coo_matrix((a[0], (a[1], a[2])), shape=(adpn_prob.global_settings.n_dofs_per_eg, adpn_prob.global_settings.n_dofs_per_eg)).tocsr()
        b_total = scipy.sparse.coo_matrix((a[3], (a[4], np.zeros_like(a[4]))), shape=(adpn_prob.global_settings.n_dofs_per_eg, 1)).tocsr()
        u_base  = sp.linalg.spsolve(A_total, b_total.todense()[:,0])
        return edge_flux(u_base, adpn_prob.global_settings)
    
    parameters_eg = {
        'sigma_t_i'       : adpn_prob.jax_sigma_t[:, 0],
        'sigma_s_k_i_gg'  : adpn_prob.jax_sigma_s[:, :, 0, 0],
        'h_i'             : h_i_default,
        'q_i_k_j'         : adpn_prob.jax_q_i_k_j[:, :, :, 0]
    }
    base = compute_with_param(parameters_eg)
    h_i_perturbed = np.copy(h_i_default)
    h_i_perturbed[index] += stepsize
    parameters_eg['h_i'] = h_i_perturbed    
    edge = compute_with_param(parameters_eg)
    print((edge - base) / stepsize)
    
index = 0 
print(dfdh[index])
finite_difference_h(index)


def finite_difference_sigma_s(index,l_value,  stepsize = 1e-4):
    

    def compute_with_param(parameters_eg):
        a = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg)
        A_total = scipy.sparse.coo_matrix((a[0], (a[1], a[2])), shape=(adpn_prob.global_settings.n_dofs_per_eg, adpn_prob.global_settings.n_dofs_per_eg)).tocsr()
        b_total = scipy.sparse.coo_matrix((a[3], (a[4], np.zeros_like(a[4]))), shape=(adpn_prob.global_settings.n_dofs_per_eg, 1)).tocsr()
        u_base  = sp.linalg.spsolve(A_total, b_total.todense()[:,0])
        return edge_flux(u_base, adpn_prob.global_settings)
    
    sigma_s_default = adpn_prob.jax_sigma_s[:, :, 0, 0]
    parameters_eg = {
        'sigma_t_i'       : adpn_prob.jax_sigma_t[:, 0],
        'sigma_s_k_i_gg'  : sigma_s_default,
        'h_i'             : adpn_prob.jax_h_i,
        'q_i_k_j'         : adpn_prob.jax_q_i_k_j[:, :, :, 0]
    }

    base = compute_with_param(parameters_eg)

    sigma_s_perturbed = np.copy(sigma_s_default)
    sigma_s_perturbed[index, l_value] += stepsize #* sigma_s_perturbed[index, l_value] 
    
    parameters_eg['sigma_s_k_i_gg'] = sigma_s_perturbed
    edge = compute_with_param(parameters_eg)
    print((edge - base) / stepsize)
    
index = - 1 
l_value = 1
print(dfdsigma_s[index, l_value])
finite_difference_sigma_s(index, l_value, stepsize=1e-6)


-1.7543179520307546e-05
-5.588058349381697e-06
0.10017789537042807
0.010538375616597762


### Implementation

We have 2 objective functions:

1. The flux at $x=5$
2. The flux at $x=8$

In [None]:
from jax_pn.ADPN import GlobalSettings
def edge_flux(sol, settings : GlobalSettings):
    right_dof = settings.right_dof
    return sol[right_dof]

def five_flux(sol, settings : GlobalSettings, nodes):
    # assumes equal spacing of nodes
    n_nodes = settings.right_dof + 1
    val_node = ( (n_nodes -1) * 5) / 8
    val_node = int(val_node)
    return sol[val_node]

print(edge_flux(sol, adpn_prob.global_settings))
print(five_flux(sol, adpn_prob.global_settings, adpn_prob.nodes))

0.24351924850098738
1.2125756808954822


In [None]:
import jax
def residual(global_settings : GlobalSettings, total_dofs : int, matrix_settings, parameters, solution):
    data, rows, cols, bdata, brows = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(global_settings, matrix_settings, parameters)
    indices = jnp.stack([rows, cols], axis=1)
    A_total = jax.experimental.sparse.BCOO((data, indices), shape=(total_dofs, total_dofs))    
    b_cols = jnp.zeros_like(brows)
    b_indices = jnp.stack([brows, b_cols], axis=1)    
    b_total = jax.experimental.sparse.BCOO((bdata, b_indices), shape=(total_dofs, 1))        

    return A_total @ solution - b_total.todense()[:,0]

solution = jnp.array(adpn_prob.solution)
parameters_eg = {
    'sigma_t_i'       : adpn_prob.jax_sigma_t[:, 0],
    'sigma_s_k_i_gg'  : adpn_prob.jax_sigma_s[:, :, 0, 0],
    'h_i'             : adpn_prob.jax_h_i,
    'q_i_k_j'         : adpn_prob.jax_q_i_k_j[:, :, :, 0]
}

residual_jit = jax.jit(residual, static_argnums=(0,1))

jac_res = jax.jacrev(residual_jit, argnums = (4))
adata, rows, cols, _, _ = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters = parameters_eg)
indices = jnp.stack([rows, cols], axis=1)
A_total                = jax.experimental.sparse.BCOO((adata, indices), shape=(adpn_prob.dofs_per_eg, adpn_prob.dofs_per_eg))    

In [None]:
A2 = jac_res(adpn_prob.global_settings, adpn_prob.dofs_per_eg, adpn_prob.matrix_settings, parameters_eg, solution)

print(jnp.max(A2 - A_total.todense()))

0.0


#### Direct

In [None]:

# dA/dtheta & db/dtheta
energy_group = 0
parameters_eg = {
    'sigma_t_i'       : adpn_prob.jax_sigma_t[:, energy_group],
    'sigma_s_k_i_gg'  : adpn_prob.jax_sigma_s[:, :, energy_group, energy_group],
    'h_i'             : adpn_prob.jax_h_i,
    'q_i_k_j'         : adpn_prob.jax_q_i_k_j[:, :, :, energy_group]
}
adata, rows, val, _, _ =jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters = parameters_eg)

In [None]:
import jax
def generate_A(global_settings : GlobalSettings, total_dofs : int,  matrix_settings, sigma_s):
    energy_group = 0
    sigma_t = matrix_settings['sigma_t'][:, energy_group]
    h_i     = matrix_settings['h_i']
    q_i_k_j = matrix_settings['q_i_k_j'][:, :, :, energy_group]

    no_elements = global_settings.n_elements
    n_moments =  global_settings.n_moments    

    # Calculate split index for 5/8 and 3/8
    split_idx = int(no_elements * 5 / 8)

    # Create the array
    sigma_s_new = jnp.zeros((no_elements, n_moments))
    sigma_s_new = sigma_s_new.at[split_idx:, 0].set(sigma_s)

    parameters_eg = {
        'sigma_t_i'       : sigma_t,
        'sigma_s_k_i_gg'  : sigma_s_new,
        'h_i'             : h_i,
        'q_i_k_j'         : q_i_k_j
    }
    data, rows, cols, bdata, brows = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(global_settings, matrix_settings, parameters_eg)

    # BCOO expects indices as a single array of shape (nnz, ndim)
    indices = jnp.stack([rows, cols], axis=-1)
    A_total = jax.experimental.sparse.BCOO((data, indices), shape=(total_dofs, total_dofs))

    return A_total

m2 = adpn_prob.matrix_settings
m2["sigma_t"] = adpn_prob.jax_sigma_t
m2["h_i"] = adpn_prob.jax_h_i
m2["q_i_k_j"] = adpn_prob.jax_q_i_k_j
m2['total_dofs'] = adpn_prob.dofs_per_eg

generate_A(adpn_prob.global_settings, adpn_prob.dofs_per_eg, m2, 0.9)

BCOO(float64[1608, 1608], nse=34592)

In [None]:
import jax
def generate_A(global_settings : GlobalSettings, total_dofs : int,  matrix_settings, sigma_s):
    energy_group = 0
    sigma_t = matrix_settings['sigma_t'][:, energy_group]
    h_i     = matrix_settings['h_i']
    q_i_k_j = matrix_settings['q_i_k_j'][:, :, :, energy_group]

    no_elements = global_settings.n_elements
    n_moments =  global_settings.n_moments    

    # Calculate split index for 5/8 and 3/8
    split_idx = int(no_elements * 5 / 8)

# Create the array
    sigma_s_new = jnp.zeros((no_elements, n_moments))
    sigma_s_new = sigma_s_new.at[split_idx:, 0].set(sigma_s)

    parameters_eg = {
        'sigma_t_i'       : sigma_t,
        'sigma_s_k_i_gg'  : sigma_s_new,
        'h_i'             : h_i,
        'q_i_k_j'         : q_i_k_j
    }
    data, rows, cols, bdata, brows = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(global_settings, matrix_settings, parameters_eg)
    indices = jnp.stack([rows, cols], axis=1)
    A_total = jax.experimental.sparse.BCOO((data,indices), shape=(total_dofs, total_dofs))
    
    return A_total

m2 = adpn_prob.matrix_settings
m2["sigma_t"] = adpn_prob.jax_sigma_t
m2["h_i"] = adpn_prob.jax_h_i
m2["q_i_k_j"] = adpn_prob.jax_q_i_k_j
m2['total_dofs'] = adpn_prob.dofs_per_eg

generate_A(adpn_prob.global_settings, adpn_prob.dofs_per_eg, m2, 0.9)

BCOO(float64[1608, 1608], nse=34592)