In [None]:
import numpy as np
import basix
import matplotlib.pyplot as plt
from basix import CellType, ElementFamily, LagrangeVariant, LatticeType
import jax_pn
from jax.experimental import sparse
import jax.numpy as jnp
import jax
%load_ext autoreload
%autoreload 2
import jax
#jax.config.update("jax_enable_x64", True)

# Adjoint method with Autodiff partial derivatives

We define the adjoint variable $\boldsymbol{\lambda} \in \mathbb{R}^N$ as the solution to:
$$
    \left( \frac{\partial \boldsymbol{R}}{\partial \boldsymbol{x}} \right)^T \boldsymbol{\lambda} = \left( \frac{\partial f}{\partial \boldsymbol{x}} \right)^T
$$


This allows us to write the objective sensitivity as:
$$
    \frac{d f}{d \boldsymbol{\theta}} 
    = -\boldsymbol{\lambda}^T \frac{\partial \boldsymbol{R}}{\partial \boldsymbol{\theta}} + \frac{\partial f}{\partial \boldsymbol{\theta}}
$$

In this notebook, the adjoint method is implemented and compared to the finite difference method using the autodiff for the partial derivatives.

In [None]:
lagrange    = basix.create_element(ElementFamily.P, CellType.interval, degree= 5, lagrange_variant= LagrangeVariant.gll_warped)
N_max = 3
elements_per_cm = 10
regions = [
    (2.0, [20.0], np.array([[[0.0]]]), [20.0]),
    (1.0, [1.0],  np.array([[[0.0]]]),  [0.0]),
    (2.0, [0.0],  np.array([[[0.0]]]),  [0.0]),
    (1.0, [1.0],  np.array([[[0.9]]]), [1.0]),
    (2.0, [1.0],  np.array([[[0.9]]]), [0.0]),
]

adpn_prob = jax_pn.ADPN.ADPN_Problem.from_regions_per_cm(regions, elements_per_cm, N_max, lagrange, L_scat=0)
solution  = adpn_prob.Solve_Multigroup_System('vacuum',1)

#### Checking the residual

To check whether the residual was correctly implemented, we check it by comparing it with the matrix construction.

We also check that $\partial R / \partial x$ reproduces the original finite-element matrix (including the augmented boundary conditions).

In [None]:
from jax_pn.ADPN import residualPN_jit, GlobalSettings
parameters_eg = {
    'sigma_t_i'       : adpn_prob.jax_sigma_t,
    'sigma_s_k_i_gg'  : adpn_prob.jax_sigma_s,
    'h_i'             : adpn_prob.jax_h_i,
    'q_i_k_j'         : adpn_prob.jax_q_i_k_j
}

def residual_naive(global_settings : GlobalSettings, matrix_settings, parameters, solution):
    data, rows, cols, bdata, brows = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(global_settings, matrix_settings, parameters)
    total_dofs = global_settings.n_dofs_per_eg    
    
    indices = jnp.stack([rows, cols], axis=1)
    A_total = jax.experimental.sparse.BCOO((data, indices), shape=(total_dofs, total_dofs))
    b_cols = jnp.zeros_like(brows)
    b_indices = jnp.stack([brows, b_cols], axis=1)    
    b_total = jax.experimental.sparse.BCOO((bdata, b_indices), shape=(total_dofs, 1))        

    return (A_total @ solution - b_total.todense()[:,0])

np.set_printoptions(precision=12)
res_old = residualPN_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)
res_new = residual_naive(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)
print("Difference: ", np.max(np.abs(res_old - res_new)))

In [None]:
def A_total_f(global_settings : GlobalSettings, matrix_settings, parameters, solution):
    data, rows, cols, bdata, brows = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(global_settings, matrix_settings, parameters)
    total_dofs = global_settings.n_dofs_per_eg
        
    indices = jnp.stack([rows, cols], axis=1)
    A_total = jax.experimental.sparse.BCOO((data, indices), shape=(total_dofs, total_dofs))
    return A_total
jac_A    = jax.jit(jax.jacfwd(residualPN_jit, argnums = (3,)), static_argnums = 0)
a_normal = A_total_f(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution).todense()
Ares     = jac_A(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)[0]
print(np.max(np.abs(a_normal - Ares)))

### Defining an objective function

We define a very simple objective function here: the flux at the edge of the domain.

In [None]:
def edge_flux(sol, settings : GlobalSettings):
   right_dof = settings.right_dof
   return sol[right_dof]

This solves the adjoint system:

- First, $\partial f / \partial_x $ is computed using autodiff (this is a very simple differentation, but we keep it as-is).
- The $\partial R / \partial_x $ is not computed using autodiff but just by constructing the sparse matrix.

Then, using scipy, the adjoint system is solved.

In [None]:
dfdx = jax.grad(edge_flux, argnums = 0)(solution, adpn_prob.global_settings)

at_jax = A_total_f(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution).T

data = np.array(at_jax.data)
indices = np.array(at_jax.indices)
shape = at_jax.shape
import scipy 
import scipy.sparse as sp
at_scipy = scipy.sparse.coo_matrix((data, indices.T), shape=shape).tocsr()
adjoint_var = sp.linalg.spsolve(at_scipy, np.array(dfdx))

Computing the $\partial R / \partial \theta $ term is then rather simple using autodiff. Note that parameters is actually a dictionary, but JAx handles this as well, returning a dictionary with the partial derivatives.

In [None]:
drdtheta = jax.jit(jax.jacfwd(residualPN_jit, argnums = 2), static_argnums=0)
drdtheta_res = drdtheta(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg, solution)


In [None]:
print(parameters_eg)

Finally, the adjoint is used to obtain the desired total derivatives.

Here, we just use the derivative of the fast flux w.r.t to the element sizes (note it gets them all at the same time!) and the scattering matrices (which are 320 design variables!)

In [None]:
dfdh = np.einsum("i, ik->k", adjoint_var, - np.array(drdtheta_res['h_i']))
dfdsigma_s = np.einsum("i, iklgp->klgp", adjoint_var, -np.array(drdtheta_res['sigma_s_k_i_gg']))
dfdq= np.einsum("i, ikljg->kljg", adjoint_var, -np.array(drdtheta_res['q_i_k_j']))

A comparison with a naive finite-difference approach shows that the adjoint method provides the correct derivatives. However, note that the finite-difference method would require $80 + 320 + 80 + 480 \times 4 = 2400 $ function evaluations to get 
the same number of derivatives!

In [None]:
def compute_with_param(parameters_eg):
        a = jax_pn.ADPN.total_matrix_assembly_vacuum_bcs_single_g_jit(adpn_prob.global_settings, adpn_prob.matrix_settings, parameters_eg)
        A_total = scipy.sparse.coo_matrix((a[0], (a[1], a[2])), shape=(adpn_prob.global_settings.n_dofs_per_eg, adpn_prob.global_settings.n_dofs_per_eg)).tocsr()
        b_total = scipy.sparse.coo_matrix((a[3], (a[4], np.zeros_like(a[4]))), shape=(adpn_prob.global_settings.n_dofs_per_eg, 1)).tocsr()
        u_base  = sp.linalg.spsolve(A_total, b_total.todense()[:,0])
        return edge_flux(u_base, adpn_prob.global_settings)
    
def finite_difference_h(index, stepsize = 1e-4):
    h_i_default = np.array(adpn_prob.jax_h_i)


    parameters_eg = {
        'sigma_t_i'       : adpn_prob.jax_sigma_t[:, 0],
        'sigma_s_k_i_gg'  : adpn_prob.jax_sigma_s[:, :, 0, 0],
        'h_i'             : h_i_default,
        'q_i_k_j'         : adpn_prob.jax_q_i_k_j[:, :, :, 0]
    }
    base = compute_with_param(parameters_eg)
    h_i_perturbed = np.copy(h_i_default)
    h_i_perturbed[index] += stepsize
    parameters_eg['h_i'] = h_i_perturbed    
    edge = compute_with_param(parameters_eg)
    return (edge - base) / stepsize
    
index = -5
print(f"H comparison: index {index}")
print("Autodiff",  dfdh[index])
print("Finite Diff", finite_difference_h(index))


def finite_difference_sigma_s(index,l_value,  stepsize = 1e-4):
    
    sigma_s_default = adpn_prob.jax_sigma_s[:, :, 0, 0]
    parameters_eg = {
        'sigma_t_i'       : adpn_prob.jax_sigma_t[:, 0],
        'sigma_s_k_i_gg'  : sigma_s_default,
        'h_i'             : adpn_prob.jax_h_i,
        'q_i_k_j'         : adpn_prob.jax_q_i_k_j[:, :, :, 0]
    }

    base = compute_with_param(parameters_eg)

    sigma_s_perturbed = np.copy(sigma_s_default)
    sigma_s_perturbed[index, l_value] += stepsize #* sigma_s_perturbed[index, l_value] 
    
    parameters_eg['sigma_s_k_i_gg'] = sigma_s_perturbed
    edge = compute_with_param(parameters_eg)
    return (edge - base) / stepsize
    
index = - 10 
l_value = 1
print(f"Sigma_s comparison: index {index}, l_value {l_value}")
print("Autodiff:" , dfdsigma_s[index, l_value])
print("Finite diff", finite_difference_sigma_s(index, l_value, stepsize=1e-2))

def finite_difference_q(index,l_value, local_dof,  stepsize = 1e-4):
    
    q_default = adpn_prob.jax_q_i_k_j[:, :, :, 0]
    parameters_eg = {
        'sigma_t_i'       : adpn_prob.jax_sigma_t[:, 0],
        'sigma_s_k_i_gg'  : adpn_prob.jax_sigma_s[:, :, 0, 0],
        'h_i'             : adpn_prob.jax_h_i,
        'q_i_k_j'         : q_default
    }

    base = compute_with_param(parameters_eg)

    sigma_s_perturbed = np.copy(q_default)
    sigma_s_perturbed[index, l_value, local_dof] += stepsize #* sigma_s_perturbed[index, l_value] 
    
    parameters_eg['q_i_k_j'] = sigma_s_perturbed
    edge = compute_with_param(parameters_eg)
    return (edge - base) / stepsize
    
index = - 1 
l_value = 1
local_dof = 0
print(f"Q comparison: index {index}, l_value {l_value}, local_dof {local_dof}")
print("Autodiff:" , dfdq[index, l_value, local_dof])
print("Finite diff", finite_difference_q(index, l_value, local_dof, stepsize=1e-2))

